In [2]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
    T5Tokenizer,
    T5ForConditionalGeneration
)
from transformers.models.t5.modeling_t5 import T5LayerFF
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

!unzip -qn /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

base_bert_path = "emilyalsentzer/Bio_ClinicalBERT"
base_t5_path = "t5-small"
finetuned_bert_path = (
    "BioClinicalBert-MLM-Finetuned-40k-25epoch-exp-25epoch-questions.pth"
)
finetuned_t5_path = "T5-Finetuned-15k-20epoch.pth"
dataset_file_name = (
    "MEDMCQA-dataset-with-CLS-40k-25epoch-exp-25epoch-questions-nltk.json"
)
repo_id = "alibababeig/nlp-hw4"

batch_size = 32
bottleneck_size = 32
k = 3  # KNN hyperparameter

# Load Finetuned Model From Hugging Face

In [4]:
from huggingface_hub import hf_hub_download


bert_tokenizer = BertTokenizer.from_pretrained(base_bert_path)
bert_model = BertForMaskedLM.from_pretrained(base_bert_path).to(device)

checkpoint_file = hf_hub_download(repo_id=repo_id, filename=finetuned_bert_path)
checkpoint = torch.load(checkpoint_file)
bert_model.load_state_dict(checkpoint["model_state_dict"])
bert_model = bert_model.bert  # dropping MLM head
bert_model.eval()

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


(…)ed-40k-25epoch-exp-25epoch-questions.pth:   0%|          | 0.00/433M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

# Load Dataset From Hugging Face

In [5]:
from huggingface_hub import hf_hub_download


dataset_path = hf_hub_download(repo_id=repo_id, filename=dataset_file_name)
loaded_df = pd.read_json(dataset_path)
display(loaded_df)

(…)-25epoch-exp-25epoch-questions-nltk.json:   0%|          | 0.00/249M [00:00<?, ?B/s]

Unnamed: 0,question,exp,question_cls
0,"All of the following are pyrogenic cytokines, ...",Interleukin 18 is not a pyrogenic cytokine. IL...,"[0.549001753330231, -0.12343280017375902, 0.15..."
1,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[-0.24133916199207303, 0.097042627632618, -0.2..."
2,Following statement regarding dislocation of t...,Anterior dislocation is more common in which h...,"[-0.386974930763245, -0.131634533405304, 0.241..."
3,The active search for unrecognized disease or ...,Screening is the search for unrecognized disea...,"[-0.091413952410221, -0.04559937492013, -0.082..."
4,Fir tree pattern lesion is seen in,Fir tree pattern of distribution of lesions is...,"[-0.41899171471595803, -0.24402141571044902, -..."
...,...,...,...
16826,Carcinoma sigmoid colon with obstruction Manag...,- Obstruction due to rectosigmoid growth with ...,"[0.241796687245369, 0.7256665825843811, -0.431..."
16827,ADHD in childhood can lead to which of the fol...,"ADHD can lead to substance abuse,mood disorder...","[0.20766235888004303, -0.30834984779357905, 0...."
16828,Nerve for adductor compament of thigh ?,Ans. B) Obturator nerveObturator nerve is the ...,"[0.134738609194756, 0.008813104592264, -0.2542..."
16829,The &;a&;wave of jugular venous pulse is produ...,JVP or jugular venous is a reflection of the r...,"[-0.035813611000776, 0.18877704441547402, -0.1..."


In [6]:
def preprocess_text(text):
    tokens = word_tokenize(text)
    tokens = [word.lower() for word in tokens]
    tokens = [word for word in tokens if word.isalpha()]
    stop_words = set(stopwords.words("english"))
    tokens = [word for word in tokens if word not in stop_words]
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(word) for word in tokens]
    return " ".join(tokens)


def encode_text(text, tokenizer, bert_model, max_length=128):
    text = preprocess_text(text)
    tokens = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = bert_model(**tokens)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS token found in the given model")

    return cls_embedding.cpu()

In [7]:
def cosine_similarity(query, dataset):
    query_norm = query / np.linalg.norm(query)
    dataset_norm = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
    similarities = np.dot(dataset_norm, query_norm)
    return similarities


def MSE_similarity(query, dataset):
    dists = ((dataset - query) ** 2).sum(axis=1)
    return 1.0 / dists  # inverse of distance scores are equivalent to similarity


def k_nearest_embeddings(query, dataset, k, similarity_metric=cosine_similarity):
    similarities = similarity_metric(query, dataset)

    # Get the indices of the top k highest similarities
    nearest_indices = np.argpartition(similarities, -k)[-k:]

    # Sort these indices by the actual similarities
    nearest_indices = nearest_indices[np.argsort(similarities[nearest_indices])[::-1]]

    # Get the top k similarities and corresponding embeddings
    top_k_similarities = similarities[nearest_indices]
    top_k_embeddings = dataset[nearest_indices]

    return nearest_indices, top_k_embeddings, top_k_similarities


# query = "pyrogenic cytokines"
query = "female with neck swelling. Gross and histology."
cls_emb = encode_text(query, bert_tokenizer, bert_model).numpy().squeeze()

In [8]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=MSE_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("MSE similarities of the k nearest embeddings:", nearest_similarities)
mins_mse = loaded_df.iloc[nearest_indices]
mins_mse = mins_mse.reset_index(drop=True)
display(mins_mse)

Row indices of the k nearest embeddings: [   1 3456 9274]
MSE similarities of the k nearest embeddings: [0.01963839 0.01770742 0.01723149]


Unnamed: 0,question,exp,question_cls
0,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[-0.24133916199207303, 0.097042627632618, -0.2..."
1,Max Joseph&;s space is a histopathological fea...,Max Joseph's space is a characteristic histolo...,"[-0.09414966404438, -0.05790701508522, 0.25478..."
2,'Mickey Mouse Ears' is a histological feature of:,Paracoccidioidomycosis is a deep fungal infect...,"[0.18600422143936202, -0.17785388231277502, 0...."


In [9]:
nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=cosine_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("Cosine similarities of the k nearest embeddings:", nearest_similarities)
mins_cosine = loaded_df.iloc[nearest_indices]
mins_cosine = mins_cosine.reset_index(drop=True)
display(mins_cosine)

Row indices of the k nearest embeddings: [   1 3456 7323]
Cosine similarities of the k nearest embeddings: [0.86631365 0.84796207 0.84498733]


Unnamed: 0,question,exp,question_cls
0,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[-0.24133916199207303, 0.097042627632618, -0.2..."
1,Max Joseph&;s space is a histopathological fea...,Max Joseph's space is a characteristic histolo...,"[-0.09414966404438, -0.05790701508522, 0.25478..."
2,A female patient presents with deep Desmoid tu...,Desmoid tumour is a tumour arising from the mu...,"[0.13216152787208602, 0.38578277826309204, -0...."


In [10]:
idx = 0
print(mins_cosine["question"][idx])
print(mins_cosine["exp"][idx])

40-year old female presented with neck swelling. Gross and histology is shown below.  What is your diagnosis?
Ref. Robbins Pathology. 9th edition. Page. 1099
Medullary carcinoma thyroid
Gross

Single or multiple
Typically nonencapsulated
Solid, gray / tan / yellow, firm, may be infiltrative

Microscopy

Round, polygonal or spindle cells in nests, cords or follicles, defined by sharply outlined fibrous bands
Tumor cells have granular cytoplasm and uniform round / oval nuclei with punctate chromatin
Stroma has amyloid deposits from calcitonin, prominent vascularity with glomeruloid configuration or long cords of vessels, coarse calcifications

 
IHC – Calcitonin


# T5 with Fine-tuning

In [11]:
t5_tokenizer = T5Tokenizer.from_pretrained(base_t5_path)
finetuned_t5_model = T5ForConditionalGeneration.from_pretrained(base_t5_path)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [12]:
# Adapter layer
class AdapterLayer(nn.Module):
    def __init__(self, emb_dim: int, bottleneck_size: int):

        super().__init__()

        self.sharif_llm_adapter = nn.Sequential(
            nn.Linear(emb_dim, bottleneck_size),
            nn.ReLU(),
            nn.Linear(bottleneck_size, emb_dim),
        )

    def forward(self, x: torch.Tensor):
        adapter_output = self.sharif_llm_adapter(x)
        output = x + adapter_output
        return output


class FeedForwardAdapterWrapper(nn.Module):
    def __init__(self, original_module: T5LayerFF, bottleneck_size: int):

        super().__init__()
        assert isinstance(original_module, T5LayerFF)

        self.original_module = original_module
        emb_dim = original_module.DenseReluDense.wi.in_features
        self.adapter = AdapterLayer(emb_dim, bottleneck_size)

    def forward(self, x: torch.Tensor):
        output = self.original_module(x)
        output = self.adapter(output)
        return output

In [13]:
def mutate_model_recursive(model: nn.Module, bottleneck_size: int):
    for name, module in model.named_children():
        if isinstance(module, T5LayerFF):
            feed_forward_with_adapter = FeedForwardAdapterWrapper(
                module, bottleneck_size
            )
            setattr(model, name, feed_forward_with_adapter)
            print(f"Replaced {name} with FeedForwardAdapterWrapper layer.")
        else:
            mutate_model_recursive(module, bottleneck_size)


def mutate_model(model: nn.Module, bottleneck_size: int):
    if hasattr(model, "_mutated"):
        print("Model already contains adapter layers! \n Try reloading the model.")
        return

    mutate_model_recursive(model, bottleneck_size)

    model._mutated = True


mutate_model(finetuned_t5_model, bottleneck_size=bottleneck_size)

Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.


In [14]:
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=finetuned_t5_path)
finetuned_t5_model.load_state_dict(torch.load(checkpoint_file)['model_state_dict'])
finetuned_t5_model = finetuned_t5_model.to(device)

T5-Finetuned-15k-20epoch.pth:   0%|          | 0.00/244M [00:00<?, ?B/s]

In [15]:
opt_idx2str = {
    0: "A",
    1: "B",
    2: "C",
    3: "D",
}

# code to generate answer based on model
def generate_answer(row, model, tokenizer):
    model.eval()
    input_text = f"Question: {row['question']}\n\nOptions:\nA: {row['opa']}\nB: {row['opb']}\nC: {row['opc']}\nD: {row['opd']}\n\nExplanation: {row['exp']}\n\nAnswer:"
    input_ids = tokenizer(input_text, truncation=True, max_length=1024)
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    outputs = model.generate(input_ids.to(device), max_length=5)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

In [99]:
# t = {
#     'question': 'What is the term for the active searching for unrecognized problems in apparently healthy individuals using quick tests?',
#     'opa': 'Vaccination',
#     'opb': 'Screening',
#     'opc': 'Monitoring',
#     'opd': 'Diagnosis',
#     'cop': 1,
# }
# t = {
#     'question': 'Which of the following causes a decrease in ESR?',
#     'opa': 'Sickle cell anaemia',
#     'opb': 'Inflammation',
#     'opc': 'COVID-19',
#     'opd': 'Pregnancy',
#     'cop': 0,
# }
# t = {
#     'question': 'Which option could be identified using the cephalic index?',
#     'opa': 'Blood type',
#     'opb': 'Sex',
#     'opc': 'Hair color',
#     'opd': 'Race',
#     'cop': 3,
# }
# t = {
#     'question': 'How does radiotherapy work?',
#     'opa': 'By using ultrasound tissue scanning',
#     'opb': 'By ionization of tissues',
#     'opc': 'By blocking hormone receptors',
#     'opd': 'By necrosis of body cells',
#     'cop': 1,
# }
# t = {
#     'question': 'In which case is a magistrate inquest NOT required?',
#     'opa': 'Death in police custody',
#     'opb': 'Death in police firing',
#     'opc': 'Death by suicide',
#     'opd': 'Death in psychiatry hospital',
#     'cop': 2,
# }
# t = {
#     'question': 'Which test is most related to Addison\'s disease (i.e. adrenal insufficiency)?',
#     'opa': 'ACTH (Cosyntropin) test',
#     'opb': 'Blood glucose test',
#     'opc': 'MRI of the adrenal glands',
#     'opd': 'CT scan of abdomen',
#     'cop': 0,
# }
# t = {
#     'question': 'For a woman in her 30s, under which circumstance, there is an increased risk of having a baby with Down syndrome?',
#     'opa': 'Undergoing IVF treatment',
#     'opb': 'Having a previous baby with Klinefelter syndrome',
#     'opc': 'Having three first-trimester miscarriages',
#     'opd': 'Having a previous baby with Turner syndrome',
#     'cop': 2,
# }
# t = {
#     'question': 'Which one is inhibited by Azaserine?',
#     'opa': 'Ribose-phosphate diphosphokinase',
#     'opb': 'Dihydrofolate reductase',
#     'opc': 'Glycinamide ribonucleotide transformylase',
#     'opd': 'Formyl glycinamide ribonucleotide amidotransferase (aka PurL)',
#     'cop': 3,
# }
# t = {
#     'question': 'Which metric provides information about the completed family size?',
#     'opa': 'Pregnancy rate',
#     'opb': 'General marital fertility rate',
#     'opc': 'Gross reproductive rate',
#     'opd': 'Total fertility rate',
#     'cop': 3,
# }
# t = {
#     'question': 'What is the most frequent complication associated with a Colles fracture?',
#     'opa': 'Stiffness of fingers',
#     'opb': 'Compartment syndrome',
#     'opc': 'Malunion',
#     'opd': 'Carpal tunnel syndrome',
#     'cop': 0,
# }
# t = {
#     'question': 'The ligament of Berry in the thyroid gland attaches it to which structure?',
#     'opa': 'Larynx',
#     'opb': 'Cricoid cartilage',
#     'opc': 'Esophagus',
#     'opd': 'Thyroid',
#     'cop': 1,
# }
# t = {
#     'question': 'Conn syndrome is caused by overproduction of which hormone?',
#     'opa': 'ACTH',
#     'opb': 'ADH',
#     'opc': 'Aldosterone',
#     'opd': 'Cortisol',
#     'cop': 2,
# }

In [100]:
k = 2
cls_emb = encode_text(t["question"], bert_tokenizer, bert_model).numpy().squeeze()

nearest_indices, _, nearest_similarities = k_nearest_embeddings(
    cls_emb,
    np.asarray(loaded_df["question_cls"].tolist()),
    k,
    similarity_metric=cosine_similarity,
)
print("Row indices of the k nearest embeddings:", nearest_indices)
print("Cosine similarities of the k nearest embeddings:", nearest_similarities)
mins_cosine = loaded_df.iloc[nearest_indices]
mins_cosine = mins_cosine.reset_index(drop=True)
display(mins_cosine)

exps = mins_cosine["exp"].tolist()
# exps_str = '\n'.join([f'{i+1}. {exp}' for i, exp in enumerate(exps)])
exps_str = exps[0]
print(exps_str)

t["exp"] = exps_str

Row indices of the k nearest embeddings: [   52 10856]
Cosine similarities of the k nearest embeddings: [0.94617574 0.90021785]


Unnamed: 0,question,exp,question_cls
0,Conn syndrome is seen due to increased product...,An adrenocoical disorder caused by excessive s...,"[-0.11438417434692401, 0.267834663391113, -0.1..."
1,Zellweger syndrome is due to defect in-,"Ans. is 'b' i.e., Fatty acid oxidation in pero...","[0.192382022738457, 0.050101213157177006, 0.12..."


An adrenocoical disorder caused by excessive secretion of aldosterone. ; primary aldosteronismPrimary aldosteronism is an adrenocoical disorder caused by excessive secretion of aldosterone and characterized by headaches, nocturia, polyuria, fatigue, hypeension, potassium depletion, hypokalemic alkalosis, hypervolemia, and decreased plasma renin activity; may be associated with small benign adrenocoical adenomas.Ref: Ganong&;s review of medical physiology; 24th edition; page no:-364


In [101]:
answer = generate_answer(t, finetuned_t5_model, t5_tokenizer)
print(f"Model's output =  \"{answer}\"")
print(f"Correct output =  \"Answer: {opt_idx2str[t['cop']]}\"")

Model's output =  "Answer: C"
Correct output =  "Answer: C"


# T5 without Fine-tuning

In [102]:
t5_model = T5ForConditionalGeneration.from_pretrained(base_t5_path).to(device)

In [103]:
answer = generate_answer(t, t5_model, t5_tokenizer)
print(f"Model's output =  \"{answer}\"")
print(f"Correct output =  \"Answer: {opt_idx2str[t['cop']]}\"")

Model's output =  "True"
Correct output =  "Answer: C"
