
## Chain of Thought (CoT) Creation

In this notebook, we aim to create a Chain of Thought (CoT) for clinical trial studies using HuggingFace's transformers and other relevant libraries. Below is a step-by-step guide to our workflow:

1. **Setup and Initialization**:
  - We initialize the ChromaDB client and load the SentenceTransformer model for encoding queries.
  - We define a function to fix invalid JSON strings, which is crucial for handling the metadata of clinical trial studies.

2. **Retrieving Relevant Studies**:
  - We define a function `retrieve_relevant_studies` that queries the ChromaDB collection to find studies relevant to a given query, excluding the study already present in the query.

3. **Crafting Context from Studies**:
  - We define a function `craft_context_from_studies_documents` to create a context string from the documents of related studies. This context is used to provide examples in the CoT creation process.

4. **Generating Messages for CoT**:
  - We define a function `get_messages_for_create_CoT` that generates the system and user messages required for creating a CoT. These messages include the study title, description, and desired criteria.

5. **Prompt Creation**:
  - We define a function `get_prompt_from_studies` that uses the above functions to generate the complete prompt for a given study. This prompt includes the context from related studies and the task instructions for generating the CoT.

6. **Model Inference**:
  - We load the HuggingFace model and tokenizer, and define a function `pipe` to generate the CoT using the model. The function takes the generated messages as input and returns the model's output.
  - For Gemini, can use the function in the gemini section to generate the CoT.

By following this workflow, we can systematically generate a Chain of Thought for clinical trial studies, leveraging the capabilities of HuggingFace's transformers and other relevant tools.
```

In [None]:
import chromadb
from sentence_transformers import SentenceTransformer

client = chromadb.PersistentClient(path="./clinical_trials_chroma")
embed_model = SentenceTransformer("malteos/scincl")
collection = client.get_or_create_collection("clinical_trials_studies")

In [13]:
import re
def fix_invalid_json(input_str):
    ## add double quotes around elements inside square brackets if not already quoted
    fixed_str = re.sub(r'(?<=\[)([^\[\],]+)(?=\])', lambda x: '"' + x.group(0).strip() + '"', input_str)
    
    ## add double quotes around words in Conditions and Interventions
    fixed_str = re.sub(r'(?<=\[)([^\"\]]+?)(?=\])', lambda x: '"' + x.group(0).strip().replace(", ", '", "') + '"', fixed_str)
    
    ## fix key-value pairs inside Interventions
    fixed_str = re.sub(r'"([A-Za-z]+): ([A-Za-z0-9\s]+)"', r'"\1: \2"', fixed_str)
    
    # fix dictionary keys
    fixed_str = re.sub(r'(?<!")(\b[A-Za-z_]+\b)(?=\s*:)', r'"\1"', fixed_str)
    
    return fixed_str

In [None]:
# retrieve relevent studies from chromadb but exclude the ones that are already is the query
def retrieve_relevant_studies(query, existing_study, n_results=5):
    query_embedding = embed_model.encode(query).tolist()
    
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n_results + 1,
    )
    
    filtered_results = []
    for id, distance, document in zip(results['ids'][0], results['distances'][0], results['documents'][0]):
        if id != existing_study:
            filtered_results.append({
                "id": id,
                "distance": distance,
                "document": document,
            })
        
        if len(filtered_results) == n_results:
            break
    
    return filtered_results

print(
    retrieve_relevant_studies("Effect of Kinesiotaping on Edema Management, Pain and Function on Patients With Bilateral Total Knee Arthroplasty [SEP] After being informed about the study and potential risk, all patients undergoing inpatient rehabilitation after bilateral total knee arthroplasty will have Kinesio(R)Tape applied to one randomly selected leg while the other leg serves as a control. Measurement of bilateral leg circumference, knee range of motion, numerical rating scale for pain, and selected questions from the Knee Injury and Osteoarthritis Outcome Score will occur at regular intervals throughout the rehabilitation stay. Patients will receive standard rehabilitation.", 
                              "NCT05013879"))

In [15]:
import json
import json_repair
def related_studies_template(title: str, description: str, criteria: str):
    return f"""Example Title: {title}
Example Description: {description}
Example Criteria: {criteria}
"""

def craft_context_from_studies_documents(related_studies: list[str]):
    json_related_studies = [json.loads(i) for i in related_studies]
    context = ""
    for i in json_related_studies:
        title = i.get('metadata', {}).get('Official_title', "")
        description = i.get('description', "")
        criteria = i.get('criteria', "")
        if title and description:
            context += f"""<STUDY>
{related_studies_template(title, description, criteria)}
</STUDY>"""
    return context

def user_prompt_template(encoded_related_studies: str, title: str, description: str, desired_criteria: str):
    user_prompt_template = """<EXAMPLE_STUDIES>{encoded_related_studies}</EXAMPLE_STUDIES>

Title: {title}
Description: {description}
Desired criteria: {desired_criteria}

Task Instructions:
1. Derive a step-by-step justification starting from the Title and Description provided, gradually building up to support the Desired criteria.
2. Could use example studies (in the <EXAMPLE_STUDIES> section) if they support your justifications, but ensure the reasoning is well-explained and relevant to the study's context.
4. Avoid mentioning that the criteria were already provided, and please do not cite the given criteria directly in your justification.
5. You should give the justification first before giving out the criteria.

Response Format:
<STEP-BY-STEP-JUSTIFICATION>
Your step by step justification here.
</STEP-BY-STEP-JUSTIFICATION>
<Criteria>
The copied desired criteria here.
</Criteria>
"""

    return user_prompt_template.format(encoded_related_studies=encoded_related_studies, title=title, description=description, desired_criteria=desired_criteria)

system_prompt = "You are a justifier chatbot designed to generate step-by-step justifications that derived form the Title and Description of a study and then gradually build up to the Desired criteria. Your task is to analyze the title and description of a study and build logical, step-by-step justifications that connect the study’s key elements to the desired criteria. Reference related example studies if they reinforce your justifications. You must assume the desired criteria are correct (as it was already reviewed by specialists) and develop arguments to support them based on the study context and relevant research insights."

def get_messages_for_CoT_huggingface(encoded_related_studies: str, title: str, description: str, desired_criteria: str):
    return [
        {"role": "system", "content": "You are a justifier chatbot designed to generate step-by-step justifications that derived form the Title and Description of a study and then gradually build up to the Desired criteria. Your task is to analyze the title and description of a study and build logical, step-by-step justifications that connect the study’s key elements to the desired criteria. Reference related example studies if they reinforce your justifications. You must assume the desired criteria are correct (as it was already reviewed by specialists) and develop arguments to support them based on the study context and relevant research insights."},
        {"role": "user", "content": user_prompt_template(encoded_related_studies, title, description, desired_criteria)},
    ]
    

def get_info_for_prompt_gen(study_info: dict):
    metadata = json_repair.loads(fix_invalid_json(study_info.get('metadata')))
    title = metadata.get('Official_title')
    description = study_info.get('data')
    study_id = metadata.get('NCT_ID')
    desired_criteria = study_info.get('criteria')

    # Ensure we have the minimum required information
    if not title or not description or not desired_criteria or not study_id:
        print(f"Skipping study {study_id}: Missing title or description or desired criteria or study id")
        return None

    query = f'{title} [SEP] {description}'
    relevant_studies = retrieve_relevant_studies(query, study_id)
    encoded_related_studies = craft_context_from_studies_documents([i['document'] for i in relevant_studies])
    return encoded_related_studies, title, description, desired_criteria




## Using Llama 3.1

In [None]:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_id = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a16"
number_gpus = 1

tokenizer = AutoTokenizer.from_pretrained(model_id)

llm = LLM(model=model_id, tensor_parallel_size=number_gpus, max_model_len=20000)

def pipe(messages):
    sampling_params = SamplingParams(temperature=0, top_p=0.9, max_tokens=4096)
    prompts = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    outputs = llm.generate(prompts, sampling_params)
    return [i.outputs[0].text for i in outputs]

In [None]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
# Load the JSON data

ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")

for study in tqdm(ravis_dataset['train']):
    print(f"Processing {study['metadata']}")
    info_for_prompt = get_info_for_prompt_gen(study)
    messages = get_messages_for_CoT_huggingface(*info_for_prompt)
    print(f"Prompt: {messages}")
    print(f"Response: {pipe(messages)}")


## Using Gemini

In [None]:
!pip install -U google-generativeai
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import google.generativeai as genai

# Configure the Gemini model
genai.configure(api_key="YOUR_API_KEY")
model = genai.GenerativeModel("gemini-1.5-flash",
                              system_instruction=system_prompt)

# Load the JSON data
ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")

for study in tqdm(ravis_dataset['train']):
    print(f"Processing {study['metadata']}")
    info_for_prompt = get_info_for_prompt_gen(study)
    messages = user_prompt_template(*info_for_prompt)
    print(f"Prompt: {messages}")
    response = model.generate_content(messages)
    print(f"Response: {response.text}")