In [4]:
s = None
f"similarity: {s}"
type(s)

NoneType

In [1]:
# import OpenAI's tokenizer
import tiktoken

# import used functions from chromadb
import chromadb
import chromadb.utils.embedding_functions as embedding_functions

# import pandas 
import pandas as pd

# import utils
from tqdm import tqdm
from collections import defaultdict 
import pickle

import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file

data_base_path = "/secure/shared_data/tcga_path_reports/"

# Initialize collection instance in Chroma database 

In [2]:
"""
Initialize the helper for OpenAI's embedding API
and also initialize Chroma's client and select the target collection
"""

# text-embedding-3-small is better than text-embedding-ada-002
# the best-perfoming model is text-embedding-3-small (we can consider it later when the cost is allowed.)
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.getenv('OPENAI_API_KEY'),
                model_name="text-embedding-3-small"
            )

client = chromadb.PersistentClient(path=data_base_path+"chroma_data/")

# Embed documents into representations and store into the collection 

In [3]:
"""
Load the reports, embed them, save them in the ChromaDB's collection
"""

def num_tokens_from_string(string: str, encoding_name: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

def embed_reports_in_chroma(report_df, collection, label_name = "t"):
    
    pbar = tqdm(total=report_df.shape[0])

    for _, report in report_df.iterrows():
        report_patient_filename = report["patient_filename"]
        report_label = report[label_name]
        report_text = report["text"]
        report_length_cl100k_base = num_tokens_from_string(report_text, "cl100k_base")

        collection.add(
            embeddings=openai_ef([report_text])[0],
            metadatas={"patient_filename": report_patient_filename, label_name: report_label, "report_length_cl100k_base": report_length_cl100k_base},
            documents=report_text,
            ids=report_patient_filename
        )
        pbar.update(1)
    pbar.close()

In [4]:
# embed the training split
collection = client.get_or_create_collection("text-emb-3-small-t14-training-split", embedding_function=openai_ef)
print(collection.count())
if collection.count() == 0:
    print("Start encoding the documents into database.")
    t14_training_reports = pd.read_csv(data_base_path+"t14_data/Target_Data_T14.csv")
    #samples = t14_training_reports.sample(n=10)
    embed_reports_in_chroma(t14_training_reports, collection, label_name = "t")
else:
    print("Collection is already there.")
# client.delete_collection("text-emb-3-small-t14-training-split")
# 5853/5853 [22:44<00:00,  4.29it/s]

5853
Collection is already there.


In [5]:
# embed the testing split
testing_collection = client.get_or_create_collection("text-emb-3-small-t14-testing-split", embedding_function=openai_ef)
print(testing_collection.count())
if testing_collection.count() == 0:
    print("Start encoding the documents into database.")
    t14_testing_reports = pd.read_csv(data_base_path+"t14_data/Target_Data_T14_test.csv")
    embed_reports_in_chroma(t14_testing_reports, testing_collection, label_name = "t")
else:
    print("Collection is already there.")
# client.delete_collection("text-emb-3-small-t14-testing-split")
# 1034/1034 [03:41<00:00,  4.67it/s]

0
Start encoding the documents into database.


100%|██████████| 1034/1034 [03:41<00:00,  4.67it/s]


# Query

In [42]:
t14_testing_reports = pd.read_csv(data_base_path+"t14_data/Target_Data_T14_test.csv")
label_name = "t"
top_n = 5
dynamic_few_shots = defaultdict(list) 

pbar = tqdm(total=t14_testing_reports.shape[0])
for _, report in t14_testing_reports.iterrows():
    
    # retrieved embedding for each test report
    test_return_obj = testing_collection.get(ids=report["patient_filename"], include=["embeddings", "metadatas"])
    test_report_embedding = test_return_obj["embeddings"]
    
    # retrieve similar items for each possible_label, e.g., t0, t1, ...
    # similar items SHOULD be from training split (i.e., collection)
    for possible_label in [0,1,2,3]:
        retrieved_items = collection.query(query_embeddings=test_report_embedding,
                                                   n_results=top_n,
                                                   where={label_name: possible_label})
        # add any filter later if any
        # current way is a simple k-nn retrieval
        for idx, doc in enumerate(retrieved_items["documents"][0]): # [0] is used becasue only single query (i.e., test_report_embedding)
            shot_key = "dfs_{}{}_{}".format(label_name, possible_label, idx) # dfs_t0_0, this format is used as the key in panda's df; every test report can easily pick the top n items for a specific category
            dynamic_few_shots[shot_key].append(doc)

    pbar.update(1)
pbar.close()
# 1034/1034 [02:58<00:00,  5.80it/s]

100%|██████████| 1034/1034 [02:58<00:00,  5.80it/s]


In [44]:
# validation
print(t14_testing_reports.shape[0])
for key in dynamic_few_shots.keys():
    print(key, len(dynamic_few_shots[key]))
    assert t14_testing_reports.shape[0] == len(dynamic_few_shots[key])
    t14_testing_reports[key] = dynamic_few_shots[key]

1034
dfs_t0_0 1034
dfs_t0_1 1034
dfs_t0_2 1034
dfs_t0_3 1034
dfs_t0_4 1034
dfs_t1_0 1034
dfs_t1_1 1034
dfs_t1_2 1034
dfs_t1_3 1034
dfs_t1_4 1034
dfs_t2_0 1034
dfs_t2_1 1034
dfs_t2_2 1034
dfs_t2_3 1034
dfs_t2_4 1034
dfs_t3_0 1034
dfs_t3_1 1034
dfs_t3_2 1034
dfs_t3_3 1034
dfs_t3_4 1034


In [None]:
t14_testing_reports

In [47]:
# save the retrieved results
filename = "dfs-t14-report-length-k5.csv"
t14_testing_reports.to_csv(data_base_path+"t14_data/text-embedding-3-small/"+filename)

### prepare 5 folds shots

In [3]:
client.list_collections()

[Collection(name=text-emb-3-small-n03-training-split),
 Collection(name=full_report_emb),
 Collection(name=test-bge-small-t14),
 Collection(name=full_summary_emb),
 Collection(name=text-emb-3-small-n03-testing-split)]

In [4]:
full_summary_collection = client.get_collection("full_summary_emb")
full_summary_collection.count()

1517

In [6]:
with open('/home/yl3427/cylab/rag_tnm/full_5folds_dict.pkl', 'rb') as file:
    loaded_dict = pickle.load(file)
df = pd.read_csv("/secure/shared_data/rag_tnm_results/summary/5_folds_summary/merged_df.csv")
df.columns

Index(['patient_filename', 't', 'text', 'type', 'n'], dtype='object')

In [11]:
for mode, folds in loaded_dict.items():
    for i, fold in enumerate(folds):
        print(f"{mode} {i+1}st fold")
        print(fold.keys())
        print(mode.split("_")[1])

brca_t_5folds 1st fold
dict_keys(['train', 'test'])
t
brca_t_5folds 2st fold
dict_keys(['train', 'test'])
t
brca_t_5folds 3st fold
dict_keys(['train', 'test'])
t
brca_t_5folds 4st fold
dict_keys(['train', 'test'])
t
brca_t_5folds 5st fold
dict_keys(['train', 'test'])
t
brca_n_5folds 1st fold
dict_keys(['train', 'test'])
n
brca_n_5folds 2st fold
dict_keys(['train', 'test'])
n
brca_n_5folds 3st fold
dict_keys(['train', 'test'])
n
brca_n_5folds 4st fold
dict_keys(['train', 'test'])
n
brca_n_5folds 5st fold
dict_keys(['train', 'test'])
n
luad_t_5folds 1st fold
dict_keys(['train', 'test'])
t
luad_t_5folds 2st fold
dict_keys(['train', 'test'])
t
luad_t_5folds 3st fold
dict_keys(['train', 'test'])
t
luad_t_5folds 4st fold
dict_keys(['train', 'test'])
t
luad_t_5folds 5st fold
dict_keys(['train', 'test'])
t
luad_n_5folds 1st fold
dict_keys(['train', 'test'])
n
luad_n_5folds 2st fold
dict_keys(['train', 'test'])
n
luad_n_5folds 3st fold
dict_keys(['train', 'test'])
n
luad_n_5folds 4st fold
dict_

In [None]:
top_n = 5
dynamic_few_shots = defaultdict(list) 

for mode, folds in loaded_dict.items():
    for i, fold in enumerate(folds):
        print(f"{mode} {i+1}st fold")
        print(fold.keys())
        label = mode.split("_")[1]
        print(label)
        list_of_test_ids = fold['test']
        test_df = df[df['patient_filename'].isin(list_of_test_ids)]
        
        pbar = tqdm(total=test_df.shape[0])
        for _, report in test_df.iterrows():
            
            test_return_obj = full_summary_collection.get(ids=report["patient_filename"], include=["embeddings", "metadatas"])
            test_report_embedding = test_return_obj["embeddings"]
            
            for possible_label in [0,1,2,3]:
                retrieved_items = full_summary_collection.query(query_embeddings=test_report_embedding,
                                                        n_results=top_n,
                                                        where={f'is_goodsum_{label}': {"$eq": True},
                                                               label: possible_label})
                                                        
                for idx, doc in enumerate(retrieved_items["documents"][0]): # [0] is used becasue only single query (i.e., test_report_embedding)
                    shot_key = "dfs_{}{}_{}".format(label_name, possible_label, idx) # dfs_t0_0, this format is used as the key in panda's df; every test report can easily pick the top n items for a specific category
                    dynamic_few_shots[shot_key].append(doc)

            pbar.update(1)
        pbar.close()

In [10]:
client.get_collection("luad_n_5folds_2st_train_good_sum_emb").peek()
                               

{'ids': ['TCGA-05-4244.3a844132-f813-4d8e-8f7d-dbae0b23d7fd',
  'TCGA-05-4245.902fe548-5b93-49c9-81db-2af4a4a88f3c',
  'TCGA-05-4249.7e920317-d5c2-4160-9e2b-ef0101eb5a23',
  'TCGA-05-4250.5574f2f8-f247-40e6-a285-7793edcf5358',
  'TCGA-05-4382.952c0f32-1472-49e1-8334-b0f1de4ac921',
  'TCGA-05-4389.924f5877-07dc-48dc-b920-c59ba743498e',
  'TCGA-05-4390.2e3faad1-3a5e-4efb-96ea-8c44839fec6e',
  'TCGA-05-4395.601b48ca-e99c-4d4b-854e-4aa923b63237',
  'TCGA-05-4396.96f3b48a-c5c4-40c3-a3d1-64477499c6e6',
  'TCGA-05-4397.e3f87d2c-61b5-435a-a095-318d2bf58bb3'],
 'embeddings': [[0.06585714966058731,
   0.07980936765670776,
   -0.028476255014538765,
   -0.07303035259246826,
   -0.05829576030373573,
   -0.03968682512640953,
   -0.021373150870203972,
   0.07292729616165161,
   0.03480758145451546,
   -0.0030991616658866405,
   0.009435717947781086,
   0.0018854656955227256,
   -0.007299541495740414,
   0.004778753034770489,
   0.014413821510970592,
   0.0742042288184166,
   -0.010597209446132183,
  