In [1]:
from langchain.embeddings.base import Embeddings
from sentence_transformers import SentenceTransformer

class SentenceTransformerWrapper(Embeddings):
    def __init__(self, model_name: str):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, docs: list[str]) -> list[list[float]]:
        return self.model.encode(docs, normalize_embeddings=True).tolist()

    def embed_query(self, query: str) -> list[float]:
        return self.model.encode(query, normalize_embeddings=True).tolist()


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import logging
import re
from typing import List, Tuple
from pathlib import Path
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
    PyPDFLoader,
    Docx2txtLoader,
    TextLoader,
)
from langchain.schema import Document
from langchain_chroma import Chroma
from sentence_transformers import SentenceTransformer
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration

class RAGFusionProcessor:
    DEFAULTS = {
        'size': 512,
        'overlap': 50,
        'model': "sentence-transformers/all-MiniLM-L6-v2",
        'name': "Docs",
        'level': logging.INFO,
        'depth': 10,
        'factor': 0.5,
        't5_model': '/home/shubham/.cache/huggingface/hub/models--mrm8488--t5-base-finetuned-question-generation-ap/snapshots/c81cbaf0ec96cc3623719e3d8b0f238da5456ca8',  # T5 model to be used for summarization or question-answering
    }

    LOADERS = {
        '.pdf': PyPDFLoader,
        '.docx': Docx2txtLoader,
        '.txt': TextLoader
    }

    def __init__(self, **kwargs):
        self.cfg = {**self.DEFAULTS, **kwargs}
        logging.basicConfig(level=self.cfg['level'], format='%(asctime)s - %(levelname)s - %(message)s')
        self.log = logging.getLogger(__name__)
        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.cfg['size'],
            chunk_overlap=self.cfg['overlap']
        )
        self.wrapper = SentenceTransformerWrapper(self.cfg['model'])
        self.model = self.wrapper.model
        self.store = Chroma(
            collection_name=self.cfg['name'],
            embedding_function=self.wrapper
        )
        ids = self.store._collection.get()['ids']
        if ids:
            self.store._collection.delete(ids)

        # Initialize the T5 model and tokenizer
        self.t5_tokenizer = T5Tokenizer.from_pretrained(self.cfg['t5_model'])
        self.t5_model = T5ForConditionalGeneration.from_pretrained(self.cfg['t5_model'])

    def __enter__(self):
        return self

    def __exit__(self, typ, val, tb):
        self.store = None

    @staticmethod
    def clean(text: str) -> str:
        return re.sub(r'\s+', ' ', re.sub(r'[^\w\s\.,]', '', text)).strip()

    def load(self, path: str) -> List[str]:
        path = Path(path)
        loader = self.LOADERS.get(path.suffix.lower())
        docs = loader(str(path)).load()
        return [self.clean(doc.page_content) for doc in docs]

    def process(self, docs: List[str]) -> List[str]:
        clean_docs = [self.clean(doc) for doc in docs]
        return self.splitter.split_text(" ".join(clean_docs))

    def embed(self, chunks: List[str], src: str = "default") -> None:
        docs = [Document(page_content=chunk, metadata={"src": src}) for chunk in chunks]
        self.store.add_documents(docs)
        print(f"Total records: {self.store._collection.count()}")

    def search(self, qry: str, lim: int = 5) -> List[Tuple[Document, float]]:
        results = self.store.max_marginal_relevance_search(
            qry, k=lim, fetch_k=self.cfg['depth'], lambda_mult=self.cfg['factor']
        )
        return [(doc, 0.0) for doc in results]

    def rank(self, qry: str, docs: List[Document]) -> List[Tuple[Document, float]]:
        embed_qry = self.model.encode(qry, normalize_embeddings=True)
        scores = [
            np.dot(embed_qry, self.model.encode(doc.page_content, normalize_embeddings=True))
            for doc in docs
        ]
        ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)

        for doc, score in ranked[:5]: 
            summary = self.generate_summary(doc.page_content)
            print(f"Summary of chunk from {doc.metadata.get('src')}: {summary}")

        return ranked

    def generate_summary(self, text: str) -> str:
        inputs = self.t5_tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
        outputs = self.t5_model.generate(inputs, max_length=150, num_beams=4, early_stopping=True)
        summary = self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary


def main():
    import os
    os.environ["CHROMA_TELEMETRY_ENABLED"] = "false"

    with RAGFusionProcessor() as proc:
        path = "sample3.pdf"
        raw_docs = proc.load(path)
        chunks = proc.process(raw_docs)
        proc.embed(chunks, src=path)
        qry = "tell me Service Subscriber charges"
        results = proc.search(qry)
        ranked = proc.rank(qry, [doc for doc, _ in results])

        for doc, score in ranked:
            print(f"Chunk: {doc.page_content}, Source: {doc.metadata.get('src')}, Score: {score}")


if __name__ == "__main__":
    main()


2024-12-17 23:31:01,311 - INFO - Use pytorch device_name: cpu
2024-12-17 23:31:01,311 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
Batches: 100%|██████████| 4/4 [00:01<00:00,  3.60it/s]


Total records: 116


Batches: 100%|██████████| 1/1 [00:00<00:00, 102.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 138.60it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 63.58it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 48.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 44.40it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 82.98it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 60.25it/s]


Answer: question: What is the quotient obtained by dividing the Net Retail Price paid by each Service Subscriber per month?
