In [None]:
import chromadb
from sentence_transformers import SentenceTransformer
from client import Client
from prompt_gen import PromptGen

client = chromadb.PersistentClient(path="./clinical_trials_chroma")
embed_model = SentenceTransformer("malteos/scincl")
collection = client.get_or_create_collection("clinical_trials_studies")


client = Client(
  client=client,
  collection=collection,
  embed_model=embed_model
)

prompt_gen = PromptGen(
  client=client
)

In [None]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import uuid


ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")


info_data = {}  # key is uuid, value is the info for the prompt
for study in tqdm(ravis_dataset['test']):

    info_for_prompt = prompt_gen.get_info_for_prompt_gen(study)
    
    if info_for_prompt:
        unique_id = str(uuid.uuid4())  # generate uuid for each entry
        encoded_related_studies, title, description, desired_criteria = info_for_prompt
        messages = prompt_gen.user_prompt_template(encoded_related_studies, title, description, desired_criteria)
        
        # store info with unique ID for later use
        info_data[unique_id] = {
            "encoded_related_studies": encoded_related_studies,
            "title": title,
            "description": description,
            "desired_criteria": desired_criteria,
            "messages": messages,
            "response": None  # setting response to None
        }

df = pd.DataFrame.from_dict(info_data, orient="index")


In [3]:
import torch

# Clear CUDA memory
torch.cuda.empty_cache()

In [None]:
impose_input_len = True
df['input'] = df.apply(lambda x: prompt_gen.gen_input(x['encoded_related_studies'], x['title'], x['description']), axis=1)

if impose_input_len:
  from transformers import AutoTokenizer
  tokenizer = AutoTokenizer.from_pretrained("neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a16")
  df['input_len'] = df['input'].apply(lambda x: len(tokenizer(x)['input_ids']))
  df = df[df['input_len'] < 7000]
df["output"] = ""

df

In [13]:
## for the best perf will pickle the df here and load it in the next cell
df.to_pickle("df_temp.pkl")

## please restart the kernel and run the following cells

# Restart the kernal here for the best performance

In [1]:
## load the df
import pandas as pd
df = pd.read_pickle("df_temp.pkl")

In [None]:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
model_id = "swissnp/finetuned_gemini_CoT_studies"
number_gpus = 1
repetition_penalty = 1
llm = LLM(model=model_id, tensor_parallel_size=number_gpus, max_model_len=12000, gpu_memory_utilization=0.93)
def pipe(messages):
    sampling_params = SamplingParams(temperature=0, top_p=0.9, max_tokens=4096, repetition_penalty=repetition_penalty)
    prompts = llm.get_tokenizer().apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    outputs = llm.generate(prompts, sampling_params)
    print([i.outputs[0].text for i in outputs], len(outputs))
    return [i.outputs[0].text for i in outputs]

In [None]:
from prompt_gen import PromptGen

batch_size = 20

for i in range(0, len(df), batch_size):
    print(i)
    batch = df.iloc[i:i + batch_size]
    batch_inputs = [PromptGen.gen_messages(row["input"]) for _, row in batch.iterrows()]
    batch_outputs = pipe(batch_inputs)
    df.loc[batch.index, 'output'] = batch_outputs

print(df)

In [None]:
# import re

# def extract_criteria(text):
#     match = re.search(r"<CRITERIA>(.*?)</CRITERIA>", text, re.DOTALL)
#     return match.group(1).strip() if match else None

# def improved_parse_with_raw(criteria_text):
#     criteria_text = extract_criteria(criteria_text)
#     if not criteria_text:
#         print("No criteria found")
#         return {}
#     result = {
#         "raw_criteria": criteria_text, 
#         "inclusion_criteria": [],
#         "exclusion_criteria": [],
#         "sex": "ALL",
#         "ages": {
#             "minimum_age": None,
#             "maximum_age": None,
#             "age_group": []
#         },
#         "accepts_healthy_volunteers": False
#     }
    
#     inclusion_match = re.search(r"Inclusion Criteria:(.*?)(?:Exclusion Criteria:|##|$)", criteria_text, re.DOTALL)
#     exclusion_match = re.search(r"Exclusion Criteria:(.*?)(?:##|$)", criteria_text, re.DOTALL)

#     if inclusion_match:
#         result["inclusion_criteria"] = [
#             item.strip() for item in inclusion_match.group(1).split("\n") if item.strip()
#         ]
#     if exclusion_match:
#         result["exclusion_criteria"] = [
#             item.strip() for item in exclusion_match.group(1).split("\n") if item.strip()
#         ]

#     sex_match = re.search(r"##Sex\s*:\s*(Male|Female|All)", criteria_text, re.IGNORECASE)
#     if sex_match:
#         result["sex"] = sex_match.group(1).upper()

#     min_age_match = re.search(r"- Minimum Age\s*:\s*(\d+)", criteria_text, re.IGNORECASE)
#     if min_age_match:
#         result["ages"]["minimum_age"] = int(min_age_match.group(1))

#     max_age_match = re.search(r"- Maximum Age\s*:\s*(\d+)", criteria_text, re.IGNORECASE)
#     if max_age_match:
#         result["ages"]["maximum_age"] = int(max_age_match.group(1))

#     age_group_match = re.findall(r"Age Group.*?:(.*?)$", criteria_text, re.MULTILINE)
#     if age_group_match:
#         age_groups = re.findall(r"(Child|Adult|Older Adult)", " ".join(age_group_match), re.IGNORECASE)
#         result["ages"]["age_group"] = list(set(group.upper() for group in age_groups))  # Unique values

#     healthy_volunteers_match = re.search(r"##Accepts Healthy Volunteers:\s*(Yes|No)", criteria_text, re.IGNORECASE)
#     if healthy_volunteers_match:
#         result["accepts_healthy_volunteers"] = healthy_volunteers_match.group(1).strip().lower() == "yes"

#     return result

# df.dropna(subset=['output'], inplace=True)
# improved_criteria_with_raw_json = df['output'].apply(improved_parse_with_raw)

# df['json'] = improved_criteria_with_raw_json