In [None]:
from torch.nn import DataParallel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import chromadb
from chromadb.config import Settings
import pymupdf
import re
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import pandas as pd
import os

In [2]:
# Set environment variable
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# embedding_model = AutoModel.from_pretrained("nvidia/NV-Embed-v2", trust_remote_code=True)
# embedding_model.to(device)

dir_path = "/secure/shared_data/rag_embedding_model"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("nvidia/NV-Embed-v2", trust_remote_code=True, cache_dir=dir_path)

# Load the embedding model
embedding_model = AutoModel.from_pretrained("nvidia/NV-Embed-v2", trust_remote_code=True, cache_dir=dir_path)

# Move the model to the device (GPU or CPU)
embedding_model.to(device)
embedding_model.half()

In [None]:
# # Each query needs to be accompanied by an corresponding instruction describing the task.
# task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}

# query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
# queries = [
#     'are judo throws allowed in wrestling?', 
#     'how to become a radiology technician in michigan?'
#     ]

# # No instruction needed for retrieval passages
# passage_prefix = ""
# passages = [
#     "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
#     "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
# ]

# # get the embeddings
# max_length = 32768
# query_embeddings = embedding_model.encode(queries, instruction=query_prefix, max_length=max_length)
# passage_embeddings = embedding_model.encode(passages, instruction=passage_prefix, max_length=max_length)

# # normalize embeddings
# query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
# passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)

# scores = (query_embeddings @ passage_embeddings.T) * 100
# print(scores.tolist())


In [6]:
def load_pdf(files="/home/yl3427/cylab/rag_tnm/selfCorrectionAgent/ajcc_7thed_cancer_staging_manual.pdf"):
    if not isinstance(files, list):
        files = [files]  

    documents = []
    for file_path in files:
        doc = pymupdf.open(file_path)
        text = ""
        
        for page in doc:
            text += page.get_text()

        text = group_broken_paragraphs(text)
        text = clean_extra_whitespace_within_paragraphs(text)

        document = Document(
            page_content=text,
            metadata={"source": file_path}
        )
        documents.append(document)


    return documents

def clean_extra_whitespace_within_paragraphs(text):
    return re.sub(r'[ \t]+', ' ', text)

def group_broken_paragraphs(text):
    text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
    # text = re.sub(r"\n{2,}", "\n", text)
    return text

In [7]:
documents = load_pdf()

In [None]:
len(documents)

In [9]:
def plot_docs_tokens(docs_processed, tokenizer):
    lengths = [len(tokenizer.encode(doc.page_content)) for doc in docs_processed]
    print(f"Maximum sequence length in chunks: {max(lengths)}")
    fig = pd.Series(lengths).hist()
    plt.title("Distribution of document lengths in the knowledge base (in count of tokens)")
    plt.show()

def split_documents(
    chunk_size: int,
    knowledge_base,
    tokenizer
):
    text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
        separators = ["\n\n", "\n", '(?<=[.?"\s])\s+', " "],
        tokenizer=tokenizer,
        chunk_size=chunk_size,
        chunk_overlap=0,
        add_start_index=True,
        strip_whitespace=True,
        is_separator_regex=True
    )

    docs_processed = (text_splitter.split_documents([doc]) for doc in knowledge_base)

    unique_texts = set()
    docs_processed_unique = []
    for doc_chunk in docs_processed:
        for doc in doc_chunk:
            if doc.page_content not in unique_texts:
                unique_texts.add(doc.page_content)
                docs_processed_unique.append(doc)

    return docs_processed_unique

In [10]:
docs_processed = split_documents(
    chunk_size = 512, 
    knowledge_base = documents,
    tokenizer = tokenizer,
)

In [None]:
plot_docs_tokens(docs_processed, tokenizer)

In [None]:
print(f"Number of chunks: {len(docs_processed)}")

In [13]:
torch.cuda.empty_cache()

In [14]:
MAX_LENGTH = 512

def embed_docs_in_chroma(docs, collection):
    
    pbar = tqdm(total=len(docs))

    for doc in docs:
        id = str(doc.metadata["start_index"])
        doc_text = doc.page_content

        with torch.no_grad():
            embeddings = embedding_model.encode([doc_text], max_length=MAX_LENGTH)
            embeddings = embeddings.detach().cpu().numpy().tolist()

        collection.add(
            embeddings=embeddings,
            # metadatas=[{}],
            documents=[doc_text],
            ids=[id]
        )
        pbar.update(1)
        torch.cuda.empty_cache()
        
    pbar.close()

In [None]:
import chromadb
from chromadb.config import Settings
# client = chromadb.Client()
client = chromadb.PersistentClient(path="/home/yl3427/cylab/chroma_db",
                                   settings=Settings(allow_reset=True))

brca_collection = client.get_or_create_collection(name = "brca", metadata={"hnsw:space": "cosine"})

print(brca_collection)

In [None]:
brca_collection.count()

In [None]:
# embed_docs_in_chroma(docs_processed, brca_collection)

In [None]:
brca_collection.count()

# 잘 계산되나 시험

In [56]:
morpheus_collection = client.create_collection(
     name="morpheus", metadata={"hnsw:space": "cosine"}
)

In [None]:
docs = [
        "This is your last chance. After this, there is no turning back.",
        "You take the blue pill, the story ends, you wake up in your bed and believe whatever you want to believe.",
        "You take the red pill, you stay in Wonderland, and I show you how deep the rabbit hole goes.",
    ]
morpheus_collection.add(
       documents=docs,
    embeddings= embedding_model.encode(docs, max_length=MAX_LENGTH).detach().cpu().numpy().tolist(),
    ids=["quote1", "quote2", "quote3"],
)

morpheus_collection.get()

In [None]:
# Querying by a set of query_texts
queries = ["Take the blow pill", "chance", "yewon"]
results = morpheus_collection.query(query_embeddings=embedding_model.encode(queries, max_length=MAX_LENGTH).detach().cpu().numpy().tolist(),
                                                    include=["metadatas", "documents", "distances"],
                                                    n_results=2,
                                                    )

results

In [None]:
for query in range(len(results['documents'])):
    print(f"For {query}st query: ")
    for top in range(len(results['documents'][query])):
        print(f"----top {top}st----")
        print(results['documents'][query][top])
    print()

# vllm

In [29]:
from openai import OpenAI

client = OpenAI(api_key = "empty",
                base_url = "http://localhost:8000/v1")

def agent(client, prompt, output_schema):
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        model = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        messages = messages,
        extra_body={"guided_json":output_schema},
        temperature = 0.1)
  
    return response.choices[0].message.content

In [None]:
from pydantic import BaseModel, Field

# Return your reasoning and the T stage in the following JSON format:
# {
#   "reasoning": "Step-by-step explanation of how you interpreted the report to determine the T stage.",
#   "T_stage": "T1, T2, T3, or T4"
# }

# Return your reasoning and the N stage in the following JSON format:
# {
#   "reasoning": "Step-by-step explanation of how you interpreted the report to determine the N stage.",
#   "N_stage": "N0, N1, N2, or N3"
# }

class Response(BaseModel):
    reasoning: str = Field(description="Step-by-step explanation of how you interpreted the report to determine the cancer stage.")
    stage: str = Field(description="The cancer stage determined from the report.")
 
schema = Response.model_json_schema()
schema

In [37]:
main_query = '''Please infer a list of general rules that help predict the T stage for breast cancer based on the AJCC's TNM Staging System. Ensure there is at least one rule for each T stage (T1, T2, T3, T4) in the list of rules.'''
# main_query = '''Please infer a list of general rules that help predict the N stage for breast cancer based on the AJCC's TNM Staging System. Ensure there is at least one rule for each N stage (N0, N1, N2, N3) in the list of rules.'''

query_decomposer_prompt = """
You are a helpful assistant that decomposes an input query into multiple sub-queries.
Your goal is to break down the input into a set of specific sub-questions that can be answered individually to cover the full scope of the original question.

Generate at least 5 sub-queries related to the following input query: {question}

"""


rule_generator_prompt = """
You are a helpful assistant. Based on the provided context, answer the question.

Context:
{context}

Question:
{question}
"""



In [None]:
answer = agent(client, main_query)
print(answer)

In [None]:
answer = agent(client, query_decomposer_prompt.format(question=main_query))
print(answer)

In [None]:
subqueries = answer.split("\n")
subqueries = [subquery.strip() for subquery in subqueries]
subqueries

In [None]:
# queries = [main_query]
queries = subqueries
results = brca_collection.query(query_embeddings=embedding_model.encode(queries, max_length=MAX_LENGTH).detach().cpu().numpy().tolist(),
                                                    include=["metadatas", "documents", "distances"],
                                                    n_results=1,
                                                    )

results

In [None]:
retrieved_context = ""
id_set = set()
for query in range(len(results['documents'])):
    # print(f"For {query}st query: ")
    for top in range(len(results['documents'][query])):
        if results['ids'][query][top] in id_set:
            print(f"Skip at {query}, {top}")
            continue
        else:
            id_set.add(results['ids'][query][top])
            retrieved_context += results['documents'][query][top]+"\n"
            print(f"Add at {query}, {top}")

In [None]:
retrieved_context

In [None]:
answer = agent(client, rule_generator_prompt.format(context = retrieved_context, question = main_query))
print(answer)

In [None]:
answer