Skip to content

feat: Add title attribute to Session model #1372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Session(BaseModel):
"""The name of the app."""
user_id: str
"""The id of the user."""
title: str | None = None
"""An optional title for the session."""
state: dict[str, Any] = Field(default_factory=dict)
"""The state of the session."""
events: list[Event] = Field(default_factory=list)
Expand Down
116 changes: 107 additions & 9 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ async def test_create_get_session(service_type):
state = {'key': 'value'}

session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
app_name=app_name, user_id=user_id, state=state, title="test_title"
)
assert session.app_name == app_name
assert session.user_id == user_id
assert session.id
assert session.state == state
assert session.title == "test_title"
assert (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
Expand Down Expand Up @@ -95,9 +96,12 @@ async def test_create_and_list_sessions(service_type):
user_id = 'test_user'

session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
for i, session_id in enumerate(session_ids):
await session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
app_name=app_name,
user_id=user_id,
session_id=session_id,
title=f"test_title_{i}",
)

list_sessions_response = await session_service.list_sessions(
Expand All @@ -106,6 +110,7 @@ async def test_create_and_list_sessions(service_type):
sessions = list_sessions_response.sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
assert sessions[i].title == f"test_title_{i}"


@pytest.mark.asyncio
Expand All @@ -128,18 +133,24 @@ async def test_session_state(service_type):
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
title="session_11_title",
)
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
title="session_12_title",
)
await session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
app_name=app_name,
user_id=user_id_2,
session_id=session_id_2,
title="session_2_title",
)

assert session_11.state.get('key11') == 'value11'
assert session_11.title == "session_11_title"

event = Event(
invocation_id='invocation',
Expand Down Expand Up @@ -201,7 +212,11 @@ async def test_create_new_session_will_merge_states(service_type):
state_1 = {'key1': 'value1'}

session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
app_name=app_name,
user_id=user_id,
state=state_1,
session_id=session_id_1,
title="session_1_title",
)

event = Event(
Expand All @@ -225,9 +240,14 @@ async def test_create_new_session_will_merge_states(service_type):
assert not session_1.state.get('temp:key')

session_2 = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
app_name=app_name,
user_id=user_id,
state={},
session_id=session_id_2,
title="session_2_title",
)
# Session 2 has the persisted states
assert session_2.title == "session_2_title"
assert session_2.state.get('app:key') == 'value'
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
Expand All @@ -244,7 +264,7 @@ async def test_append_event_bytes(service_type):
user_id = 'user'

session = await session_service.create_session(
app_name=app_name, user_id=user_id
app_name=app_name, user_id=user_id, title="append_event_bytes_title"
)

test_content = types.Content(
Expand Down Expand Up @@ -285,7 +305,7 @@ async def test_append_event_complete(service_type):
user_id = 'user'

session = await session_service.create_session(
app_name=app_name, user_id=user_id
app_name=app_name, user_id=user_id, title="append_event_complete_title"
)
event = Event(
invocation_id='invocation',
Expand Down Expand Up @@ -326,7 +346,7 @@ async def test_get_session_with_config(service_type):

num_test_events = 5
session = await session_service.create_session(
app_name=app_name, user_id=user_id
app_name=app_name, user_id=user_id, title="get_session_with_config_title"
)
for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i)
Expand Down Expand Up @@ -357,6 +377,84 @@ async def test_get_session_with_config(service_type):
)
events = session.events
assert len(events) == num_test_events - after_timestamp + 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_session_with_title(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user_title'
session_title = "My Test Session Title"

session = await session_service.create_session(
app_name=app_name, user_id=user_id, title=session_title
)
assert session.title == session_title

retrieved_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert retrieved_session is not None
assert retrieved_session.title == session_title


@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_session_without_title(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user_no_title'

session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
assert session.title is None

retrieved_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert retrieved_session is not None
assert retrieved_session.title is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_list_sessions_includes_titles(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user_list_titles'

session_info = [
{"id": "s1", "title": "Session One"},
{"id": "s2", "title": None},
{"id": "s3", "title": "Session Three"},
]

for info in session_info:
await session_service.create_session(
app_name=app_name,
user_id=user_id,
session_id=info["id"],
title=info["title"],
)

list_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
assert list_response is not None
listed_sessions = sorted(list_response.sessions, key=lambda s: s.id)

assert len(listed_sessions) == len(session_info)
for i, listed_session in enumerate(listed_sessions):
assert listed_session.id == session_info[i]["id"]
assert listed_session.title == session_info[i]["title"]
assert events[0].timestamp == after_timestamp

# Expect no events if none are > after_timestamp.
Expand Down
16 changes: 14 additions & 2 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'key': {'value': 'test_value'},
},
'userId': 'user',
'title': 'Test Session 1',
}
MOCK_SESSION_JSON_2 = {
'name': (
Expand All @@ -47,6 +48,7 @@
),
'updateTime': '2024-12-13T12:12:12.123456Z',
'userId': 'user',
'title': 'Test Session 2',
}
MOCK_SESSION_JSON_3 = {
'name': (
Expand All @@ -55,6 +57,7 @@
),
'updateTime': '2024-12-14T12:12:12.123456Z',
'userId': 'user2',
'title': 'Test Session 3',
}
MOCK_EVENT_JSON = [
{
Expand Down Expand Up @@ -112,6 +115,7 @@
app_name='123',
user_id='user',
id='1',
title='Test Session 1',
state=MOCK_SESSION_JSON_1['sessionState'],
last_update_time=isoparse(MOCK_SESSION_JSON_1['updateTime']).timestamp(),
events=[
Expand All @@ -138,6 +142,7 @@
app_name='123',
user_id='user',
id='2',
title='Test Session 2',
last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(),
events=[
Event(
Expand Down Expand Up @@ -225,9 +230,11 @@ async def async_request(
+ new_session_id
),
'userId': request_dict['user_id'],
'title': request_dict.get('title'),
'sessionState': request_dict.get('session_state', {}),
'updateTime': '2024-12-12T12:12:12.123456Z',
}
# Return LRO for session creation
return {
'name': (
'projects/test_project/locations/test_location/'
Expand Down Expand Up @@ -343,18 +350,23 @@ async def test_create_session():
session_service = mock_vertex_ai_session_service()

state = {'key': 'value'}
title = "test_create_title"
session = await session_service.create_session(
app_name='123', user_id='user', state=state
app_name='123', user_id='user', state=state, title=title
)
assert session.state == state
assert session.app_name == '123'
assert session.user_id == 'user'
assert session.title == title
assert session.last_update_time is not None

session_id = session.id
assert session == await session_service.get_session(
# Retrieve the session again to ensure persistence and correctness
retrieved_session = await session_service.get_session(
app_name='123', user_id='user', session_id=session_id
)
assert session == retrieved_session
assert retrieved_session.title == title


@pytest.mark.asyncio
Expand Down