In [2]:
import torch
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
)
import numpy as np
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

!unzip -qn /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

# Load Finetuned Model From Hugging Face

In [4]:
from huggingface_hub import hf_hub_download


tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = BertForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# load the trained model from huggingface
repo_id = "alibababeig/nlp-hw4"
filename = "BioClinicalBert-MLM-Finetuned-20k-15epoch.pth"
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=filename)

checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.bert  # dropping MLM head
model.eval()

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


(…)inicalBert-MLM-Finetuned-20k-15epoch.pth:   0%|          | 0.00/433M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

# Load Dataset From Hugging Face

In [5]:
from huggingface_hub import hf_hub_download


repo_id = "alibababeig/nlp-hw4"
filename = "MEDMCQA-dataset-with-CLS-20k-nltk.json"
dataset_path = hf_hub_download(repo_id=repo_id, filename=filename)
loaded_df = pd.read_json(dataset_path)
display(loaded_df)

MEDMCQA-dataset-with-CLS-20k-nltk.json:   0%|          | 0.00/249M [00:00<?, ?B/s]

Unnamed: 0,question,exp,question_cls
0,"All of the following are pyrogenic cytokines, ...",Interleukin 18 is not a pyrogenic cytokine. IL...,"[0.210591122508049, 0.035236056894064005, -0.1..."
1,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[0.145646214485168, 0.25639796257019, -0.31620..."
2,Following statement regarding dislocation of t...,Anterior dislocation is more common in which h...,"[-0.20996041595935802, -0.067066535353661, -0...."
3,The active search for unrecognized disease or ...,Screening is the search for unrecognized disea...,"[0.061764661222696006, 0.059835571795702, -0.4..."
4,Fir tree pattern lesion is seen in,Fir tree pattern of distribution of lesions is...,"[0.36311072111129805, -0.022785292938352002, -..."
...,...,...,...
16826,Carcinoma sigmoid colon with obstruction Manag...,- Obstruction due to rectosigmoid growth with ...,"[0.196510925889015, 0.418803453445435, -0.2798..."
16827,ADHD in childhood can lead to which of the fol...,"ADHD can lead to substance abuse,mood disorder...","[0.24358224868774403, 0.581579089164734, -0.31..."
16828,Nerve for adductor compament of thigh ?,Ans. B) Obturator nerveObturator nerve is the ...,"[0.187239721417427, -0.011040580458939001, -0...."
16829,The &;a&;wave of jugular venous pulse is produ...,JVP or jugular venous is a reflection of the r...,"[0.21560895442962602, 0.297360360622406, -0.24..."


In [6]:
print(loaded_df.iloc[1]['question'])
print()
print(loaded_df.iloc[1]['exp'])

40-year old female presented with neck swelling. Gross and histology is shown below.  What is your diagnosis?

Ref. Robbins Pathology. 9th edition. Page. 1099
Medullary carcinoma thyroid
Gross

Single or multiple
Typically nonencapsulated
Solid, gray / tan / yellow, firm, may be infiltrative

Microscopy

Round, polygonal or spindle cells in nests, cords or follicles, defined by sharply outlined fibrous bands
Tumor cells have granular cytoplasm and uniform round / oval nuclei with punctate chromatin
Stroma has amyloid deposits from calcitonin, prominent vascularity with glomeruloid configuration or long cords of vessels, coarse calcifications

 
IHC – Calcitonin


In [7]:
def preprocess_text(text):
    tokens = word_tokenize(text)
    tokens = [word.lower() for word in tokens]
    tokens = [word for word in tokens if word.isalpha()]
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(word) for word in tokens]
    return " ".join(tokens)


def encode_text(text, tokenizer, model, max_length=512):
    text = preprocess_text(text)
    tokens = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**tokens)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS Token found in the given model")
        

    return cls_embedding.cpu()

In [31]:
def cosine_similarity(query, dataset):
    query_norm = query / np.linalg.norm(query)
    dataset_norm = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
    similarities = np.dot(dataset_norm, query_norm)
    return similarities


def MSE_similarity(query, dataset):
    dists = ((dataset - query) ** 2).sum(axis=1)
    return 1.0 / dists  # inverse of distance scores are equivalent to similarity


def k_nearest_embeddings(query, dataset, k, similarity_metric=cosine_similarity):
    similarities = similarity_metric(query, dataset)

    # Get the indices of the top k highest similarities
    nearest_indices = np.argpartition(similarities, -k)[-k:]

    # Sort these indices by the actual similarities
    nearest_indices = nearest_indices[np.argsort(similarities[nearest_indices])[::-1]]

    # Get the top k similarities and corresponding embeddings
    top_k_similarities = similarities[nearest_indices]
    top_k_embeddings = dataset[nearest_indices]

    return nearest_indices, top_k_embeddings, top_k_similarities


k = 3
# query = "pyrogenic cytokines"
query = "female with neck swelling. Gross and histology. diagnosis?"
cls_emb = encode_text(query, tokenizer, model).numpy().squeeze()

TAG1 = female with neck swelling. Gross and histology. diagnosis?
TAG2 = female neck swelling gross histology diagnosis


In [32]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=MSE_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("MSE similarities of the k nearest embeddings:", nearest_similarities)
mins_mse = loaded_df.iloc[nearest_indices]
mins_mse.reset_index(drop=True, inplace=True)
display(mins_mse)

Row indices of the k nearest embeddings: [7114 4865    1]
MSE similarities of the k nearest embeddings: [0.05882364 0.05835901 0.05810885]


Unnamed: 0,question,exp,question_cls
0,A female patient presents with patchy hair los...,"Ans. is 'c' i.e., Alopecia areata * Alopecia a...","[0.058671485632658005, 0.49155479669570906, -0..."
1,A female patient is having diarrhea and abdomi...,"Ans. is 'a' i.e., Celiac sprueo Villous atroph...","[0.12823261320591, 0.582038938999176, -0.32525..."
2,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[0.145646214485168, 0.25639796257019, -0.31620..."


In [33]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=cosine_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("Cosine similarities of the k nearest embeddings:", nearest_similarities)
mins_cosine = loaded_df.iloc[nearest_indices]
mins_cosine.reset_index(drop=True, inplace=True)
display(mins_cosine)

Row indices of the k nearest embeddings: [7114    1 4865]
Cosine similarities of the k nearest embeddings: [0.9545285  0.95404987 0.95404229]


Unnamed: 0,question,exp,question_cls
0,A female patient presents with patchy hair los...,"Ans. is 'c' i.e., Alopecia areata * Alopecia a...","[0.058671485632658005, 0.49155479669570906, -0..."
1,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[0.145646214485168, 0.25639796257019, -0.31620..."
2,A female patient is having diarrhea and abdomi...,"Ans. is 'a' i.e., Celiac sprueo Villous atroph...","[0.12823261320591, 0.582038938999176, -0.32525..."


In [15]:
idx = 0
print(mins_cosine["question"][idx])
print(mins_cosine["exp"][idx])

Egg Shell Calcification" in chest X ray is seen in?
ANSWER: (D) All aboveREF: Chapman 4th ed p. 148EGG SHELL CALCIFICATION:SilicosisPneumoconiosisSarcoidosisLymphoma following radiotherapyTB, Histoplasmosis, BlastomycosisAmyloidosis
