TODO for DFS-P
- package the function
- change the embed model to a powerful one
- create a context length gate mechanism to prevent sending too long prompt
- creating chunks in a meaningful 

For report:
- propose a intuitive way 
    - sorted reports by tokens
- propose a modified version using cross-encoder
- Discussion
    - Input context length is a limitation
        - How to chunk the model in a meaningful way  
    - Need to seach medical-domain embedding model

In [1]:
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
import pandas as pd
from tqdm import tqdm

from itertools import product
from sentence_transformers import CrossEncoder
import numpy as np

data_base_path = "/secure/shared_data/tcga_path_reports/"

  from .autonotebook import tqdm as notebook_tqdm


# Initialize collection instance in Chroma database 

In [18]:
"""
Initialize sentence_transformer and its associated textsplitter
and also initialize Chroma's client and select the target collection
"""

# use large one later
embed_model_name = "BAAI/bge-small-en-v1.5"

chunk_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0,
                                                       model_name=embed_model_name,
                                                       tokens_per_chunk=512)

embedding_function = SentenceTransformerEmbeddingFunction(model_name=embed_model_name,
                                                          device="cpu",
                                                          normalize_embeddings=True)

client = chromadb.PersistentClient(path=data_base_path+"chroma_data/")
collection = client.get_or_create_collection("test-bge-small-t14", embedding_function=embedding_function)

In [19]:
collection.count()

218

# Embed documents into representations and store into the collection 

In [20]:
"""
Load the reports, embed them, save them in the ChromaDB's collection
"""
def embed_reports_in_chroma(report_df, collection, label_name = "t"):
    
    pbar = tqdm(total=report_df.shape[0])

    for _, report in report_df.iterrows():
        report_patient_filename = report["patient_filename"]
        report_label = report[label_name]
        report_text = report["text"]

        # n = chunk_splitter.count_tokens(text=report_text)
        chunks = chunk_splitter.split_text(report_text)
        embeded_chunks = embedding_function(chunks)
        # print(n, len(chunks), len(embeded_chunks))
        assert len(chunks) == len(embeded_chunks)
        for chunk_id, emd in enumerate(embeded_chunks):
            collection.add(
                embeddings=emd,
                metadatas={"patient_filename": report_patient_filename, label_name: report_label, "chunk_id":chunk_id},
                documents=chunks[chunk_id],
                ids=report_patient_filename+"_"+str(chunk_id)
            )
            pbar.update(1)
    pbar.close()

In [21]:
if collection.count() == 0:
    print("Start encoding the documents into database.")
    t14_training_reports = pd.read_csv(data_base_path+"t14_data/Target_Data_T14.csv")
    samples = t14_training_reports.sample(n=100)

    embed_reports_in_chroma(samples, collection, label_name = "t")
else:
    print("Collection is already there.")
# client.delete_collection("test-bge-small-t14")

Collection is already there.


# Implement query + cross-encoder reranking

In [22]:
"""
Query phase with cross-encoder to re-ranking the relevant retrieved chunks
"""

from itertools import product
from sentence_transformers import CrossEncoder
import numpy as np

# load testing dataset
t14_testing_reports = pd.read_csv(data_base_path+"t14_data/Target_Data_T14_test.csv")

def cross_encoder_shot_selector(test_q, collection, cross_encoder, top_n):
    # split it into chunks
    test_qs = chunk_splitter.split_text(test_q)
    print("Get {} chunks.".format(len(test_qs)))
    # perform query from the database
    result_obj = collection.query(query_texts=test_qs, n_results=5)
    # organize retrieved items
    retrieved_chunks = set([item for item_set in result_obj["documents"] for item in item_set])

    pairs = list(product(test_qs, retrieved_chunks))

    # re-ranking
    # use cross-encoder to get the predict score
    scores = cross_encoder.predict(pairs)
    index_with_highest_score = np.argsort(scores)[::-1][:top_n]
    selected_shots = set([pairs[i][1] for i in index_with_highest_score])

    return selected_shots

In [35]:
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
test_q = t14_testing_reports.iloc[0]["text"]
selected_shots = cross_encoder_shot_selector(test_q, collection, cross_encoder, top_n=5)
selected_shots

Get 3 chunks.


{') left mandible margin - a 0. 4 x 0. 4 x 0. 2 cm aggregate of soft pink - tan tissue and bony fragments. touch preparations. are made for intraoperative evaluation, and the remainder of the tissue is submitted entirely in m for decalcification. tp / dx : bone marrow cells, no tumor present. ( n ) right tongue margin - a 2. 5 x 0. 3 x 0. 2 cm tan - pink tissue magment. submitted in toto for frozen section in n. fs / dx : squamous carcinoma in situ. ( o ) left tongue margin - a 2. 2 x 0. 4 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in o. fs / dx : squamous carcinoma in situ. ( p ) left base of tongue - a 1. 0 x 0. 3 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in p. fs / dx : squamous carcinoma in situ. ( q ) right base of tongue - a 1. 5 x 1. 0 x 0. 5 cm tan - pink tissue fragment. submitted in toto for frozen section in q. s / dx : squamous carcinoma in situ. ( r ) left level ii " b " - received is a 1. 2 x 1. 2 x 0. 7 of 

# For Testing
the following blocks are used to validate the implementation of cross_encoder_shot_selector function.

In [23]:
# read report text
test_q = t14_testing_reports.iloc[0]["text"]
# split it into chunks
test_qs = chunk_splitter.split_text(test_q)
print("Get {} chunks.".format(len(test_qs)))
result_obj = collection.query(query_texts=test_qs, n_results=5)

Get 3 chunks.


In [24]:
retrieved_chunks = set([item for item_set in result_obj["documents"] for item in item_set])

In [25]:
retrieved_chunks

{') left mandible margin - a 0. 4 x 0. 4 x 0. 2 cm aggregate of soft pink - tan tissue and bony fragments. touch preparations. are made for intraoperative evaluation, and the remainder of the tissue is submitted entirely in m for decalcification. tp / dx : bone marrow cells, no tumor present. ( n ) right tongue margin - a 2. 5 x 0. 3 x 0. 2 cm tan - pink tissue magment. submitted in toto for frozen section in n. fs / dx : squamous carcinoma in situ. ( o ) left tongue margin - a 2. 2 x 0. 4 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in o. fs / dx : squamous carcinoma in situ. ( p ) left base of tongue - a 1. 0 x 0. 3 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in p. fs / dx : squamous carcinoma in situ. ( q ) right base of tongue - a 1. 5 x 1. 0 x 0. 5 cm tan - pink tissue fragment. submitted in toto for frozen section in q. s / dx : squamous carcinoma in situ. ( r ) left level ii " b " - received is a 1. 2 x 1. 2 x 0. 7 of 

In [27]:
pairs = list(product(test_qs, retrieved_chunks))

In [28]:
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [29]:
scores = cross_encoder.predict(pairs)

In [30]:
scores

array([-3.8295927 , -5.830491  , -7.0212574 , -7.6460786 , -8.506743  ,
       -5.339867  , -7.0146947 , -6.4062567 , -8.054098  , -3.8724298 ,
       -4.2016964 , -6.261025  , -6.1777525 , -4.1040106 , -3.4471984 ,
       -5.7553787 , -3.084394  , -4.976289  ,  0.67407274,  1.8272525 ,
       -0.71477675, -3.173746  , -4.1793504 ,  1.8512446 , -0.6694144 ,
       -0.33772734, -0.48866773], dtype=float32)

In [31]:
for o in np.argsort(scores)[::-1]:
    print(o)

23
19
18
25
26
24
20
16
21
14
0
9
13
22
10
17
5
15
1
12
11
7
6
2
3
8
4


In [32]:
index_with_highest_score = np.argsort(scores)[::-1][:5]
select_shots = set([pairs[i][1] for i in index_with_highest_score])

In [33]:
pairs

[('visit number. event date and time. procedure ordered. specimen #. complete. result date : specimen ( s ) received. 1. oral - cavity : medial margin palate. 2. oral - cavity : right partial maxillectomy. 3. oral cavity : right partial glossectomy - stitch anterior margin. 4. oral cavity : deep margin tongue. diagnosis. 1. medial palate margin : negative for tumour. 2. right partial maxillectomy : squamous cell carcinoma, well differentiated. i. the maximum dimension of the tumour is 1. 3 cm. 3. the maximum thickness of the tumour is 0. 5 cm. :. no lymphovascular or perineural invasion identified. 1. bone and bone margins are negative for tumour. : all other margins are negative for tumour. 3. right partial glossectomy : squamous cell carcinoma in - situ ( cis ), focally suspicious for superficial. nvasion. l. maximum lesional dimension 0. 6 cm. ). maximum lesional thickness 0. 1 cm. :. all margins are negative for tumour ( see comment ). i. deep margin of tongue. negative for tumour.

In [34]:
select_shots

{') left mandible margin - a 0. 4 x 0. 4 x 0. 2 cm aggregate of soft pink - tan tissue and bony fragments. touch preparations. are made for intraoperative evaluation, and the remainder of the tissue is submitted entirely in m for decalcification. tp / dx : bone marrow cells, no tumor present. ( n ) right tongue margin - a 2. 5 x 0. 3 x 0. 2 cm tan - pink tissue magment. submitted in toto for frozen section in n. fs / dx : squamous carcinoma in situ. ( o ) left tongue margin - a 2. 2 x 0. 4 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in o. fs / dx : squamous carcinoma in situ. ( p ) left base of tongue - a 1. 0 x 0. 3 x 0. 2 cm tan - pink tissue fragment. submitted in toto for frozen section in p. fs / dx : squamous carcinoma in situ. ( q ) right base of tongue - a 1. 5 x 1. 0 x 0. 5 cm tan - pink tissue fragment. submitted in toto for frozen section in q. s / dx : squamous carcinoma in situ. ( r ) left level ii " b " - received is a 1. 2 x 1. 2 x 0. 7 of 