# Introduction
**Goal:** Democratize the understanding of legislature that flows through congress. <br />

**Motivation**: <br />
Thousands of bills are introduced by the United States house of representatives or senate every year. From the beginning of the **2023 congressional session** to now, there have been **over 10,000 pieces** of legislature just from the house side alone. These documents, when passed into law have significant impacts on organizations and the people living within the United States. Then surely, it would be in the best interest of the affected groups to be able to access and understand what these effects will be to aid in making informed decisions. However, legislature can be wordy and not the most readable to everyone. Between 2021 and 2024, there were 2 laws passes with roughly **700,000 words** each. From the sheer volume of bills to the length and semantics of the bills, it can be cumbersome or nigh impossible to keep track of all that is happening in congress, even mores so in the age of misinformation. <br />

**Idea:** <br/>
The adoption and application of large language models (LLMs) has changed the world in several ways and improved efficiency in many contexts where language is concerned. As legislature consists of written language, LLMs are most apropirate to handle the task of helping people understand the impacts of laws. In this notebook, we combine retrieval augmented generation (RAG) with a question and answer framework. The technologies used will be **Google's Gemini LLM** which has proven to be one of top performing models for a broad range of natural language tasks, and LlamaIndex paired with Chroma. Both LlamaIndex and Chroma are open source APIs that assists in performing storage and retrieval of text. You can read more about [Gemini here](https://ai.google.dev/), [LlamaIndex here](https://www.llamaindex.ai/), and [Chromadb here](https://www.trychroma.com/). One key advantage of using **Gemini is its long context window of 2 million tokens** for the pro version and 1 million for the flash version so that it is possible to fit even legislature with hundreds of thousands of words into one request. A second advantage is **context caching** where users can pre-process any data that is at least 32,769 tokens long in the case several queries need to be made in relation to that data point. In effect, saving computational costs and time which in turn alleviates the burden on the environment. <br />

**Method:** <br />
First, we will pull bills from the congress.gov API. Then, a indexed vector database of the bills will be created for efficient retrieval of relevant documents. Lastly, users will be able to ask questions about any area of law they are invested in. For example, suppose I am a individual concerned about health care laws, then I can submit a query to pull relevant laws from the database, cache those laws, start a chat with Gemini, and ask questions that come to mind as I review the responses. Another example is, suppose I am a property developer and I want to know how legal conditions have changed or may changed. Again, I would be able to type an initial query to retrieve laws related to construction and development, cache those laws, and chat about them.

# Setup
To begin we need to sign up for Google AI Studio and congress.gov for API keys to access Gemini and congress.gove endpoint so that we can retrieve legislation. Now I have already saved all of the passed laws from 2021 - 2024 in the forms of JSON and LlamaIndex indices. However, if you wish to get legislature that was proposed or legislature from other years then you may do so using the methods defined in the Legislation Retrieval section. Note that congress.gov is free, but they place limits on requests. **You do not have to interact with the congress.gov API as I have already loaded the passed laws for 2021-2024 as a public kaggle dataset.** Unless you want to download laws from other years or not only bills that were signed into law.

### Installing Gemini, LlamaIndex, and Chroma API

In [None]:
%pip install llama-index-vector-stores-chroma
%pip install llama-index-llms-gemini llama-index
%pip install llama-index-embeddings-huggingface
%pip install -U -q "google-generativeai>=0.8.3" chromadb

### API Keys
- To generate a congress.gov API key go [here to congress.gov](https://www.congress.gov/help/using-data-offsite) and sign up
- To generate Google AI Studio API key for Gemini go [here to aistudio.google.com](https://aistudio.google.com/app/apikey) <br />
Afterwards you can add API keys to the kaggle notebook by using **"Add-ons"** dropdown menu and clicking **"Secrets"**.

In [None]:
# SET YOU API KEYS IN THE QUOTATION MARKS

CONGRESS_API_KEY = ""
GOOGLE_API_KEY = ""

### Code to Initialize congress.gov API Client

In [None]:
import os
import time
from typing import List, Union, Optional
import requests
import re
import json
from urllib.parse import urljoin
import xml.etree.ElementTree as ET

API_VERSION = "v3"
ROOT_URL = "https://api.congress.gov/"
RESPONSE_FORMAT = "json"

In [None]:
# Directly copied from congress.gov github repo: https://github.com/LibraryOfCongress/api.congress.gov
class _MethodWrapper:
    """ Wrap request method to facilitate queries.  Supports requests signature. """

    def __init__(self, parent, http_method):
        self._parent = parent
        self._method = getattr(parent._session, http_method)

    def __call__(self, endpoint, *args, **kwargs):  # full signature passed here
        response = self._method(
            urljoin(self._parent.base_url, endpoint), *args, **kwargs
        )
        # unpack
        if response.headers.get("content-type", "").startswith("application/json"):
            return response.json(), response.status_code
        else:
            return response.content, response.status_code


class CDGClient:
    """ A sample client to interface with Congress.gov. """

    def __init__(
        self,
        api_key,
        api_version=API_VERSION,
        response_format=RESPONSE_FORMAT,
        raise_on_error=True,
    ):
        self.base_url = urljoin(ROOT_URL, api_version) + "/"
        self._session = requests.Session()

        # do not use url parameters, even if offered, use headers
        self._session.params = {"format": response_format}
        self._session.headers.update({"x-api-key": api_key})

        if raise_on_error:
            self._session.hooks = {
                "response": lambda r, *args, **kwargs: r.raise_for_status()
            }

    def __getattr__(self, method_name):
        """Find the session method dynamically and cache for later."""
        method = _MethodWrapper(self, method_name)
        self.__dict__[method_name] = method
        return method


In [None]:
# This is the client to interact with to get the texts and legislation metadata.
client = CDGClient(CONGRESS_API_KEY)  # pass the key, response_format="xml" if needed

### Congressional Session to Year
The subsequent cell is to create a mapping from congressional sessions to years. How congress.gov stores congressional sessions is by number. For example, 2023-2024 is session 118 and the API only returns the session number.

In [None]:
congress_to_years = {}
for i in reversed(range(1, 119)):
    start = 1906+i
    end = 1905+i
    congress_to_years[i] = str(start) + "-" + str(end)

# Retrieval of Bills
Run get_laws() and pass in a list of integers that specify which congressional sessions you wish to pull legislature from and the chambers of interest. The only chambers are the house of representatives ('hr') and senate ('s'). Currently the method only supports filtering out unpassed bills or not, but it can be easily modified to include only bills that passed one part of congress or bills that passed both parts of congress. Do keep in mind that congress.gov API **only allows 5,000 requests per hour** and each request can return a max of 250 bills.

In [None]:
def create_bill_dict(bill: dict) -> Union[dict, str]:
    "Create a dictionary to hold the legislation metadata."

    tmp = {}
    tmp['latest_action'] = bill['latestAction']['text']
    tmp['num'] = bill['number']
    tmp['congress'] = bill['congress']
    tmp['latest_action_date'] = bill['latestAction']['actionDate']
    tmp['chamber'] = bill['type'].lower()
    title = bill['title']

    return tmp, title

# only updates if the bill has become a law already
def update_laws(
    bills: dict,
    laws: dict,
    passed: bool = True,
) -> dict:
    """
    Update the dict of laws with newly retrieved laws where the keys are titles and values
    are law metadata.
    """
    for b in bills:
        if passed:
            if 'Became Public Law' in b['latestAction']['text']:
                tmp, title = create_bill_dict(b)
                laws[title] = tmp
        else:
            tmp, title = create_bill_dict(b)
            laws[title] = tmp
    return laws

# options for chamber are "hr" for house and "s" for senate
# 118 is the congress 2023-2024
def get_laws(
    congress: List[int] = [118, 117],
    chamber: List[str] = ['hr', 's'],
    passed: bool = True,
) -> dict:
    """
    Loop through specified chamber of congress and congressional sessions to pull the associated
    bills presented by those chambers and in those years.
    """
    laws = {}
    for ch in chamber:
        for c in congress:
            num_bills_retrieved = 0
            endpoint = f"https://api.congress.gov/v3/bill/{c}/{ch}?offset=0&limit=250&format=json"
            data, _ = client.get(endpoint)
            total_num_bills = data['pagination']['count']
            next_page = data['pagination']['next']
            bills = data['bills']
            num_bills_retrieved += len(bills)

            laws = update_laws(bills=bills, laws=laws)

            while num_bills_retrieved < total_num_bills:

                data, _ = client.get(next_page)
                bills = data['bills']
                num_bills_retrieved += len(bills)
                try:
                    next_page = data['pagination']['next']
                    laws = update_laws(bills=bills, laws=laws, passed=passed)
                except KeyError:
                    print("No next page")
                    print(f"Num bills retrieved: {num_bills_retrieved} out of {data['pagination']['count']}")
                    break
    return laws

In [None]:
# Retrieving law dict w/o text
passed_laws = get_laws(congress=[118, 117], chamber=['hr', 's'])

In [None]:
titles = list(passed_laws.keys())
num_laws_pulled = len(titles)
print(f"Num laws passed in last two years from both chambers of congress: {num_laws_pulled}")
print(f"Name of first retrieved law: {titles[0]}")
print(f"Example of first law from dict: {passed_laws[titles[0]]}")

In [None]:
# save file so as not to need to get from API again
with open("passed_laws_117_118_both_chambers.json", "w") as fout:
    json.dump(passed_laws, fout)

# RAG with LlamaIndex and Chroma
The items returned above from congress.gov API only contains the legislature metadata from which we can use to again request the API to get links to html, xml, pdf versions of the text. Getting the acutal texts requires using requests to download the html which is the most friendly format to work with for inputting to an LLM without much pre-processing. Additionally, to store the legislature, we need to turn each one into a Document class object used by LlamaIndex that will later be transformed into nodes, embedded, and stored in a Chroma database instance.

In [None]:
from IPython.display import Markdown, display

from llama_index.core import Document, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.schema import MetadataMode, TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore

from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings

import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings

from google.api_core import retry

## Retrieving The Text Associated with Each Law
The methods in the next cell request the URL for the legislative texts and turns each legislative dictionary into a Document for further processing by LlamaIndex and Chroma.

In [None]:
def get_law_doc(
    name: str,
    law: dict,
    pattern
) -> None:
    """
    Download the html text of the speficied legislature and clean it by removing html tags,
    underscores that act as separators, additional white space, and new lines.
    """
    # remove the underscores and newlines and spaces to save on token usage
    def req_and_clean_text(url: str) -> str:

        r = requests.get(url)
        text = r.text
        # remove html tags
        cleantext = re.sub(pattern, '', text)
        # remove underscore separators
        cleantext = re.sub("_", "", cleantext)
        # replace extra white space with single space
        cleantext = re.sub(' +', ' ', cleantext)
        # replace newline with space
        cleantext = re.sub('\n', ' ', cleantext)
        return cleantext

    BILL_PATH = "bill"
    CONGRESS = law['congress']
    CH = law['chamber']
    BILL_NUM = law['num']
    endpoint = f"{BILL_PATH}/{CONGRESS}/{CH}/{BILL_NUM}/text"
    data, _ = client.get(endpoint)
    text = data['textVersions'][0]['formats'][0] # assuming formatted txt always first
    if text['type'] == "Formatted Text": # this is formatted html
        url = text['url']
        return req_and_clean_text(url=url)
    return None

# issue is that we might have too much txt to store on ram so we want to chunk into 5000
# not an issue for passed laws in 117 & 118 congress so ignore for now
# create and return documents here
def create_documents(
    laws: dict,
    save_laws: bool,
    out_dir: str,
    law_file_name: str
) -> List[Document]:
    """
    Request and turn legislature dicts into documents. Option to save the dict of legislature
    with text to a json file.
    """
    # as per recommendation from @freylis, compile once only
    CLEANR = re.compile('<.*?>') # remove html tags
    docs = []
    for n, d in laws.items():
        txt = get_law_doc(name=n, law=d, pattern=CLEANR)
        if txt is None:
            del laws[n]
        else:
            d['doc'] = txt
            # llamaindex document creation to append to all docs for vector db
            document = Document(
                text=txt,
                metadata={
                    "index_id": str(d["congress"]) + "-" + d["chamber"] + "-" + str(d["num"]),
                    "title": n,
                    "congress": d["congress"],
                    "chamber": d["chamber"],
                    "bill_num": d["num"],
                    "latest_action_date": d["latest_action_date"],
                    "lates_action": d['latest_action']
                },
                excluded_llm_metadata_keys=["index_id", "congress", "chamber", "bill_num", "latest_action_date"],
                metadata_seperator="::",
                metadata_template="{key}=>{value}",
                text_template="Metadata: {metadata_str}\n-----\nContent: {content}",
            )
            docs.append(document)

    # saving laws as json to directory
    if save_laws:
        with open(os.path.join(out_dir, f"{law_file_name}.json"), 'w') as fout:
            json.dump(laws, fout)

    return docs

def dict_to_docs(laws: dict) -> List[Document]:
    """
    Input a dict of laws with titles as keys and values as dicts of metadata and text to get
    a list of LlamaIndex Document class items.
    """
    laws_docs = []
    for n, d in laws.items():
        document = Document(
                    text=d['doc'],
                    metadata={
                        "index_id": str(d["congress"]) + "-" + d["chamber"] + "-" + str(d["num"]),
                        "title": n,
                        "congress": d["congress"],
                        "chamber": d["chamber"],
                        "bill_num": d["num"],
                        "latest_action_date": d["latest_action_date"],
                        #"lates_action": d['latest_action']
                    },
                    excluded_llm_metadata_keys=["index_id", "congress", "chamber", "bill_num", "latest_action_date"],
                    metadata_seperator="::",
                    metadata_template="{key}=>{value}",
                    text_template="Metadata: {metadata_str}\n-----\nContent: {content}",
                )
        laws_docs.append(document)

    return laws_docs

In [None]:
# Creating documents w/ text of laws retrieved from laws dict

# passed_laws_docs = create_documents(
#     laws=passed_laws,
#     save_laws=True,
#     out_dir="/kaggle/working",
#     law_file_name="passed_laws_117_118_both_w_txt"
# )

In [None]:
# loading bills saved to disk as a dict
LAW_TEXT_PATH = ""

with open(LAW_TEXT_PATH, "r") as fop:
    passed_laws = json.load(fop)

In [None]:
passed_laws_docs = dict_to_docs(laws=passed_laws)

In [None]:
law_lengths = []
for n, d in passed_laws.items():
    law_lengths.append(len(d['doc'].split(" ")))

In [None]:
avg_law_length = sum(law_lengths) / len(law_lengths)

print(f"Average law length by number of words: {avg_law_length}")

## Creating Chroma Vector DB and LlamaIndex Index
The first step in creating a LlamaIndex for RAG is to initialize a storage instance using Chroma. Once the vector store is specified, use VectorStoreIndex class from LlamaIndex and pass in a list of Documents. There is also the requirment to select a model to embed the text and in this case we will be using BAAI's BGE 1.5 large hosted on huggingface as it is a open source model that has shown to perform well according the huggingface's MTEB leaderboard that ranks embedding models. Unfortunately, LlamIndex isn't completely integrated with Google's AI Studio and a embedding model must be given to create a vector index store whereas Google's AI Studio only has a method to return embeddings. Find out more about BAAI BGE [here at huggingface](https://huggingface.co/BAAI/bge-large-en-v1.5).

In [None]:
# create vector db and index from docs
def create_and_save_index(
    docs: Union[Document, TextNode],
    huggingface_model: str,
    db_name: str = "passed_laws",
    db_path: str = "./passed_laws_db",
    persist_idx_dir: str = "./passed_laws_index",
    save_idx: bool = True
) -> VectorStoreIndex:

    embed_model = HuggingFaceEmbedding(model_name=huggingface_model)
    chroma_client = chromadb.PersistentClient(path=db_path)
    if isinstance(docs[0], Document):
        #embed_fn = GeminiEmbeddingFunction()
        #embed_fn.document_mode = True
        chroma_collection = chroma_client.get_or_create_collection(
            name=db_name,
            #embedding_function=embed_fn
        )
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        index = VectorStoreIndex.from_documents(
            docs, storage_context=storage_context, embed_model=embed_model
        )
    else:
        chroma_client = chromadb.PersistentClient(path=db_path)
        chroma_collection = chroma_client.create_collection(name=db_name)
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        index = VectorStoreIndex(
            objects=docs, storage_context=storage_context, embed_model=embed_model
        )

    # persist index to disk
    if save_idx:
        index.storage_context.persist(persist_dir=persist_idx_dir)

    return index

### Document Index

Uncomment the next cell with GPU to embed and create the document index.

In [None]:
# Creating and saving source documents into vector db

passed_laws_index = create_and_save_index(
    docs=passed_laws_docs,
    huggingface_model="BAAI/bge-large-en-v1.5",
    db_name="passed_laws_117_118_both",
    db_path="./passed_laws_117_118_both_db",
    persist_idx_dir="./passed_laws_117_118_both_index",
    save_idx=True
)

In [None]:
# load document index from disk
# SAVED_LAW_DB = ""

# db2 = chromadb.PersistentClient(path=SAVED_LAW_DB)
# chroma_collection = db2.get_or_create_collection("passed_laws_117_118_both")
# vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
# passed_laws_index = index = VectorStoreIndex.from_vector_store(
#     vector_store,
#     llm='local',
#     embed_model=embed_model,
# )

## Summary Index for Structured Hierarchical Retrieval
To enhance retrieval of relevant documents we employ what LlamaIndex refers to as [structured hierarchical retrieval](https://docs.llamaindex.ai/en/stable/examples/query_engine/multi_doc_auto_retrieval/multi_doc_auto_retrieval/)(link to LlamaIndex tutorial). The core idea is to use a LLM to create summaries of each document, embed the summaries and form a summary index. Then a retriever is created with the summary index along with the original document index and query strings are matched against the summaries instead of the entire document. Yet, the document will still be pulled as nodes for answering. The advantage of this method is that documents can be immensely long, longer than even Gemini's context window allows for, and that affects the quality of the matches. When documents are too long they usually have to be chunked and that leads to the issue of having to find the similarity of the query with each chunk, and then findind the original parent document and pulling that.

In [None]:
import google.generativeai as genai

from llama_index.core import SummaryIndex
from llama_index.llms.gemini import Gemini
from llama_index.core.schema import IndexNode
from llama_index.core.vector_stores import (
    FilterOperator,
    MetadataFilter,
    MetadataFilters,
)

In [None]:
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
genai.configure(api_key=GOOGLE_API_KEY)

### Summary Index
Here we are leveraging Gemini's long context window to create summaries. The majority of the legislative text are able to be summarized without chunking. It is important to note there are also limits induced by Google AI Studio's API. Safeguards have been built into the methods in the next cell to account for these limits As of now any document that is too much for Gemini's context window. Regardless here are the limits of the **free tiers** for future use:
- Flash:
    - 1500 requests/day
    - 15 requests/minute
    - 1,000,000 tokens/minute
- Pro:
    - 50 requests/day
    - 2 requests/minute
    - 32,000 tokens/minute

In [None]:
def process_doc(doc, doc_index, include_summary: bool = True):
    """Process doc."""
    new_metadata = doc.metadata
    doc.text = " ".join([t for t in doc.text.split(" ") if t != ''])
    print(f"Length of text after join: {len(doc.text.split(' '))}")
    # now extract out summary
    summary_index = SummaryIndex.from_documents([doc])
    query_str = "Give a concise summary of this text in five sentences."
    if len(doc.text.split(" ")) > 100_000:
        query_engine = summary_index.as_query_engine(
            llm=Gemini(model="models/gemini-1.5-pro")
        )
    else:
        query_engine = summary_index.as_query_engine(
            llm=Gemini(model="models/gemini-1.5-flash")
        )
    try:
        summary_txt = query_engine.query(query_str)
    except Exception as e:
        print(e)
        return None
    summary_txt = str(summary_txt)

    index_id = doc.metadata["index_id"]
    # filter for the specific doc id
    filters = MetadataFilters(
        filters=[
            MetadataFilter(
                key="index_id", operator=FilterOperator.EQ, value=index_id
            ),
        ]
    )
    # might get an error here due to doc.id_
    # create an index node using the summary text
    try:
        index_node = IndexNode(
            text=summary_txt,
            metadata=new_metadata,
            obj=doc_index.as_retriever(filters=filters),
            index_id=doc.id_,
        )
    except Exception as e:
        print(f"Tried creating index_node w/ exception: {e}")

    return index_node


def process_docs(docs, doc_index):
    """Process metadata on docs."""

    index_nodes = []
    tokens_processed = 0
    #model = genai.GenerativeModel(model_name='gemini-1.5-flash-002')
    for i, doc in enumerate(docs):
        print(f"document {i} in passed list")
        # limit is 15 requests per minute for gemini flash and 1 mil tokens per minute
        # every 4 characters is a token according to google gemini, so set a low num tokens
        tokens_processed += len(doc.text.split(" "))
        print(f"Number of words: {len(doc.text.split(' '))}")
        if tokens_processed % 200_000 == 0:
            time.sleep(70)
        if i % 10 == 0:
            time.sleep(70)
        node = process_doc(doc, doc_index=doc_index)
        if node:
            index_nodes.append(node)

    return index_nodes

### Create Summary Index Nodes
Uncomment and run the next two nodes.

In [None]:
summary_index_nodes = process_docs(passed_laws_docs, passed_laws_index)

In [None]:
# create summary index from nodes

summarized_law_index = create_and_save_index(
    docs=summary_index_nodes,
    huggingface_model="BAAI/bge-large-en-v1.5",
    db_name="summary_passed_laws",
    db_path="./summary_passed_laws_db",
    persist_idx_dir="./summary_passed_laws_index",
    save_idx=True
)

In [None]:
# load summary index from disk

# SUMMARY_DB_PATH = ""

# db3 = chromadb.PersistentClient(path=SUMMARY_DB_PATH)
# chroma_collection = db3.get_or_create_collection("summary_passed_laws")
# vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
# summarized_law_index = index = VectorStoreIndex.from_vector_store(
#     vector_store,
#     llm='local',
#     embed_model=embed_model,
# )

### Vector Store Info
This class can be passed to a retriever created from a index so that LLMs can use document metadata to assist in more accurate retrieval.

In [None]:
# schema setup
from llama_index.core.vector_stores import MetadataInfo, VectorStoreInfo

vector_store_info = VectorStoreInfo(
    content_info="Passed Laws",
    metadata_info=[
        MetadataInfo(
            name="title",
            description="Name of the bill.",
            type="string",
        ),
        MetadataInfo(
            name="congress",
            description="Number of the congressional session that relates to year.",
            type="integer",
        ),
        MetadataInfo(
            name="chamber",
            description="Congressional body that introduced the bill. Either House or Senate.",
            type="string",
        ),
        MetadataInfo(
            name="bill_num",
            description="The number assigned to the bill.",
            type="integer",
        ),
        MetadataInfo(
            name="latest_action_date",
            description="Date of most recent action on the bill.",
            type="string",
        ),
    ],
)

# Gemini Legislature Agent
Finally we can assemble all of the pieces into a Gemini LLM agent that can answer questions about laws. When we do retrieval it is necessary to set the top_k parameter that controls the number of documents returned. LlamaIndex also implements post-processors enabling the ability to further filter out documents. For our use case we will set the top_k to be large and post-process documents by making sure they have a similarity score of more than 0.2.

In [None]:
from llama_index.core.retrievers import VectorIndexAutoRetriever
from llama_index.core import QueryBundle, Settings
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.response.notebook_utils import display_source_node

In [None]:
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
Settings.llm = None
Settings.embed_model = embed_model

## Agent Workflow
- Init agent and pass either "flash" or "pro" as model str. Also, pass in the document index, summary index, vector_store_info, and path to dict of laws with text.
- You can ask for a summary by calling the summary method. If the returned text are long enough, it will automatically use context caching, otherwise it is a one-off query.
- If you want a longer session, then use start_chat() with a query to match laws to and then you can repeatedly prompt the agent using the ask() method. If a chat was not started, then ask() will be a one-off query.

In [None]:
class LegislatureAgent(object):

    def __init__(self,
                 model: str,
                 legislature_index: VectorStoreIndex,
                 summary_index: VectorStoreIndex,
                 vector_store_info: VectorStoreInfo,
                 law_dict_path: str = "/kaggle/input/passed-legislation-117-118/passed_laws_117_118_both_w_txt.json",
                ):

        # for context caching only able to use stable versions, hence need to specify 00x.
        if model == "flash":
            self.model = "gemini-1.5-flash-002"
        elif model == "pro":
            self.model = "gemini-1.5-pro-002"
        else:
            print("Model not supported. Defaulting to free tier, flash.")
            self.model = "gemini-1.5-flash-002"

        self.document_index = legislature_index
        self.summary_index = summary_index
        self.retriever = VectorIndexAutoRetriever(
            summary_index,
            vector_store_info=vector_store_info,
            similarity_top_k=1000, # set this high to include as many as possible in case of get all
            empty_query_top_k=1000,  # if only metadata filters are specified, this is the limit
            verbose=True,
        )
        with open(law_dict_path, "r") as fop:
            laws = json.load(fop)
        self.law_docs = dict_to_docs(laws=laws)
        self.chat = None
        self.chat_len = None
        self.chat_start_time = None
        self.agent = None
        self.cache = None

    def start_chat(
        self,
        about: str,
        chat_len: int, #minutes
        temp: float = 0.1,
        sim_cutoff: Optional[float] = None
    ):
        "Initiate chat session."
        self.chat_len = chat_len
        summary_nodes = self._retrieve_summaries(query=about, sim_cutoff=sim_cutoff)
        related_docs, total_len, too_long = self._retrieve_original_doc(nodes=summary_nodes)
        # need to turn law nodes into text docs
        #law_texts, too_long, total_len = self._node_to_text(nodes=law_nodes, return_too_long=True)
        if too_long:
            print(f"Returned texts exceed token limit for {self.model} w/ {total_len} tokens.")
            return None
        self.chat_start_time = time.time()
        # onlay able to cache if num tokens greater than or equal to 32,769
        self._init_agent(total_len, about, temp, related_docs, chat_len)
        self.chat = self.agent.start_chat(history=[])

    def ask(
        self,
        about: str,
        # temp: float = 0.2,
        # sim_cutoff: Optional[float] = None
    ):
        """
        If there was a chat started, then use chat to continue, else init new model and ask a
        standalone question.
        """

        if self.chat:
            time_since_chat_start = (time.time() - self.chat_start_time) / 60.0
            if time_since_chat_start >= self.chat_len:
                print("Chat time expired based on context cache. Please start a new chat session.")
                self._reset_chat()
                return None
            print(f"Time spent since chat start: {0:.2f}".format(time_since_chat_start))
            response = self.chat.send_message(about)
            print(response.text)
        else:
            response = self.agent.generate_content(about)
            print(response.text)

    def summary(
        self,
        about: str,
        chat_len: int,
        temp: float = 0.1,
        num_sentences: int = 5,
        sim_cutoff: Optional[float] = None
    ):
        """
        Putting all the laws into Gemini at once to create a single summary.
        """
        summary_nodes = self._retrieve_summaries(query=about, sim_cutoff=sim_cutoff)
        # need to turn law nodes into text docs
        law_texts, too_long, total_len = self._node_to_text(
            nodes=summary_nodes,
            return_too_long=True
        )
        if too_long:
            print(f"Returned texts exceed token limit for {self.model} w/ {total_len} tokens.")
            return None

        self._init_agent(
            total_len=total_len,
            about=about,
            temp=temp,
            text_content=law_texts,
            chat_len=chat_len)

        if self.cache:
            response = self.agent.generate_content(
                f"Give a {num_sentences} summary of the text in cached content."
            )
        else:
            law_texts.insert(0, f"Give a {num_sentences} summary of the given text.")
            response = self.agent.generate_content(
                law_texts
            )
        print(response.text)


    def _init_agent(
        self,
        total_len: int,
        about: str,
        temp: float,
        text_content: List[str],
        chat_len: int
    ):
        "Inits agent based on context caching or not."
        # onlay able to cache if num tokens greater than or equal to 32,769
        if total_len >= 20_000:
            self.cache = genai.caching.CachedContent.create(
                model=self.model,
                display_name=f'{about}', # used to identify the cache
                system_instruction=(
                    'You are an expert in US legislature, and your job is to answer '
                    'the user\'s query based on the text file you have access to.'
                ),
                # generation_config=genai.GenerationConfig(
                #     temperature=temp,
                # ),
                contents=text_content,
                ttl=datetime.timedelta(minutes=chat_len),
            )
            model = genai.GenerativeModel.from_cached_content(cached_content=self.cache)
        else:
            model = genai.GenerativeModel(
                self.model,
                system_instruction=(
                    'You are an expert in US legislature, and your job is to answer '
                    'the user\'s query based on the text file you have access to.'
                ),
                generation_config=genai.GenerationConfig(
                    temperature=temp,
                )
            )
        self.agent = model

    def _retrieve_summaries(self, query: str, sim_cutoff: Optional[float] = None):
        """
        Returns nodes of documents using cosine similarity based on summary index summaries compared
        to the query.
        """

        nodes = self.retriever.retrieve(QueryBundle(query))
        if sim_cutoff:
            postprocessor = SimilarityPostprocessor(similarity_cutoff=sim_cutoff)
            nodes = postprocessor.postprocess_nodes(nodes)
        return nodes

    def _node_to_text(self, nodes: TextNode, return_too_long: bool = False,):
        """
        Intakes llamaindex nodes and returns the associated text along with number of words and
        whether or not the number of words may exceed the context window of Gemini.
        """
        total_len = 0
        texts = []
        for n in nodes:
            txt = n.text
            total_len += len(txt.split(" "))
            texts.append(txt)
        if (
            (("flash" in self.model) & (total_len > 300_000))
            |
            (("pro" in self.model) & (total_len > 1_000_000))
            ):
            too_long = True
        else:
            too_long = False
            return texts, too_long, total_len
        return texts

    def _retrieve_original_doc(self, nodes: TextNode):
        "Get the original law text, not the summaries."
        total_len = 0
        docs = []
        indices = [n.metadata['index_id'] for n in nodes]
        for d in self.law_docs:
            if d.metadata['index_id'] in indices:
                txt = d.text
                total_len += len(txt.split(" "))
                docs.append(txt)
        if (
            (("flash" in self.model) & (total_len > 300_000))
            |
            (("pro" in self.model) & (total_len > 1_000_000))
            ):
            too_long = True
        else:
            too_long = False

        return docs, total_len, too_long

    def _reset_chat(self):

        self.chat = None
        self.chat_len = None
        self.chat_start_time = None


# Example: Individual Concerned About Medicare

In [None]:
law_bot = LegislatureAgent(
    model="flash",
    legislature_index=passed_laws_index,
    summary_index=summarized_law_index,
    vector_store_info=vector_store_info,
    )

### Lets start by asking for a summary of medicare laws.

In [None]:
law_bot.summary(
    about="healthcare impacting medicare recipients.",
    chat_len=15,
    temp=0.1,
    num_sentences=5,
    sim_cutoff=0.4
    )

### Now lets try and start a chat with context caching.

In [None]:
law_bot.start_chat(
    about="healthcare impacting medicare recipients.",
    chat_len=30, #minutes
    temp=0.1,
    sim_cutoff=0.4
)

In [None]:
law_bot.ask("Will changes to medicare law impact recipients negatively in the next year?")