In [1]:
import chromadb
from sentence_transformers import SentenceTransformer
from client import Client
from prompt_gen import PromptGen

client = chromadb.PersistentClient(path="./clinical_trials_chroma")
embed_model = SentenceTransformer("malteos/scincl")
collection = client.get_or_create_collection("clinical_trials_studies")


client = Client(
  client=client,
  collection=collection,
  embed_model=embed_model
)

prompt_gen = PromptGen(
  client=client
)

  from tqdm.autonotebook import tqdm, trange


In [2]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
import uuid


ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")


info_data = {}  # key is uuid, value is the info for the prompt
for study in tqdm(ravis_dataset['test']):

    info_for_prompt = prompt_gen.get_info_for_prompt_gen(study)
    
    if info_for_prompt:
        unique_id = str(uuid.uuid4())  # generate uuid for each entry
        encoded_related_studies, title, description, desired_criteria = info_for_prompt
        messages = prompt_gen.user_prompt_template(encoded_related_studies, title, description, desired_criteria)
        
        # store info with unique ID for later use
        info_data[unique_id] = {
            "encoded_related_studies": encoded_related_studies,
            "title": title,
            "description": description,
            "desired_criteria": desired_criteria,
            "messages": messages,
            "response": None  # setting response to None
        }

df = pd.DataFrame.from_dict(info_data, orient="index")


100%|██████████| 3993/3993 [00:41<00:00, 95.50it/s] 


In [3]:
import torch

# Clear CUDA memory
torch.cuda.empty_cache()

In [4]:
impose_input_len = True
df['input'] = df.apply(lambda x: prompt_gen.gen_input(x['encoded_related_studies'], x['title'], x['description']), axis=1)

if impose_input_len:
  from transformers import AutoTokenizer
  tokenizer = AutoTokenizer.from_pretrained("neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a16")
  df['input_len'] = df['input'].apply(lambda x: len(tokenizer(x)['input_ids']))
  df = df[df['input_len'] < 7000]
df["output"] = ""

df

Loading fast tokenizer from /home/swiss/.cache/huggingface/hub/models--neuralmagic--Meta-Llama-3.1-8B-Instruct-quantized.w8a16/snapshots/38e03ba250017bf8ed3eeecd3a744e21f6b994a9/tokenizer.json


Unnamed: 0,encoded_related_studies,title,description,desired_criteria,messages,response,input,input_len,output
59b1f1e6-877d-4571-9f0c-517661042dec,<STUDY>\n Example Title: An Open Registry t...,An Open Registry to Measure the Impact of Addi...,#Study Description \nBrief Summary \nThis regi...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: A...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,2848,
2825b026-e008-4f01-b2df-030095f88009,"<STUDY>\n Example Title: An Open-Label, Mul...","A Phase I, Open-Label, Dose-Escalation Study o...",#Study Description \nBrief Summary \nThis is a...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: A...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,3268,
5986d1a9-a70b-47b8-8b07-4c089429b3e5,<STUDY>\n Example Title: REal-world Pattern...,Treatment Patterns And Clinical Outcomes Among...,#Study Description \nBrief Summary \nCDK4/6 in...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: R...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,3365,
e2f3bc13-1f1d-4d9d-968a-aa13ba990a85,<STUDY>\n Example Title: HI-CHART: A Phase ...,A Phase I/II Trial of Isotoxic Accelerated Rad...,#Study Description \nBrief Summary \nThe I-STA...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: H...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,5937,
d2dfa392-3b5d-4edb-a59e-33df7f03c00b,<STUDY>\n Example Title: Resistance Trainin...,Exercise and Nutrition for Head and Neck Cance...,#Study Description \nBrief Summary \nResearch ...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: R...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,4480,
...,...,...,...,...,...,...,...,...,...
2f071aff-2a03-487c-99ee-40da999afa85,"<STUDY>\n Example Title: A Phase 1, Open-la...","A Phase 1, Multiple-Dose Study of the Safety a...",#Study Description \nBrief Summary \nThe purpo...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: A...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,3452,
b169f354-dcf0-4ac7-af0d-e0d6fede930d,<STUDY>\n Example Title: The Impact of Preo...,Do Omega-3 Fatty Acids Have Any Impact On Seru...,#Study Description \nBrief Summary \nPre- and ...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: T...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,5500,
948a89a6-7a6d-4526-886f-a5d1818b091b,<STUDY>\n Example Title: Aerobic and Resist...,Effects of Water-based Versus Land-based Exerc...,#Study Description \nBrief Summary \nPhysical ...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: A...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,4941,
559ea7d5-9c0b-4229-8736-465c7d45bc83,<STUDY>\n Example Title: Hepassocin Levels ...,Clusterin Level Determination and Its Associat...,#Study Description \nBrief Summary \nClusterin...,#Eligibility Criteria:\nInclusion Criteria:\n\...,<EXAMPLE_STUDIES><STUDY>\n Example Title: H...,,<RELATED_STUDIES>\n<STUDY>\n Example Title:...,3179,


In [13]:
## for the best perf will pickle the df here and load it in the next cell
df.to_pickle("df_temp.pkl")

## please restart the kernel and run the following cells

# Restart the kernal here for the best performance

In [1]:
## load the df
import pandas as pd
df = pd.read_pickle("df_temp.pkl")

In [2]:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
model_id = "swissnp/finetuned_gemini_CoT_studies"
number_gpus = 1
repetition_penalty = 1
llm = LLM(model=model_id, tensor_parallel_size=number_gpus, max_model_len=12000, gpu_memory_utilization=0.93)
def pipe(messages):
    sampling_params = SamplingParams(temperature=0, top_p=0.9, max_tokens=4096, repetition_penalty=repetition_penalty)
    prompts = llm.get_tokenizer().apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    outputs = llm.generate(prompts, sampling_params)
    print([i.outputs[0].text for i in outputs], len(outputs))
    return [i.outputs[0].text for i in outputs]

INFO 11-26 03:10:57 config.py:350] This model supports multiple tasks: {'embedding', 'generate'}. Defaulting to 'generate'.
INFO 11-26 03:10:57 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='swissnp/finetuned_gemini_CoT_studies', speculative_config=None, tokenizer='swissnp/finetuned_gemini_CoT_studies', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=12000, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=swissn

Loading pt checkpoint shards:   0% Completed | 0/9 [00:00<?, ?it/s]


  state = torch.load(bin_file, map_location="cpu")


INFO 11-26 03:11:09 model_runner.py:1077] Loading model weights took 14.9888 GB
INFO 11-26 03:11:12 worker.py:232] Memory profiling results: total_gpu_memory=23.68GiB initial_memory_usage=15.39GiB peak_torch_memory=16.25GiB memory_usage_post_profile=15.41GiB non_torch_memory=0.41GiB kv_cache_size=5.84GiB gpu_memory_utilization=0.95
INFO 11-26 03:11:12 gpu_executor.py:113] # GPU blocks: 2988, # CPU blocks: 2048
INFO 11-26 03:11:12 gpu_executor.py:117] Maximum concurrency for 12000 tokens per request: 3.98x
INFO 11-26 03:11:14 model_runner.py:1400] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 11-26 03:11:14 model_runner.py:1404] If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INF

In [None]:
from prompt_gen import PromptGen

batch_size = 20

for i in range(0, len(df), batch_size):
    print(i)
    batch = df.iloc[i:i + batch_size]
    batch_inputs = [PromptGen.gen_messages(row["input"]) for _, row in batch.iterrows()]
    batch_outputs = pipe(batch_inputs)
    df.loc[batch.index, 'output'] = batch_outputs

print(df)

In [None]:
# import re

# def extract_criteria(text):
#     match = re.search(r"<CRITERIA>(.*?)</CRITERIA>", text, re.DOTALL)
#     return match.group(1).strip() if match else None

# def improved_parse_with_raw(criteria_text):
#     criteria_text = extract_criteria(criteria_text)
#     if not criteria_text:
#         print("No criteria found")
#         return {}
#     result = {
#         "raw_criteria": criteria_text, 
#         "inclusion_criteria": [],
#         "exclusion_criteria": [],
#         "sex": "ALL",
#         "ages": {
#             "minimum_age": None,
#             "maximum_age": None,
#             "age_group": []
#         },
#         "accepts_healthy_volunteers": False
#     }
    
#     inclusion_match = re.search(r"Inclusion Criteria:(.*?)(?:Exclusion Criteria:|##|$)", criteria_text, re.DOTALL)
#     exclusion_match = re.search(r"Exclusion Criteria:(.*?)(?:##|$)", criteria_text, re.DOTALL)

#     if inclusion_match:
#         result["inclusion_criteria"] = [
#             item.strip() for item in inclusion_match.group(1).split("\n") if item.strip()
#         ]
#     if exclusion_match:
#         result["exclusion_criteria"] = [
#             item.strip() for item in exclusion_match.group(1).split("\n") if item.strip()
#         ]

#     sex_match = re.search(r"##Sex\s*:\s*(Male|Female|All)", criteria_text, re.IGNORECASE)
#     if sex_match:
#         result["sex"] = sex_match.group(1).upper()

#     min_age_match = re.search(r"- Minimum Age\s*:\s*(\d+)", criteria_text, re.IGNORECASE)
#     if min_age_match:
#         result["ages"]["minimum_age"] = int(min_age_match.group(1))

#     max_age_match = re.search(r"- Maximum Age\s*:\s*(\d+)", criteria_text, re.IGNORECASE)
#     if max_age_match:
#         result["ages"]["maximum_age"] = int(max_age_match.group(1))

#     age_group_match = re.findall(r"Age Group.*?:(.*?)$", criteria_text, re.MULTILINE)
#     if age_group_match:
#         age_groups = re.findall(r"(Child|Adult|Older Adult)", " ".join(age_group_match), re.IGNORECASE)
#         result["ages"]["age_group"] = list(set(group.upper() for group in age_groups))  # Unique values

#     healthy_volunteers_match = re.search(r"##Accepts Healthy Volunteers:\s*(Yes|No)", criteria_text, re.IGNORECASE)
#     if healthy_volunteers_match:
#         result["accepts_healthy_volunteers"] = healthy_volunteers_match.group(1).strip().lower() == "yes"

#     return result

# df.dropna(subset=['output'], inplace=True)
# improved_criteria_with_raw_json = df['output'].apply(improved_parse_with_raw)

# df['json'] = improved_criteria_with_raw_json