In [1]:
from typing import Optional, Literal, Iterator
from repsheet_backend.common import MEMBER_VOTES_TABLE, BILLS_TABLE, VOTES_HELD_TABLE, MEMBERS_TABLE, db_connect, JUSTIN, PIERRE
from repsheet_backend.bills import BillId, BillIssues, BillSummary
from repsheet_backend.genai import generate_text, GEMINI_PRO_2_5, GEMINI_FLASH_2, CLAUDE_SONNET
from pydantic import BaseModel
import json
import asyncio
import re
from math import floor, ceil
from random import Random

In [2]:

with open("prompts/summarize-member/001.txt", "r") as f:
    summarise_member_prompt_template = f.read()
with open("prompts/merge-summaries/001.txt", "r") as f:
    merge_summaries_prompt_template = f.read()

class BillVotingRecord(BaseModel):
    summary: str
    billID: str
    billNumber: str
    voted: Literal["yea", "nay", "abstain"]
    issues: BillIssues


class MemberSummary(BaseModel):
    summary: str
    issues: BillIssues


MEMBER_BILL_VOTING_QUERY = f"""
WITH most_recent_vote AS (
SELECT              
    b.[Bill ID] AS bill_id,
    MAX(v.[Vote ID]) AS vote_id
FROM {MEMBER_VOTES_TABLE} AS mv
JOIN {VOTES_HELD_TABLE} v
    ON mv.[Vote ID] = v.[Vote ID]
JOIN {BILLS_TABLE} AS b
    ON v.[Bill ID] = b.[Bill ID]
WHERE
    mv.[Member ID] = :member_id
GROUP BY 
    b.[Bill ID]
)

SELECT 
    b.[Bill ID] AS bill_id,
    b.[Bill Number] AS bill_number,
    b.[Summary] AS full_summary,
    mv.[Member Voted] AS voted
FROM most_recent_vote
JOIN {MEMBER_VOTES_TABLE} AS mv
    ON most_recent_vote.vote_id = mv.[Vote ID]
JOIN {BILLS_TABLE} AS b
    ON most_recent_vote.bill_id = b.[Bill ID]
WHERE
    mv.[Member ID] = :member_id
AND
    b.[Summary] IS NOT NULL
"""

BATCH_COUNT = int(floor(200000 / 8192) - 1)

# fixed to make sure batches are deterministic
# to allow for caching of AI responses
RANDOM_SEED = 338

def get_member_voting_record(member_id: str) -> list[BillVotingRecord]:
    with db_connect() as db:
        rows = db.execute(MEMBER_BILL_VOTING_QUERY, {"member_id": member_id}).fetchall()

    voting_record: list[BillVotingRecord] = []
    for row in rows:
        full_summary = BillSummary.model_validate_json(row["full_summary"])
        voted = row["voted"].lower() if row["voted"] else "abstain" 
        voting_record.append(
            BillVotingRecord(
                summary=full_summary.summary,
                billID=row["bill_id"],
                billNumber=row["bill_number"],
                voted=voted,
                issues=full_summary.issues,
            )
        )
    bill_ids = [vote.billID for vote in voting_record]
    assert len(set(bill_ids)) == len(bill_ids), "Duplicate bill IDs found in voting record"
    return voting_record

def batched(iterable: list, batches: int) -> Iterator[list]:
    batch_size = len(iterable) // batches
    for i in range(0, batches):
        if i == batches - 1:
            yield iterable
            return
        else:
            yield iterable[:batch_size]
            iterable = iterable[batch_size:]
            
def get_member_summarisation_prompts(member_id: str) -> list[str]:
    voting_record = get_member_voting_record(member_id)
    Random(RANDOM_SEED).shuffle(voting_record)
    voting_record = [vote.model_dump(mode="json", exclude_none=True) for vote in voting_record]
    result = []
    for batch in batched(voting_record, BATCH_COUNT):
        batch_json = json.dumps(voting_record, indent=2)
        result.append(summarise_member_prompt_template.replace("{{RAW_INPUT_DATA}}", batch_json))
    return result

def get_summary_merge_prompt(summaries: list[MemberSummary]) -> str:
    summaries_json = [summary.model_dump(mode="json") for summary in summaries]
    summaries_json = json.dumps(summaries_json, indent=2)
    return merge_summaries_prompt_template.replace("{{RAW_INPUT_DATA}}", summaries_json)


value_meant_to_be_part_of_previous_key_regex = re.compile(r'",\s*"([^"]+)"\s*}')

def fix_crappy_json(json_str: str) -> str:
    json_str = value_meant_to_be_part_of_previous_key_regex.sub(r' \1"}', json_str)
    json_str = json_str.replace("\\\n", "\\n")
    return json_str

def parse_member_summary_response(response: str | None) -> MemberSummary:
    assert response is not None
    response = response.removeprefix("```json\n").removesuffix("\n```")
    response = fix_crappy_json(response)
    return MemberSummary.model_validate_json(response)


async def generate_member_summary(member_id: str) -> MemberSummary:
    prompts = get_member_summarisation_prompts(member_id)
    summaries = await asyncio.gather(*[
        # use the cheaper model for the very batched summaries, as they have a high token count
        generate_text(prompt, model=GEMINI_FLASH_2)
        for prompt in prompts
    ])
    processed_summaries = [
        parse_member_summary_response(summary)
        for summary in summaries
    ]
    merge_summary_prompt = get_summary_merge_prompt(processed_summaries)
    # use the expensive model to merge them, as this is a small number of tokens,
    # and is also the final output so should be polished
    merged_summary = await generate_text(merge_summary_prompt, model=CLAUDE_SONNET)
    assert merged_summary is not None
    return parse_member_summary_response(merged_summary)

In [3]:
x = await generate_member_summary(PIERRE)
with open('merged.txt', 'w') as f:
    f.write(x.model_dump_json(indent=2))

Generating text with claude-3-7-sonnet-latest (124970 chars)
Received response from claude-3-7-sonnet-latest (4811 chars)


In [None]:
await generate_text("Please summarise bill C-79 of the 1st session of the 44th parliament of Canada")

In [None]:
pierre_2 = await generate_member_summary(PIERRE)
pierre_2_5 = await generate_member_summary(PIERRE, model=GEMINI_PRO_2_5)
# pierre_sonnet = await generate_member_summary(PIERRE, model=SONNET)

In [None]:
with open('pierre-pro.json', 'w') as f:
    f.write(pierre_2_5.model_dump_json(indent=2))

with open('pierre-flash.json', 'w') as f:
    f.write(pierre_2.model_dump_json(indent=2))

In [None]:
with db_connect() as db:
    all_member_ids = [row[0] for row in db.execute(f"SELECT DISTINCT [Member ID] FROM {MEMBERS_TABLE}").fetchall()]

summaries = await asyncio.gather(*[
    generate_member_summary(member_id)
    for member_id in all_member_ids
])

member_summaries = [
    {"member_id": member_id, "summary": summary.model_dump_json()}
    for member_id, summary in zip(all_member_ids, summaries)
    if summary is not None
]

In [None]:


with db_connect() as db:
    # Insert the new summaries
    db.executemany(f"UPDATE {MEMBERS_TABLE} SET Summary = :summary WHERE [Member ID] = :member_id", member_summaries)
