In [11]:
from utils import *

In [12]:
import getpass
import os

if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

In [13]:
desired_sections = ["section_1", "section_12", "section_13", "section_14", "section_15"]

In [14]:
_dataset1 = load_dataset(dataset_name, f"year_{2018}", split="train")
_dataset2 = load_dataset(dataset_name, f"year_{2019}", split="train")
_dataset3 = load_dataset(dataset_name, f"year_{2020}", split="train")

In [15]:
next(iter((set(_dataset1["cik"]).intersection(set(_dataset2["cik"]))).intersection(set(_dataset3["cik"]))))

'37634'

In [16]:
from collections import defaultdict

cik = '788784'
dataset_dict = defaultdict(list)
def add_year_row(_dataset, year):
   d = _dataset.filter(lambda example: example["cik"] == cik)\
               .select_columns(desired_sections)
   d_dict = d[0]
   d_dict.update({"year": year})
   for k, v in d_dict.items():
      dataset_dict[k].append(v)

add_year_row(_dataset1, "2018")
add_year_row(_dataset2, "2019")
add_year_row(_dataset3, "2020")

dataset = Dataset.from_dict(dataset_dict)


##### Chunk

In [17]:
for section in desired_sections:
    dataset = dataset.map(lambda example: {f"{section}_chunked": chunk_text(example[section], chunk_size=50)})

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

In [18]:
dataset

Dataset({
    features: ['section_1', 'section_12', 'section_13', 'section_14', 'section_15', 'year', 'section_1_chunked', 'section_12_chunked', 'section_13_chunked', 'section_14_chunked', 'section_15_chunked'],
    num_rows: 3
})

In [19]:
data_dict_exploded = defaultdict(list)
for item in dataset:
    for section in desired_sections:
        chunks = item[f"{section}_chunked"]
        n_chunks = len(chunks)
        data_dict_exploded["year"] += [item["year"]] * n_chunks
        data_dict_exploded["section"] += [section] * n_chunks
        data_dict_exploded["doc_chunk"] += chunks

dataset_exploded = Dataset.from_dict(data_dict_exploded)

##### Build index

In [20]:
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from uuid import uuid4


documents = [
    Document(
        id=str(i),
        page_content= item["doc_chunk"],
        metadata= {
            "year": item["year"],
            "section": item["section"],
        }
    )
    for i, item in enumerate(dataset_exploded)
]

# uuids = [str(uuid4()) for _ in range(len(documents))]
ids = [str(i) for i in range(len(documents))]

# embeddings = HuggingFaceEmbeddings(model_name=f"sentence-transformers/{SENTENCE_EMBED_MODEL}")
embeddings =  OpenAIEmbeddings(model="text-embedding-3-small")

vector_store = Chroma(
    collection_name="fin_report_store",
    embedding_function=embeddings,
    # persist_directory=v_db_path,  
)
vector_store.add_documents(documents=documents, ids=ids)



['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30',
 '31',
 '32',
 '33',
 '34',
 '35',
 '36',
 '37',
 '38',
 '39',
 '40',
 '41',
 '42',
 '43',
 '44',
 '45',
 '46',
 '47',
 '48',
 '49',
 '50',
 '51',
 '52',
 '53',
 '54',
 '55',
 '56',
 '57',
 '58',
 '59',
 '60',
 '61',
 '62',
 '63',
 '64',
 '65',
 '66',
 '67',
 '68',
 '69',
 '70',
 '71',
 '72',
 '73',
 '74',
 '75',
 '76',
 '77',
 '78',
 '79',
 '80',
 '81',
 '82',
 '83',
 '84',
 '85',
 '86',
 '87',
 '88',
 '89',
 '90',
 '91',
 '92',
 '93',
 '94',
 '95',
 '96',
 '97',
 '98',
 '99',
 '100',
 '101',
 '102',
 '103',
 '104',
 '105',
 '106',
 '107',
 '108',
 '109',
 '110',
 '111',
 '112',
 '113',
 '114',
 '115',
 '116',
 '117',
 '118',
 '119',
 '120',
 '121',
 '122',
 '123',
 '124',
 '125',
 '126',
 '127',
 '128',
 '129',
 '130',
 '131',
 '132',
 '133',
 '134',
 '135',
 '136',
 '137',
 '138'

##### Reranker

In [29]:
def retrieve(query: str, top_k: int=5):
    retriever = vector_store.as_retriever(search_kwargs={"k": 10*top_k})
    docs = retriever.invoke(query)
    doc_contents = [f"The following was documented in year {doc.metadata['year']}: {doc.page_content}" for doc in docs]
    pairs = [[query, doc] for doc in doc_contents]
    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
    ranked_pairs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
    ranked_docs = [r_doc[0] for r_doc in ranked_pairs]
    return ranked_docs[:top_k] if top_k < len(ranked_docs) else ranked_docs

##### Generate Validation Set

In [30]:

from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo")
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Create two valid questions solely based on the context below:\n\n{context}. Return as a list of questions"),
    ]
)
chain = create_stuff_documents_chain(llm, prompt)

def get_questions(doc_idx:tuple):
    context = [documents[i] for i in range(doc_idx[0], doc_idx[1])]
    ans = chain.invoke({"context": context}).split("\n")
    return [a.split(".")[1].strip() for a in ans]


qa_dict = []
for range_ in [(50,55), (600,605), (900, 905)]:
    qa_dict.append({
        "questions": get_questions(range_),
        "contexts": documents[range_[0]:range_[1]]
    })
# qa_dict

##### Validate

In [57]:
from collections import Counter

def get_recall_by_ids(pred, true):
    return len(set(pred).intersection(set(true))) / len(true)

def get_recall_by_years(pred, true):
    pred_ctr = Counter(pred)
    true_ctr = Counter(true)
    tp = 0
    for k in pred_ctr:
        if k in true_ctr:
            tp += min(pred_ctr[k], true_ctr[k])
    return tp / len(true)

def get_report(query, true_chunks, retrieved):
    true_chunks_ids = [doc.id for doc in true_chunks]
    true_chunks_years = [doc.metadata["year"] for doc in true_chunks]
    retireved_ids = [doc.id for doc in retrieved]
    retireved_years = [doc.metadata["year"] for doc in retrieved]
    report = f"""
                Question: {query}

                Retrieved document ids: {retireved_ids}
                Retrieved document year: {retireved_years}

                True document ids: {true_chunks_ids}
                True document year: {true_chunks_years}

                Recall by doc_ids: {get_recall_by_ids(retireved_ids, true_chunks_ids)}
                Recall by doc_years: {get_recall_by_years(retireved_years, true_chunks_years)}

                """
    return report

In [58]:
# First questions set - year 2018
queries = qa_dict[0]["questions"]
true_chunks = qa_dict[0]["contexts"]

#question 1:
query = f"In the year 2018: {queries[0]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)

print(report)



                Question: In the year 2018: How much has PSE&G invested in replacing gas mains and service lines with stronger, more durable plastic piping to reduce the potential for leaks and release of methane gas?

                Retrieved document ids: ['53', '50', '51', '45', '5']
                Retrieved document year: ['2018', '2018', '2018', '2018', '2018']

                True document ids: ['50', '51', '52', '53', '54']
                True document year: ['2018', '2018', '2018', '2018', '2018']

                Recall by doc_ids: 0.6
                Recall by doc_years: 1.0

                


In [59]:
#question 2:
query = f"In the year 2018: {queries[1]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)

print(report)



                Question: In the year 2018: What additional improvements will be made to the gas system under PSE&G's Gas System Modernization Program II, which is set to invest $1

                Retrieved document ids: ['53', '47', '5', '54', '45']
                Retrieved document year: ['2018', '2018', '2018', '2018', '2018']

                True document ids: ['50', '51', '52', '53', '54']
                True document year: ['2018', '2018', '2018', '2018', '2018']

                Recall by doc_ids: 0.4
                Recall by doc_years: 1.0

                


In [60]:
# Second questions set - year 2019
queries = qa_dict[1]["questions"]
true_chunks = qa_dict[1]["contexts"]

#question 1:
query = f"In the year 2019: {queries[0]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)

print(report)


                Question: In the year 2019: What is the significance of filing Exhibit 4 with the Quarterly Report on Form 10-Q for the quarter ended June 30, 2013, and how does it impact the information provided in the report?

                Retrieved document ids: ['601', '273', '872', '607', '558']
                Retrieved document year: ['2019', '2018', '2020', '2019', '2019']

                True document ids: ['600', '601', '602', '603', '604']
                True document year: ['2019', '2019', '2019', '2019', '2019']

                Recall by doc_ids: 0.2
                Recall by doc_years: 0.6

                


In [62]:
#question 2:
query = f"In the year 2019: {queries[1]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)
print(report)


                Question: In the year 2019: How does the incorporation of Exhibit 4a(22) in the Quarterly Report on Form 10-Q for the quarter ended September 30, 2014, influence the overall understanding of the financial status of the company?

                Retrieved document ids: ['604', '276', '874', '566', '523']
                Retrieved document year: ['2019', '2018', '2020', '2019', '2019']

                True document ids: ['600', '601', '602', '603', '604']
                True document year: ['2019', '2019', '2019', '2019', '2019']

                Recall by doc_ids: 0.2
                Recall by doc_years: 0.6

                


In [63]:
# Second questions set - year 2020
queries = qa_dict[2]["questions"]
true_chunks = qa_dict[2]["contexts"]

#question 1:
query = f"In the year 2020: {queries[0]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)

print(report)


                Question: In the year 2020: What is the significance of the $8 million included in the financial report due to the adoption of ASU 2016-13 by PSEG POWER LLC?

                Retrieved document ids: ['900', '898', '713', '915', '718']
                Retrieved document year: ['2020', '2020', '2020', '2020', '2020']

                True document ids: ['900', '901', '902', '903', '904']
                True document year: ['2020', '2020', '2020', '2020', '2020']

                Recall by doc_ids: 0.2
                Recall by doc_years: 1.0

                


In [64]:
#question 2:
query = f"In the year 2020: {queries[1]}"
retrieved = retrieve(query)
report = get_report(query, true_chunks, retrieved)
print(report)


                Question: In the year 2020: How does the reduction of reserves to appropriate levels and removal of obsolete inventory impact the financial statements of Public Service Enterprise Group Incorporated?

                Retrieved document ids: ['898', '900', '297', '296', '624']
                Retrieved document year: ['2020', '2020', '2018', '2018', '2019']

                True document ids: ['900', '901', '902', '903', '904']
                True document year: ['2020', '2020', '2020', '2020', '2020']

                Recall by doc_ids: 0.2
                Recall by doc_years: 0.4

                
