In [1]:
import os
import getpass
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from typing import Annotated, List, Tuple, Union, TypedDict, Optional
from langchain_core.tools import tool
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
from langgraph.graph import START, END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from qdrant_client import QdrantClient
from langchain_community.tools.tavily_search import TavilySearchResults

In [2]:
import nest_asyncio

nest_asyncio.apply()

In [14]:
data_folder = 'data'
pdf_files = [os.path.join(data_folder, f) for f in os.listdir(data_folder) if f.endswith('.pdf')]

# Load all PDFs
all_docs = []
for pdf_file in pdf_files:
    loader = PyMuPDFLoader(pdf_file)
    docs = loader.load()
    all_docs.extend(docs)

# Split documents
splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
split_docs = splitter.split_documents(all_docs)
print(f'Total split documents: {len(split_docs)}')

Total split documents: 875


In [17]:
training_documents = split_docs

In [18]:
import uuid

id_set = set()

for document in training_documents:
  id = str(uuid.uuid4())
  while id in id_set:
    id = uuid.uuid4()
  id_set.add(id)
  document.metadata["id"] = id

In [19]:
training_split_documents = training_documents[:700]
val_split_documents = training_documents[700:788]
test_split_documents = training_documents[788:]

In [20]:
from langchain_openai import ChatOpenAI

qa_chat_model = ChatOpenAI(
    model="gpt-4.1-mini",
    temperature=0
)

In [22]:
from langchain_core.prompts import ChatPromptTemplate

qa_prompt = """\
Given the following context, you must generate questions based on only the provided context.

You are to generate {n_questions} questions which should be provided in the following format:

1. QUESTION #1
2. QUESTION #2
...

Context:
{context}
"""

qa_prompt_template = ChatPromptTemplate.from_template(qa_prompt)

In [23]:
question_generation_chain = qa_prompt_template | qa_chat_model

In [25]:
import tqdm
import asyncio

"""
Sample Usage of TQDM:

for i in tqdm.tqdm(range(10)):
  time.sleep(1)
"""

async def create_questions(documents, n_questions):

    questions = {}
    relevant_docs = {}

    for doc in tqdm.tqdm(documents, desc="Generating questions"):
        # Prepare the input for the chain
        input_context = doc.page_content
        doc_id = doc.metadata["id"]

        # Call the question generation chain
        response = await question_generation_chain.ainvoke({"context": input_context, "n_questions": n_questions})

        # Extract questions
        generated_questions = response.content.split("\n")
        generated_questions = [q.strip() for q in generated_questions if q.strip()]

        # Some outputs might be numbered like "1. What is ...?", so clean numbering
        cleaned_questions = []
        for q in generated_questions:
            if q[0].isdigit() and q[1] == '.':
                cleaned_questions.append(q[2:].strip())
            elif q[0].isdigit() and q[1] == ' ':
                cleaned_questions.append(q[1:].strip())
            else:
                cleaned_questions.append(q)

        # Now save each question
        for q in cleaned_questions:
            question_id = str(uuid.uuid4())
            questions[question_id] = q
            relevant_docs[question_id] = [doc_id]

    return questions, relevant_docs

In [26]:
training_questions, training_relevant_contexts = await create_questions(training_split_documents, 2)

Generating questions: 100%|██████████| 700/700 [16:07<00:00,  1.38s/it]


In [27]:
val_questions, val_relevant_contexts = await create_questions(val_split_documents, 2)

Generating questions: 100%|██████████| 88/88 [01:59<00:00,  1.36s/it]


In [28]:
test_questions, test_relevant_contexts = await create_questions(test_split_documents, 2)

Generating questions: 100%|██████████| 87/87 [02:00<00:00,  1.39s/it]


In [29]:
import json

training_corpus = {train_item.metadata["id"] : train_item.page_content for train_item in training_split_documents}

train_dataset = {
    "questions" : training_questions,
    "relevant_contexts" : training_relevant_contexts,
    "corpus" : training_corpus
}

with open("train_dataset.jsonl", "w") as f:
  json.dump(train_dataset, f)    

In [30]:
val_corpus = {val_item.metadata["id"] : val_item.page_content for val_item in val_split_documents}

val_dataset = {
    "questions" : val_questions,
    "relevant_contexts" : val_relevant_contexts,
    "corpus" : val_corpus
}

with open("val_dataset.jsonl", "w") as f:
  json.dump(val_dataset, f)

In [31]:
train_corpus = {test_item.metadata["id"] : test_item.page_content for test_item in test_split_documents}

test_dataset = {
    "questions" : test_questions,
    "relevant_contexts" : test_relevant_contexts,
    "corpus" : train_corpus
}

with open("test_dataset.jsonl", "w") as f:
  json.dump(test_dataset, f)

In [32]:
from sentence_transformers import SentenceTransformer

model_id = "Snowflake/snowflake-arctic-embed-m"
model = SentenceTransformer(model_id)

In [33]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sentence_transformers import InputExample

In [34]:
BATCH_SIZE = 10

In [35]:
corpus = train_dataset['corpus']
queries = train_dataset['questions']
relevant_docs = train_dataset['relevant_contexts']

examples = []
for query_id, query in queries.items():
    doc_id = relevant_docs[query_id][0]
    text = corpus[doc_id]
    example = InputExample(texts=[query, text])
    examples.append(example)

In [36]:
loader = DataLoader(
    examples, batch_size=BATCH_SIZE
)

In [37]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [768, 512, 256, 128, 64]
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [38]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

corpus = val_dataset['corpus']
queries = val_dataset['questions']
relevant_docs = val_dataset['relevant_contexts']

evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

In [39]:
EPOCHS = 10

In [40]:
import wandb
wandb.init(mode="disabled")

In [41]:
warmup_steps = int(len(loader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(loader, train_loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='finetuned_arctic_FT',
    show_progress_bar=True,
    evaluator=evaluator,
    evaluation_steps=50
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]



Step,Training Loss,Validation Loss,Cosine Accuracy@1,Cosine Accuracy@3,Cosine Accuracy@5,Cosine Accuracy@10,Cosine Precision@1,Cosine Precision@3,Cosine Precision@5,Cosine Precision@10,Cosine Recall@1,Cosine Recall@3,Cosine Recall@5,Cosine Recall@10,Cosine Ndcg@10,Cosine Mrr@10,Cosine Map@100
50,No log,No log,0.715909,0.886364,0.948864,0.971591,0.715909,0.295455,0.189773,0.097159,0.715909,0.886364,0.948864,0.971591,0.853984,0.815111,0.816803
100,No log,No log,0.806818,0.943182,0.982955,0.994318,0.806818,0.314394,0.196591,0.099432,0.806818,0.943182,0.982955,0.994318,0.909813,0.881534,0.88185
140,No log,No log,0.806818,0.965909,0.982955,1.0,0.806818,0.32197,0.196591,0.1,0.806818,0.965909,0.982955,1.0,0.912373,0.88316,0.88316
150,No log,No log,0.806818,0.954545,0.982955,1.0,0.806818,0.318182,0.196591,0.1,0.806818,0.954545,0.982955,1.0,0.910841,0.881266,0.881266
200,No log,No log,0.806818,0.960227,0.988636,1.0,0.806818,0.320076,0.197727,0.1,0.806818,0.960227,0.988636,1.0,0.913873,0.885006,0.885006
250,No log,No log,0.823864,0.965909,0.982955,1.0,0.823864,0.32197,0.196591,0.1,0.823864,0.965909,0.982955,1.0,0.920502,0.893962,0.893962
280,No log,No log,0.829545,0.965909,0.988636,1.0,0.829545,0.32197,0.197727,0.1,0.829545,0.965909,0.988636,1.0,0.923896,0.898359,0.898359
300,No log,No log,0.8125,0.960227,0.982955,1.0,0.8125,0.320076,0.196591,0.1,0.8125,0.960227,0.982955,1.0,0.915895,0.887784,0.887784
350,No log,No log,0.818182,0.965909,0.982955,1.0,0.818182,0.32197,0.196591,0.1,0.818182,0.965909,0.982955,1.0,0.918467,0.891193,0.891193
400,No log,No log,0.818182,0.960227,0.988636,1.0,0.818182,0.320076,0.197727,0.1,0.818182,0.960227,0.988636,1.0,0.920366,0.893466,0.893466


In [42]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [43]:
hf_username = "shradharp"

In [44]:
import uuid

model.push_to_hub(f"{hf_username}/legal-ft-{uuid.uuid4()}")

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

'https://huggingface.co/shradharp/legal-ft-658ea1b1-1d08-4417-8a4e-4920ae593642/commit/a3be0484739c8e3f29a03b46588341435360007c'

In [45]:
import pandas as pd

from langchain_community.vectorstores import FAISS
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document

In [49]:
def evaluate_openai(
    dataset,
    embed_model,
    top_k=5,
    verbose=False,
):
  corpus = dataset['corpus']
  questions = dataset['questions']
  relevant_docs = dataset['relevant_contexts']
  documents = [Document(page_content=content, metadata={"id": doc_id}) for doc_id, content in corpus.items()]
  vectorstore = FAISS.from_documents(documents, embed_model)

  retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})

  eval_results = []
  for id, question in tqdm.tqdm(questions.items()):
    retrieved_nodes = retriever.invoke(question)
    retrieved_ids = [node.metadata["id"] for node in retrieved_nodes]
    expected_id = relevant_docs[id][0]
    is_hit = expected_id in retrieved_ids
    eval_results.append({"id": id, "question": question, "expected_id": expected_id, "is_hit": is_hit})

  return eval_results

In [50]:
te3_openai = OpenAIEmbeddings(model="text-embedding-3-small")
te3_results = evaluate_openai(test_dataset, te3_openai)

100%|██████████| 174/174 [00:54<00:00,  3.21it/s]


In [51]:
te3_results_df = pd.DataFrame(te3_results)

In [52]:
te3_hit_rate = te3_results_df["is_hit"].mean()
te3_hit_rate

0.9425287356321839

### `Snowflake/snowflake-arctic-embed-m` (base)

In [53]:
from langchain_huggingface import HuggingFaceEmbeddings

huggingface_embeddings = HuggingFaceEmbeddings(model_name="Snowflake/snowflake-arctic-embed-m")
arctic_embed_m_results = evaluate_openai(test_dataset, huggingface_embeddings)

100%|██████████| 174/174 [00:01<00:00, 92.65it/s]


In [54]:
arctic_embed_m_results_df = pd.DataFrame(arctic_embed_m_results)

In [55]:
arctic_embed_m_hit_rate = arctic_embed_m_results_df["is_hit"].mean()
arctic_embed_m_hit_rate

0.6666666666666666

### `Snowflake/snowflake-arctic-embed-m` (fine-tuned)

In [56]:
finetune_embeddings = HuggingFaceEmbeddings(model_name="finetuned_arctic_FT")
finetune_results = evaluate_openai(test_dataset, finetune_embeddings)

Some weights of BertModel were not initialized from the model checkpoint at finetuned_arctic_FT and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 174/174 [00:01<00:00, 101.43it/s]


In [57]:
finetune_results_df = pd.DataFrame(finetune_results)

In [58]:
finetune_hit_rate = finetune_results_df["is_hit"].mean()
finetune_hit_rate

0.9482758620689655

# Vibe Checking RAG pipeline

In [59]:
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = splitter.split_documents(all_docs)
print(f'Total split documents: {len(split_docs)}')

Total split documents: 1815


In [60]:
base_vectorstore = FAISS.from_documents(training_documents, huggingface_embeddings)
base_retriever = base_vectorstore.as_retriever(search_kwargs={"k": 6})

In [61]:
from langchain_core.prompts import ChatPromptTemplate

RAG_PROMPT = """\
You are an expert assistant that answers questions using ONLY the provided CONTEXT.
Do NOT make up any information.

CONTEXT:
{context}

USER QUESTION:
{question}

Instructions:
- If the context fully covers the answer, respond concisely and accurately.
- If the context is missing information needed to answer, respond exactly:
  INSUFFICIENT_CONTEXT
"""

In [62]:
rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT)

In [63]:
rag_llm =  ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0
)

In [64]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel

base_rag_chain = (
    {"context": itemgetter("question") | base_retriever, "question": itemgetter("question")}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"response": rag_prompt_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
)

In [65]:
base_rag_chain.invoke({"question" : "Tell me about Radio in Chicago"})["response"]

'INSUFFICIENT_CONTEXT'

In [66]:
finetune_vectorstore = FAISS.from_documents(training_documents, finetune_embeddings)
finetune_retriever = finetune_vectorstore.as_retriever(search_kwargs={"k": 6})

In [67]:
finetune_rag_chain = (
    {"context": itemgetter("question") | finetune_retriever, "question": itemgetter("question")}
    | RunnablePassthrough.assign(context=itemgetter("context"))
    | {"response": rag_prompt_template | rag_llm | StrOutputParser(), "context": itemgetter("context")}
)

In [68]:
finetune_rag_chain.invoke({"question" : "Tell me about Radio in Chicago?"})["response"]

"Chicago has five 50,000 watt AM radio stations: the Audacy-owned WBBM and WSCR; the Tribune Broadcasting-owned WGN; the Cumulus Media-owned WLS; and the ESPN Radio-owned WMVP. Chicago Public Radio produces nationally aired programs such as PRI's This American Life and NPR's Wait Wait...Don't Tell Me!."