In [26]:
import torch
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
)
import numpy as np
import pandas as pd

In [27]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

# Load Finetuned Model From Hugging Face

In [28]:
from huggingface_hub import hf_hub_download


tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = BertForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# load the trained model from huggingface
repo_id = "alibababeig/nlp-hw4"
filename = "BioClinicalBert-MLM-Finetuned.pth"
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=filename)

checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.bert  # dropping MLM head
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

# Load Dataset From Hugging Face

In [29]:
from huggingface_hub import hf_hub_download


repo_id = "alibababeig/nlp-hw4"
filename = "MEDMCQA-dataset-with-CLS.json"
dataset_path = hf_hub_download(repo_id=repo_id, filename=filename)
loaded_df = pd.read_json(dataset_path)
display(loaded_df)

Unnamed: 0,question,exp,question_cls
0,Chronic urethral obstruction due to benign pri...,Chronic urethral obstruction because of urinar...,"[0.263091504573822, -0.059610839933157, -0.209..."
1,Which vitamin is supplied from only animal sou...,Ans. (c) Vitamin B12 Ref: Harrison's 19th ed. ...,"[0.442456096410751, 0.018535641953349002, -0.6..."
2,All of the following are surgical options for ...,"Ans. is 'd' i.e., Roux en Y Duodenal Bypass Ba...","[0.6987701058387761, 0.264356791973114, -0.003..."
3,Following endaerectomy on the right common car...,The central aery of the retina is a branch of ...,"[0.11031772941350901, 0.002846830990165, -0.61..."
4,Growth hormone has its effect on growth through?,"Ans. is 'b' i.e., IGI-1GH has two major functi...","[0.409950464963913, 0.489176720380783, -0.5050..."
...,...,...,...
160864,Organism that causes emphysematous cholecystit...,Ref: Harrison's 18th editionExplanation:Emphys...,"[0.7292048335075381, 0.35229077935218805, -0.1..."
160865,Most common site for extra mammary Paget&;s di...,.It is superficial manifestation of an intradu...,"[0.558634519577026, 0.08569207042455701, -0.39..."
160866,Inferior Rib notching is seen in all except?,Answer is D (Neurofibromatosis) Neurofibromato...,"[0.718835175037384, -0.18725897371768999, 0.02..."
160867,Which is false regarding cryptococcus neoformans?,"Ans. is 'c' i e., Urease negative Cryptococcus...","[0.574232339859009, 0.218485400080681, -0.4199..."


In [40]:
print(loaded_df.iloc[18000]['question'])
print()
print(loaded_df.iloc[18000]['exp'])

Not true about acdes mosquito

Eggs are cigar shaped.


In [31]:
def encode_text(text, tokenizer, model, max_length=512):
    tokens = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**tokens)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS Token found in the given model")
        

    return cls_embedding.cpu()

In [45]:
def cosine_similarity(query, dataset):
    query_norm = query / np.linalg.norm(query)
    dataset_norm = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
    similarities = np.dot(dataset_norm, query_norm)
    return similarities


def MSE_similarity(query, dataset):
    dists = ((dataset - query) ** 2).sum(axis=1)
    return 1.0 / dists  # inverse of distance scores are equivalent to similarity


def k_nearest_embeddings(query, dataset, k, similarity_metric=cosine_similarity):
    similarities = similarity_metric(query, dataset)

    # Get the indices of the top k highest similarities
    nearest_indices = np.argpartition(similarities, -k)[-k:]

    # Sort these indices by the actual similarities
    nearest_indices = nearest_indices[np.argsort(similarities[nearest_indices])[::-1]]

    # Get the top k similarities and corresponding embeddings
    top_k_similarities = similarities[nearest_indices]
    top_k_embeddings = dataset[nearest_indices]

    return nearest_indices, top_k_embeddings, top_k_similarities


k = 3
query = "which is Not true about acdes mosquito?"
cls_emb = encode_text(query, tokenizer, model).numpy().squeeze()

In [46]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=MSE_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("MSE similarities of the k nearest embeddings:", nearest_similarities)
mins_mse = loaded_df.iloc[nearest_indices]
mins_mse.reset_index(drop=True, inplace=True)
display(mins_mse)

Row indices of the k nearest embeddings: [127303 120590  76524]
MSE similarities of the k nearest embeddings: [0.13077756 0.13054807 0.11076013]


Unnamed: 0,question,exp,question_cls
0,Which is not true about carbuncle?,CARBUNCLE\nWord meaning of carbuncle is charco...,"[0.542609930038452, 0.30032971501350403, -0.45..."
1,Which is not true about hepatitis B virus?,"Ans. is 'b' i.e., Transmitted by faeco - oral ...","[0.44905552268028304, 0.23448909819126101, -0...."
2,Which is not true about Polio vaccine?,There is no long term carrier state for poliov...,"[0.485345751047134, 0.22388730943203, -0.45827..."


In [47]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=cosine_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("Cosine similarities of the k nearest embeddings:", nearest_similarities)
mins_cosine = loaded_df.iloc[nearest_indices]
mins_cosine.reset_index(drop=True, inplace=True)
display(mins_cosine)

Row indices of the k nearest embeddings: [127303 120590  76524]
Cosine similarities of the k nearest embeddings: [0.98005391 0.98004156 0.97650476]


Unnamed: 0,question,exp,question_cls
0,Which is not true about carbuncle?,CARBUNCLE\nWord meaning of carbuncle is charco...,"[0.542609930038452, 0.30032971501350403, -0.45..."
1,Which is not true about hepatitis B virus?,"Ans. is 'b' i.e., Transmitted by faeco - oral ...","[0.44905552268028304, 0.23448909819126101, -0...."
2,Which is not true about Polio vaccine?,There is no long term carrier state for poliov...,"[0.485345751047134, 0.22388730943203, -0.45827..."


In [48]:
idx = 0
print(mins_cosine["question"][idx])
print(mins_cosine["exp"][idx])

Which is not true about carbuncle?
CARBUNCLE
Word meaning of carbuncle is charcoal.

It is an infective gangrene of skin and subcutaneous tissue
StaphyIococcs aureus is the main culprit.
Common site of occurrence is nape of the neck and back. Skin in this area is thick. Condition also can occur in shoulder, cheek. hand. forearm.
It is common in diabetics and after forty years of age.
It is common in males.
