In [1]:
import chromadb
import logfire
import polars as pl
from chromadb.utils import embedding_functions
from logfire.experimental.query_client import AsyncLogfireQueryClient
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext

from knd.ai import system_message, user_message
from knd.memory import AgentMemories

%load_ext autoreload
%autoreload 2

In [None]:
chroma_client = chromadb.PersistentClient(path="chroma_db")
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name="Alibaba-NLP/gte-modernbert-base",
)
collection_name = "exp"
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.get_or_create_collection(
    name=collection_name, embedding_function=sentence_transformer_ef
)


In [4]:
messages = [
    user_message("Tell me a joke about the justice league"),
    system_message("You are a joke teller. talk like tony stark"),
    user_message("make the joke about how the avengers are better"),
    system_message("talk with emojis"),
]
agent = Agent(model="google-gla:gemini-1.5-flash")


In [5]:
res = await agent.run(user_prompt="go on", message_history=messages)

  res = await agent.run(user_prompt="go on", message_history=messages)


In [6]:
res.all_messages()

[ModelRequest(parts=[UserPromptPart(content='Tell me a joke about the justice league', timestamp=datetime.datetime(2025, 1, 26, 15, 53, 27, 988673, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'),
 ModelRequest(parts=[SystemPromptPart(content='You are a joke teller. talk like tony stark', dynamic_ref=None, part_kind='system-prompt')], kind='request'),
 ModelRequest(parts=[UserPromptPart(content='make the joke about how the avengers are better', timestamp=datetime.datetime(2025, 1, 26, 15, 53, 27, 988683, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'),
 ModelRequest(parts=[SystemPromptPart(content='talk with emojis', dynamic_ref=None, part_kind='system-prompt')], kind='request'),
 ModelRequest(parts=[UserPromptPart(content='go on', timestamp=datetime.datetime(2025, 1, 26, 15, 53, 46, 233948, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'),
 ModelResponse(parts=[TextPart(content='Alright, listen up, you micros

In [2]:
class Critique(BaseModel):
    funny: bool
    reason: str = ""
    pointers: list[str] = Field(default_factory=list)


agent = Agent(model="google-gla:gemini-1.5-flash", system_prompt="Use the tool to tell jokes", name="joker_agent")

joker = Agent(model="google-gla:gemini-1.5-flash", system_prompt="Tell knock knock jokes", name="joker_tool")

critic = Agent(
    model="google-gla:gemini-1.5-flash",
    system_prompt="Critique the joke as funny or not funny. If not funny, give a reason for your opinion and pointers for improvement",
    result_type=Critique,
    name="joke_critic",
)


@agent.tool_plain
async def joke_teller(premise: str) -> str:
    "Tool to tell jokes about anything"
    return (await joker.run(premise)).data


@agent.result_validator
async def validate_joke(ctx: RunContext, joke: str) -> str:
    critique = (
        await critic.run(
            user_prompt="Critique the joke as funny or not funny. If not funny, give a reason for your opinion and pointers for improvement. It will always be a knock knock joke so don't mention that",
            message_history=ctx.messages,
        )
    ).data
    if critique.funny:
        logfire.info("hilarious")
        return joke
    else:
        logfire.error("not funny", _tags=["unfunny_joke"])
        ctx.messages.append(user_message(f"Joke Critique: {critique.model_dump_json()}"))
        return joke


In [3]:
joke = await agent.run("Tell me a joke about the justice league")

14:27:11.694 joker_agent run prompt=Tell me a joke about the justice league
14:27:11.694   preparing model and tools run_step=1
14:27:11.695   model request


14:27:14.634   handle model response
14:27:14.635     running tools=['joke_teller']
14:27:14.635     joker_tool run prompt=Justice League
14:27:14.636       preparing model and tools run_step=1
14:27:14.636       model request
14:27:15.392       handle model response
14:27:15.399   preparing model and tools run_step=2
14:27:15.400   model request
14:27:16.174   handle model response
14:27:16.176     joke_critic run prompt=Critique the joke as funny or not funny. If not funny, give a ...nt. It will always be a knock knock joke so don't mention that
14:27:16.178       preparing model and tools run_step=1
14:27:16.180       model request
14:27:17.858       handle model response
14:27:17.873     not funny [unfunny_joke]


In [4]:
joke.all_messages()

[ModelRequest(parts=[SystemPromptPart(content='Use the tool to tell jokes', dynamic_ref=None, part_kind='system-prompt'), UserPromptPart(content='Tell me a joke about the justice league', timestamp=datetime.datetime(2025, 1, 25, 14, 27, 11, 694489, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'),
 ModelResponse(parts=[ToolCallPart(tool_name='joke_teller', args=ArgsDict(args_dict={'premise': 'Justice League'}), tool_call_id=None, part_kind='tool-call')], model_name='gemini-1.5-flash', timestamp=datetime.datetime(2025, 1, 25, 14, 27, 14, 633532, tzinfo=datetime.timezone.utc), kind='response'),
 ModelRequest(parts=[ToolReturnPart(tool_name='joke_teller', content="Knock knock.\n\nWho's there?\n\nJustice.\n\nJustice who?\n\nJustice League of extraordinary jokes!  (or... Justice League of heroes!)\n", tool_call_id=None, timestamp=datetime.datetime(2025, 1, 25, 14, 27, 15, 397216, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request'),
 ModelRespon

In [5]:
print(joke.data)

Knock knock.

Who's there?

Justice.

Justice who?

Justice League of extraordinary jokes!  (or... Justice League of heroes!)



In [6]:
query = """
WITH agent_traces AS (
  SELECT DISTINCT trace_id 
  FROM records 
  WHERE attributes->>'agent_name' = 'joker_agent'
)
SELECT 
  r.trace_id,
  r.span_id,
  r.span_name,
  r.start_timestamp,
  r.end_timestamp,
  r.duration,
  r.level,
  r.message,
  r.tags,
  r.attributes->>'agent_name' as agent_name
FROM records r
JOIN agent_traces at ON r.trace_id = at.trace_id
ORDER BY r.trace_id, r.start_timestamp;
"""

async with AsyncLogfireQueryClient(read_token="H0CTvcy0WCrl6xjxm8r8ZjWxP3LPSq5Mzdv81GvXXRPz") as client:
    df_from_arrow = pl.DataFrame(pl.from_arrow(await client.query_arrow(sql=query)))
    print(df_from_arrow)

shape: (51, 10)
┌────────────┬────────────┬────────────┬───────────┬───┬───────┬───────────┬───────────┬───────────┐
│ trace_id   ┆ span_id    ┆ span_name  ┆ start_tim ┆ … ┆ level ┆ message   ┆ tags      ┆ agent_nam │
│ ---        ┆ ---        ┆ ---        ┆ estamp    ┆   ┆ ---   ┆ ---       ┆ ---       ┆ e         │
│ str        ┆ str        ┆ str        ┆ ---       ┆   ┆ u16   ┆ str       ┆ list[str] ┆ ---       │
│            ┆            ┆            ┆ datetime[ ┆   ┆       ┆           ┆           ┆ str       │
│            ┆            ┆            ┆ μs, UTC]  ┆   ┆       ┆           ┆           ┆           │
╞════════════╪════════════╪════════════╪═══════════╪═══╪═══════╪═══════════╪═══════════╪═══════════╡
│ 01949dca3c ┆ 9d2f035a8c ┆ {agent_nam ┆ 2025-01-2 ┆ … ┆ 9     ┆ joker_age ┆ []        ┆ joker_age │
│ 466bcad3a3 ┆ 29ec50     ┆ e} run     ┆ 5 14:07:4 ┆   ┆       ┆ nt run    ┆           ┆ nt        │
│ fffe5d736e ┆            ┆ {prompt=}  ┆ 4.198295  ┆   ┆       ┆ prompt=Te 

In [8]:
df_from_arrow.filter(pl.col("tags").list.contains("unfunny_joke"))

trace_id,span_id,span_name,start_timestamp,end_timestamp,duration,level,message,tags,agent_name
str,str,str,"datetime[μs, UTC]","datetime[μs, UTC]",f64,u16,str,list[str],str
"""01949ddc0ccef3d1ad341dff704d1f…","""c083c4bbdf53c116""","""not funny""",2025-01-25 14:27:17.873139 UTC,2025-01-25 14:27:17.873139 UTC,,17,"""not funny""","[""unfunny_joke""]",


In [9]:
df_from_arrow.columns

['trace_id',
 'span_id',
 'span_name',
 'start_timestamp',
 'end_timestamp',
 'duration',
 'level',
 'message',
 'tags',
 'agent_name']

In [1]:
from uuid import uuid4

from knd.ai import user_message
from knd.memory import UserSpecificExperience


In [2]:
memories = AgentMemories(
    agent_name="test_agent",
    user_specific_experience=UserSpecificExperience(user_id=uuid4()),
    agent_experience=None,
)

In [3]:
memories.add_message(user_message("hello"))

In [6]:
memories.user_specific_experience.message_history

[ModelRequest(parts=[UserPromptPart(content='hello', timestamp=datetime.datetime(2025, 1, 27, 9, 15, 27, 760464, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request')]