In [13]:
import redis
r = redis.Redis(host='localhost', port=6379, decode_responses=True)
print(r.ping())

True


In [2]:
r.set('foo', 'bar')
# True
r.get('foo')
# bar

'bar'

In [3]:
from dataclasses import dataclass, field
from typing import Any, Optional
from datetime import datetime
import json
from enum import Enum

In [4]:
class WorkingMemorySlot(Enum):
    """
    Cognitive slots in working memory.
    
    Each slot represents a distinct type of information the agent
    might hold in active cognition. The separation allows independent
    access and update—changing the current goal doesn't require
    rewriting the entire working memory.
    """
    CURRENT_GOAL = "current_goal"
    ACTIVE_CONTEXT = "active_context"
    SCRATCHPAD = "scratchpad"
    RECENT_OBSERVATIONS = "recent_observations"
    PENDING_ACTIONS = "pending_actions"
    USER_INTENT = "user_intent"
    REASONING_TRACE = "reasoning_trace"

In [21]:
@dataclass
class WorkingMemoryEntry:
    """
    A single entry in working memory.
    
    Each entry has content (what we're remembering), metadata (when, how important),
    and lifecycle information (how long until it expires).
    """
    slot: WorkingMemorySlot
    content: Any  # JSON-serializable content
    timestamp: datetime = field(default_factory=datetime.utcnow)
    ttl_seconds: int = 300  # 5 minute default—tune based on your application
    importance: float = 0.5  # 0-1 scale; higher = retain longer under pressure

    def to_redis_hash(self) -> dict:
        """
        Convert to Redis hash format.
        
        Redis hashes store field-value pairs. We serialize complex content
        to JSON since Redis values are strings/bytes. This serialization
        is intentionally simple—if you need more sophisticated serialization,
        consider msgpack or protobuf for better performance.
        """
        return {
            "slot": self.slot.value,
            "content": json.dumps(self.content),
            "timestamp": self.timestamp.isoformat(),
            "importance": str(self.importance)
        }
    
@classmethod
def from_redis_hash(cls, data: dict) -> "WorkingMemoryEntry":
    """
    Reconstruct from Redis hash data.
    
    Note the byte decoding—Redis returns bytes by default.
    """
    return cls(
        slot=WorkingMemorySlot(data[b"slot"].decode()),
        content=json.loads(data[b"content"].decode()),
        timestamp=datetime.fromisoformat(data[b"timestamp"].decode()),
        importance=float(data[b"importance"].decode()),
        ttl_seconds=0  # Already in Redis; TTL managed there
    )

In [22]:
pip install redis[hiredis]


Note: you may need to restart the kernel to use updated packages.


pip install hiredis

In [6]:
import json
import redis.asyncio as redis
from datetime import datetime, timezone # ✅ Updated import
from typing import Any, Optional, List, Dict
from enum import Enum

class WorkingMemorySlot(Enum):
    USER_INTENT = "user_intent"
    CURRENT_GOAL = "current_goal"
    TASK_STATE = "task_state"
    SCRATCHPAD = "scratchpad"

class WorkingMemoryManager:
    def __init__(self, redis_client: redis.Redis, session_id: str, max_observations: int = 10, default_ttl: int = 300):
        self.redis = redis_client
        self.session_id = session_id
        self.max_observations = max_observations
        self.default_ttl = default_ttl
        self._key_prefix = f"wm:{session_id}"

    async def _slot_key(self, slot: WorkingMemorySlot) -> str:
        return f"{self._key_prefix}:{slot.value}"

    async def set_slot(self, slot: WorkingMemorySlot, content: Any, ttl_seconds: Optional[int] = None, importance: float = 0.5) -> None:
        key = await self._slot_key(slot)
        
        entry_data = {
            "content": json.dumps(content),
            "importance": str(importance),
            "timestamp": datetime.now(timezone.utc).isoformat() # ✅ Fixed Warning
        }
        
        # ✅ Fixed Pipeline Compatibility
        pipe = self.redis.pipeline(transaction=True)
        pipe.hset(key, mapping=entry_data)
        pipe.expire(key, ttl_seconds or self.default_ttl)
        pipe.execute()

    async def get_slot(self, slot: WorkingMemorySlot) -> Optional[Dict]:
        key = await self._slot_key(slot)
        data = await self.redis.hgetall(key)
        if not data: return None
        
        decoded = {k.decode(): v.decode() for k, v in data.items()}
        if "content" in decoded:
            decoded["content"] = json.loads(decoded["content"])
        return decoded

    async def append_observation(self, observation: dict) -> None:
        key = f"{self._key_prefix}:observations"
        timestamp = datetime.now(timezone.utc).timestamp() # ✅ Fixed Warning
        
        pipe = self.redis.pipeline(transaction=True)
        pipe.zadd(key, {json.dumps(observation): timestamp})
        pipe.zremrangebyrank(key, 0, -(self.max_observations + 1))
        pipe.expire(key, self.default_ttl * 2)
        pipe.execute()

    async def get_recent_observations(self, limit: int = 5) -> List[dict]:
        key = f"{self._key_prefix}:observations"
        observations = await self.redis.zrange(key, -limit, -1)
        return [json.loads(obs.decode()) for obs in observations]

    async def get_full_context(self) -> dict:
        context = {}
        for slot in WorkingMemorySlot:
            entry = await self.get_slot(slot)
            if entry:
                context[slot.value] = entry["content"]
        context["recent_observations"] = await self.get_recent_observations()
        return context

    async def update_scratchpad(self, field: str, value: Any) -> None:
        """Helper to update the scratchpad slot."""
        key = await self._slot_key(WorkingMemorySlot.SCRATCHPAD)
        # Get existing scratchpad
        current = await self.get_slot(WorkingMemorySlot.SCRATCHPAD)
        content = current["content"] if current else {}
        content[field] = value
        await self.set_slot(WorkingMemorySlot.SCRATCHPAD, content)

    async def get_scratchpad(self) -> Dict:
        data = await self.get_slot(WorkingMemorySlot.SCRATCHPAD)
        return data["content"] if data else {}

    async def get_full_context(self) -> dict:
            context = {}
            for slot in WorkingMemorySlot:
                entry = await self.get_slot(slot)
                if entry:
                    context[slot.value] = entry["content"]
            context["recent_observations"] = await self.get_recent_observations()
            return context

# --- External Reasoning Function ---

async def agent_reasoning_step(wm: WorkingMemoryManager, user_input: str, llm: Any) -> str:
    # 1. Load current context
    context = await wm.get_full_context()

    # 2. Update with new input
    await wm.set_slot(
        WorkingMemorySlot.USER_INTENT,
        {
            "raw_input": user_input,
            "parsed_at": datetime.utcnow().isoformat()
        },
        importance=0.9
    )

    # 3. Record as observation
    await wm.append_observation({
        "type": "user_input",
        "content": user_input,
        "timestamp": datetime.utcnow().isoformat()
    })
    
    # 4. Generate response (Assuming build_prompt and llm.generate exist)
    # response = await llm.generate(prompt=build_prompt(user_input, context))
    response = f"Simulated response to: {user_input}"
    
    # 5. Update reasoning trace
    await wm.update_scratchpad("last_response", response)
    
    reasoning_steps = context.get("scratchpad", {}).get("reasoning_steps", [])
    reasoning_steps.append({
        "input": user_input,
        "output": response,
        "timestamp": datetime.utcnow().isoformat()
    })
    await wm.update_scratchpad("reasoning_steps", reasoning_steps)
    
    return response

In [7]:
wmm = WorkingMemoryManager(redis_client=r, session_id="session123")

In [8]:
await wmm.set_slot(WorkingMemorySlot.CURRENT_GOAL, "Learn about Redis")