# [SOLUTION] Exercise - Building an Agent with Short-Term Memory

In this exercise, you’ll extend your agent to support short-term memory across sessions. While state is used to manage the agent’s progress within a single run, memory allows your agent to remember what happened in previous runs, enabling context continuity across multiple user interactions.

You’ll learn how to use a memory object to store and retrieve conversation history, tool usage, and other relevant information, grouped by session. This is a key step toward building agents that can hold a conversation or remember facts within a session.

## Challenge

- Understand the difference between state and memory:
    - State is local to a single run and is lost when the run ends.
    - Memory persists across runs and sessions, allowing the agent to remember what happened before.
- Use the provided ShortTermMemory class to manage session memory.
- Implement an Agent class that:
    - Accepts a session_id for each interaction.
    - Stores each state in memory under the correct session.
    - Retrieves and uses session history to provide context for new queries.
- Demonstrate how the agent can continue a conversation across multiple invocations.


## Setup
First, let's import the necessary libraries:

In [1]:
from typing import TypedDict, List, Optional, Union
import json
from dotenv import load_dotenv

from lib.state_machine import StateMachine, Step, EntryPoint, Termination, Run
from lib.llm import LLM
from lib.messages import AIMessage, UserMessage, SystemMessage, ToolMessage, BaseMessage
from lib.tooling import Tool, ToolCall, tool
from lib.memory import ShortTermMemory

In [2]:
load_dotenv()

True

## Define a State Schema

Create a TypedDict to represent the agent’s state, including fields for the user query, instructions, message history, any pending tool calls and the session_id.

In [3]:
class AgentState(TypedDict):
    user_query: str  # The current user query being processed
    instructions: str  # System instructions for the agent
    messages: List[dict]  # List of conversation messages
    current_tool_calls: Optional[List[ToolCall]]  # Current pending tool calls
    session_id: str  # Session identifier for memory management


## Create your Agent with Memory

In [4]:
class MemoryAgent:
    def __init__(self, 
                 model_name: str,
                 instructions: str, 
                 tools: List[Tool] = None,
                 temperature: float = 0.7):
        """
        Initialize a MemoryAgent instance
        
        Args:
            model_name: Name/identifier of the LLM model to use
            instructions: System instructions for the agent
            tools: Optional list of tools available to the agent
            temperature: Temperature parameter for LLM (default: 0.7)
        """
        self.instructions = instructions
        self.tools = tools if tools else []
        self.model_name = model_name
        self.temperature = temperature
        
        # Initialize memory and state machine
        self.memory = ShortTermMemory()
        self.workflow = self._create_state_machine()

    def _prepare_messages_step(self, state: AgentState) -> AgentState:
        """Step logic: Prepare messages for LLM consumption"""
        messages = state.get("messages", [])
        
        # If no messages exist, start with system message
        if not messages:
            messages = [SystemMessage(content=state["instructions"])]
            
        # Add the new user message
        messages.append(UserMessage(content=state["user_query"]))
        
        return {
            "messages": messages,
            "session_id": state["session_id"]
        }

    def _llm_step(self, state: AgentState) -> AgentState:
        """Step logic: Process the current state through the LLM"""
        # Initialize LLM
        llm = LLM(
            model=self.model_name,
            temperature=self.temperature,
            tools=self.tools
        )

        response = llm.invoke(state["messages"])
        tool_calls = response.tool_calls if response.tool_calls else None

        # Create AI message with content and tool calls
        ai_message = AIMessage(content=response.content, tool_calls=tool_calls)
        
        return {
            "messages": state["messages"] + [ai_message],
            "current_tool_calls": tool_calls,
            "session_id": state["session_id"]
        }

    def _tool_step(self, state: AgentState) -> AgentState:
        """Step logic: Execute any pending tool calls"""
        tool_calls = state["current_tool_calls"] or []
        tool_messages = []
        
        for call in tool_calls:
            # Access tool call data correctly
            function_name = call.function.name
            function_args = json.loads(call.function.arguments)
            tool_call_id = call.id
            # Find the matching tool
            tool = next((t for t in self.tools if t.name == function_name), None)
            if tool:
                result = tool(**function_args)
                tool_message = ToolMessage(
                    content=json.dumps(result), 
                    tool_call_id=tool_call_id, 
                    name=function_name, 
                )
                tool_messages.append(tool_message)
        
        # Clear tool calls and add results to messages
        return {
            "messages": state["messages"] + tool_messages,
            "current_tool_calls": None,
            "session_id": state["session_id"]
        }

    def _create_state_machine(self) -> StateMachine[AgentState]:
        """Create the internal state machine for the agent"""
        machine = StateMachine[AgentState](AgentState)
        
        # Create steps
        entry = EntryPoint[AgentState]()
        message_prep = Step[AgentState]("message_prep", self._prepare_messages_step)
        llm_processor = Step[AgentState]("llm_processor", self._llm_step)
        tool_executor = Step[AgentState]("tool_executor", self._tool_step)
        termination = Termination[AgentState]()
        
        machine.add_steps([entry, message_prep, llm_processor, tool_executor, termination])
        
        # Add transitions
        machine.connect(entry, message_prep)
        machine.connect(message_prep, llm_processor)
        
        # Transition based on whether there are tool calls
        def check_tool_calls(state: AgentState) -> Union[Step[AgentState], str]:
            """Transition logic: Check if there are tool calls"""
            if state.get("current_tool_calls"):
                return tool_executor
            return termination
        
        machine.connect(llm_processor, [tool_executor, termination], check_tool_calls)
        machine.connect(tool_executor, llm_processor)  # Go back to llm after tool execution
        
        return machine

    def invoke(self, query: str, session_id: Optional[str] = None) -> Run:
        """
        Run the agent on a query
        
        Args:
            query: The user's query to process
            session_id: Optional session identifier (uses "default" if None)
            
        Returns:
            The final run object after processing
        """
        session_id = session_id or "default"

        # Create session if it doesn't exist
        self.memory.create_session(session_id)

        # Get previous messages from last run if available
        previous_messages = []
        last_run: Run = self.memory.get_last_object(session_id)
        if last_run:
            last_state = last_run.get_final_state()
            if last_state:
                previous_messages = last_state["messages"]

        initial_state: AgentState = {
            "user_query": query,
            "instructions": self.instructions,
            "messages": previous_messages,
            "current_tool_calls": None,
            "session_id": session_id,
        }

        run_object = self.workflow.run(initial_state)
        
        # Store the complete run object in memory
        self.memory.add(run_object, session_id)
        
        return run_object

    def get_session_runs(self, session_id: Optional[str] = None) -> List[Run]:
        """Get all Run objects for a session
        
        Args:
            session_id: Optional session ID (uses "default" if None)
            
        Returns:
            List of Run objects in the session
        """
        return self.memory.get_all_objects(session_id)

    def reset_session(self, session_id: Optional[str] = None):
        """Reset memory for a specific session
        
        Args:
            session_id: Optional session to reset (uses "default" if None)
        """
        self.memory.reset(session_id)

## Define your tools and instantiate your Agent

In [5]:
@tool
def get_games(num_games:int=1, top:bool=True) -> str:
    """
    Returns the top or bottom N games with highest or lowest scores.    
    args:
        num_games (int): Number of games to return (default is 1)
        top (bool): If True, return top games, otherwise return bottom (default is True)
    """
    data = [
        {"Game": "The Legend of Zelda: Breath of the Wild", "Platform": "Switch", "Score": 98},
        {"Game": "Super Mario Odyssey", "Platform": "Switch", "Score": 97},
        {"Game": "Metroid Prime", "Platform": "GameCube", "Score": 97},
        {"Game": "Super Smash Bros. Brawl", "Platform": "Wii", "Score": 93},
        {"Game": "Mario Kart 8 Deluxe", "Platform": "Switch", "Score": 92},
        {"Game": "Fire Emblem: Awakening", "Platform": "3DS", "Score": 92},
        {"Game": "Donkey Kong Country Returns", "Platform": "Wii", "Score": 87},
        {"Game": "Luigi's Mansion 3", "Platform": "Switch", "Score": 86},
        {"Game": "Pikmin 3", "Platform": "Wii U", "Score": 85},
        {"Game": "Animal Crossing: New Leaf", "Platform": "3DS", "Score": 88}
    ]
    # Sort the games list by Score
    # If top is True, descending order
    sorted_games = sorted(data, key=lambda x: x['Score'], reverse=top)
    
    # Return the N games
    return sorted_games[:num_games]

In [6]:
tools = [get_games]

In [7]:
agent = MemoryAgent(
    model_name="gpt-4o-mini",
    instructions="You can bring insights about a game dataset based on users questions",
    tools=tools
)

In [8]:
def print_messages(messages: List[BaseMessage]):
    for m in messages:
        print(f" -> (role = {m.role}, content = {m.content}, tool_calls = {getattr(m, 'tool_calls', None)})")

## Run your Agent

In [9]:
# First interaction in session "games"
print("First interaction:")
run1 = agent.invoke("What's the best game in the dataset?", "games")

print("\nMessages from run 1:")
messages = run1.get_final_state()["messages"]
print_messages(messages)

First interaction:
[StateMachine] Starting: __entry__
[StateMachine] Executing step: message_prep
[StateMachine] Executing step: llm_processor
[StateMachine] Executing step: tool_executor
[StateMachine] Executing step: llm_processor
[StateMachine] Terminating: __termination__

Messages from run 1:
 -> (role = system, content = You can bring insights about a game dataset based on users questions, tool_calls = None)
 -> (role = user, content = What's the best game in the dataset?, tool_calls = None)
 -> (role = assistant, content = None, tool_calls = [ChatCompletionMessageToolCall(id='call_SLomFYOi1ZA5YFwoO0zPBCoQ', function=Function(arguments='{"num_games":1,"top":true}', name='get_games'), type='function')])
 -> (role = tool, content = [{"Game": "The Legend of Zelda: Breath of the Wild", "Platform": "Switch", "Score": 98}], tool_calls = None)
 -> (role = assistant, content = The best game in the dataset is **The Legend of Zelda: Breath of the Wild**, available on the Switch, with a sco

In [10]:
# Second interaction in same session
print("\nSecond interaction (same session):")
run2 = agent.invoke("And what was its score?", "games")

print("\nMessages from run 2:")
messages = run2.get_final_state()["messages"]
print_messages(messages)


Second interaction (same session):
[StateMachine] Starting: __entry__
[StateMachine] Executing step: message_prep
[StateMachine] Executing step: llm_processor
[StateMachine] Terminating: __termination__

Messages from run 2:
 -> (role = system, content = You can bring insights about a game dataset based on users questions, tool_calls = None)
 -> (role = user, content = What's the best game in the dataset?, tool_calls = None)
 -> (role = assistant, content = None, tool_calls = [ChatCompletionMessageToolCall(id='call_SLomFYOi1ZA5YFwoO0zPBCoQ', function=Function(arguments='{"num_games":1,"top":true}', name='get_games'), type='function')])
 -> (role = tool, content = [{"Game": "The Legend of Zelda: Breath of the Wild", "Platform": "Switch", "Score": 98}], tool_calls = None)
 -> (role = assistant, content = The best game in the dataset is **The Legend of Zelda: Breath of the Wild**, available on the Switch, with a score of **98**., tool_calls = None)
 -> (role = user, content = And what wa

In [11]:
# New session
print("\nNew session interaction:")
run3 = agent.invoke("What's the worst game?", "other_session")

print("\nMessages from run 3:")
messages = run3.get_final_state()["messages"]
print_messages(messages)



New session interaction:
[StateMachine] Starting: __entry__
[StateMachine] Executing step: message_prep
[StateMachine] Executing step: llm_processor
[StateMachine] Executing step: tool_executor
[StateMachine] Executing step: llm_processor
[StateMachine] Terminating: __termination__

Messages from run 3:
 -> (role = system, content = You can bring insights about a game dataset based on users questions, tool_calls = None)
 -> (role = user, content = What's the worst game?, tool_calls = None)
 -> (role = assistant, content = None, tool_calls = [ChatCompletionMessageToolCall(id='call_k28lcnTDh43aJfp2fqf2ud8e', function=Function(arguments='{"num_games":1,"top":false}', name='get_games'), type='function')])
 -> (role = tool, content = [{"Game": "Pikmin 3", "Platform": "Wii U", "Score": 85}], tool_calls = None)
 -> (role = assistant, content = The worst game, based on the lowest score, is "Pikmin 3" for the Wii U, with a score of 85., tool_calls = None)


## Check session histories

In [12]:
print("Games session runs:")
runs = agent.get_session_runs("games")
for i, run_object in enumerate(runs, 1):
    print(f"\n# Run {i}", run_object.metadata)
    print("Messages:")
    print_messages(run_object.get_final_state()["messages"])

Games session runs:

# Run 1 {'run_id': '08203247-b0ba-4dd5-b2fc-280d653d7ab3', 'start_timestamp': '2025-05-12 03:06:05.323496', 'end_timestamp': '2025-05-12 03:06:08.289115', 'snapshot_counts': 5}
Messages:
 -> (role = system, content = You can bring insights about a game dataset based on users questions, tool_calls = None)
 -> (role = user, content = What's the best game in the dataset?, tool_calls = None)
 -> (role = assistant, content = None, tool_calls = [ChatCompletionMessageToolCall(id='call_SLomFYOi1ZA5YFwoO0zPBCoQ', function=Function(arguments='{"num_games":1,"top":true}', name='get_games'), type='function')])
 -> (role = tool, content = [{"Game": "The Legend of Zelda: Breath of the Wild", "Platform": "Switch", "Score": 98}], tool_calls = None)
 -> (role = assistant, content = The best game in the dataset is **The Legend of Zelda: Breath of the Wild**, available on the Switch, with a score of **98**., tool_calls = None)

# Run 2 {'run_id': '09f29ee6-f883-4df1-a228-409a73305ed4

In [13]:
print("Games session snapshots:\n")

runs = agent.get_session_runs("games")
for run_object in runs:
    print(run_object)
    for snp in run_object.snapshots:
        print(f"-> {snp}")
    print("\n")

Games session snapshots:

Run('08203247-b0ba-4dd5-b2fc-280d653d7ab3')
-> Snapshot(e3a47265-bde3-4dd4-9583-8fe1487bc82c) @ [2025-05-12 03:06:05.323543]: __entry__.State({'user_query': "What's the best game in the dataset?", 'instructions': 'You can bring insights about a game dataset based on users questions', 'messages': [], 'current_tool_calls': None, 'session_id': 'games'})
-> Snapshot(a3769f17-639f-44db-8aa6-5395bb57d6df) @ [2025-05-12 03:06:05.323624]: message_prep.State({'user_query': "What's the best game in the dataset?", 'instructions': 'You can bring insights about a game dataset based on users questions', 'messages': [SystemMessage(role='system', content='You can bring insights about a game dataset based on users questions'), UserMessage(role='user', content="What's the best game in the dataset?")], 'current_tool_calls': None, 'session_id': 'games'})
-> Snapshot(97769972-e1b9-4335-b88b-85a39c699c18) @ [2025-05-12 03:06:06.579710]: llm_processor.State({'user_query': "What's th