In [1]:
import pandas as pd
from src.prompts import dfs_prompt_t14, system_instruction_short_ans_t14
from transformers import AutoTokenizer

model_name_or_path = "m42-health/med42-70b"
cache_dir = "/secure/chiahsuan/hf_cache"
# load tokenizer for counting the number of tokens for each report 
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,cache_dir=cache_dir)

def count_tokens(text, tokenizer):
    input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda()
    start_index = input_ids.shape[-1]
    return start_index

MED42_PROMPT_TEMPLATE = """<|system|>:{system_instruction}
<|prompter|>:{prompt}
<|assistant|>:"""

SHOT_TEMPLATE = """
Report:{report}
Answer:{label}"""

def format_demonstration(report, label):
    return SHOT_TEMPLATE.format(report=report, label=label)

# T Category

In [2]:
data_base_path = "/secure/shared_data/tcga_path_reports/"
# load data retrieved from select-report-length-few-shots.ipynb
filename = "dfs-t14-report-length-k5.csv"
t14_testing_reports = pd.read_csv(data_base_path+"t14_data/text-embedding-3-small/"+filename)

In [3]:
""" This cell needs to be modulized later
it only works for T category and Med42 model
"""
label_name = "t"
label_range = [0,1,2,3]
k = 1 # number of shot for each category
col_key = "dfs_{}{}_{}"
formatted_fs_prompts = []
num_of_tokens = []

for _, each_report in t14_testing_reports.iterrows():
    demos = []
    for i in label_range: # e.g., T1, T2, ...
        shot_label = "{}{}".format(label_name.upper(), i+1) # i+1 is necessary for T category because 0 maps to T1, ..., 3 maps to T4
        for j in range(k): # top K
            shot_report = each_report[col_key.format(label_name, i, j)]
            demos.append(SHOT_TEMPLATE.format(report=shot_report, label=shot_label))
    demo_string = "\n".join(demos)
    formatted_prompt = dfs_prompt_t14.format(demonstrations=demo_string, report=each_report["text"])
    model_formatted_prompt = MED42_PROMPT_TEMPLATE.format(system_instruction=system_instruction_short_ans_t14,
                                                    prompt=formatted_prompt)
    formatted_fs_prompts.append(model_formatted_prompt)
    num_of_tokens.append(count_tokens(model_formatted_prompt, tokenizer))
    
out_df = t14_testing_reports[["patient_filename", "t", "text"]].copy()
out_df["formatted_fs_prompts"] = formatted_fs_prompts
out_df["num_of_tokens"] = num_of_tokens

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [13]:
# out_df.iloc[2]["formatted_fs_prompts"]
out_df

Unnamed: 0,patient_filename,t,text,formatted_fs_prompts,num_of_tokens
0,TCGA-CQ-7063.be787b58-3a45-427d-98af-68996bb3f337,0,Visit Number. Event Date and Time. Procedure O...,<|system|>:You are an expert at interpreting p...,9376
1,TCGA-63-A5MV.494B2443-C3FD-43A0-A735-29D9D801C9D1,1,Sample Type TUMOUR. Diagnosis. Squamous Cell C...,<|system|>:You are an expert at interpreting p...,1096
2,TCGA-AG-A026.800BF328-7F21-4BF0-95CA-D9092FC6B2D8,3,Internal Sample IC. This is vascularized fatty...,<|system|>:You are an expert at interpreting p...,1649
3,TCGA-CV-7437.8d2c5741-ed13-4aba-8a62-ad928f223d08,1,SUPPLEMENTAL REPORT. DIAGNOSIS: (A) TOTAL LARY...,<|system|>:You are an expert at interpreting p...,4547
4,TCGA-85-A512.A49AF0B2-6C19-47D5-BD75-A6CB7F7057B9,0,Gross Description: Microscopic Description: Di...,<|system|>:You are an expert at interpreting p...,1178
...,...,...,...,...,...
1029,TCGA-2G-AAHL.81F00ECA-C7BD-4EDD-ABE1-EED08C0E5F36,1,Summary pathology report. Left orchidectomy; s...,<|system|>:You are an expert at interpreting p...,1657
1030,TCGA-AC-A3YJ.0F890724-8554-4B2B-B4EC-6B9B24C757D8,1,MRN #: SPECIME. DIAGNOSIS. DIAGNOSIS: A. Right...,<|system|>:You are an expert at interpreting p...,6495
1031,TCGA-DD-AACV.9BC1D204-8E33-4E30-94B7-6AE5E6484AD2,0,CLINICAL DIAGNOSIS: HCC. Specimen : liver. Gro...,<|system|>:You are an expert at interpreting p...,4052
1032,TCGA-CJ-4881.36725c29-f964-4ab9-92e6-776d85380d62,2,Normal Sampleno: DIAGNOSIS. (A) RIGHT KIDNEY: ...,<|system|>:You are an expert at interpreting p...,4094


In [15]:
filter_out_df = out_df[out_df["num_of_tokens"] < 10240]
filter_out_df.shape

(819, 5)

In [16]:
filter_out_df.to_csv(data_base_path+"t14_data/text-embedding-3-small/"+"dfs-t14-formatted-report-length-k1.csv")