In [None]:
import os, asyncio, json, tqdm, dotenv
import numpy as np
import pandas as pd
from langchain import llm_cache
from langchain.chat_models import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.schema import Document
from langchain.document_loaders import CSVLoader
from langchain.cache import InMemoryCache
from typing import Tuple, Dict, Iterable, Callable
dotenv.load_dotenv()

# cache llm calls (faster when repeating queries and prompts)
llm_cache = InMemoryCache()
os.environ["PRETRAINED_SUMMARY_MODEL_NAME"]='gpt-3.5-turbo-16k'

In [None]:
DATA_PATH = "data/ctg-studies.csv"
df = pd.read_csv(DATA_PATH)
df.head(3)

In [None]:
df.columns

In [None]:
df.shape

In [None]:
df = df.replace(np.nan, "unknown")
PROCESSED_DATA_PATH  = "data/processed-ctg-studies.csv"
df.to_csv(PROCESSED_DATA_PATH, index=False)

In [None]:
doc_loader = CSVLoader(PROCESSED_DATA_PATH, encoding="utf8")
doc_splitter = RecursiveCharacterTextSplitter(chunk_size=1_024, chunk_overlap=100)
docs = doc_loader.load()
docs = doc_splitter.split_documents(docs)
print(f"number of documents: {len(docs)}")
docs[0]

In [None]:
class DataQueue:
    def __init__(self):
        self.data = []

    def __len__(self):
        return len(self.data)
        
    def enqueue(self, entry: Dict[str, str]):
        self.data.append(entry)
    
    def dequeue(self, idx: int):
        self.data.pop(idx)
    
    def __repr__(self):
        return f"{self.data}"
    
    def __getiitem__(self, idx: int):
        return self.data[idx]

In [None]:
def summarize(summary_chain: BaseCombineDocumentsChain, doc_id: int, document: Document, queue: DataQueue):
    summary = summary_chain.run([document])
    entry = dict(id=doc_id, document=document.page_content, summary=summary)
    queue.enqueue(entry)


async def summary_coroutine(f: Callable, args: Tuple, semaphore: asyncio.Semaphore):
    running_loop = asyncio.get_running_loop()
    summarize_func = lambda : f(*args)
    async with semaphore:
        await running_loop.run_in_executor(None, summarize_func, )


async def main(
    docs: Iterable[Document], 
    summary_chain: BaseCombineDocumentsChain, 
    queue: DataQueue, 
    n_concurrency: int=10):
    tasks = []
    semaphore = asyncio.Semaphore(value=n_concurrency)
    for i, doc in enumerate(docs):
        task = summary_coroutine(summarize, args=(summary_chain, i, doc, queue), semaphore=semaphore)
        tasks.append(task)
    [
        await _ for _ in tqdm.tqdm(asyncio.as_completed(tasks))
    ]

    

In [None]:
llm = ChatOpenAI(model=os.environ["PRETRAINED_SUMMARY_MODEL_NAME"], temperature=0.2)
summary_chain = load_summarize_chain(llm, chain_type="map_reduce")
queue = DataQueue()

await main(docs, summary_chain, n_concurrency=20, queue=queue)

In [None]:
JSON_DATA_PATH = "data/doc_summary_pair.json"
queue.data = sorted(queue.data, key= lambda x : x["id"])
with open(JSON_DATA_PATH, "w") as f:
    json.dump(queue.data, f, indent=4)
f.close()