In [23]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
)
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

# Load MEDMCQA Dataset and Select Columns 

In [36]:
def filter_none(example):
    return (example["exp"] is not None) and (example["question"] is not None)


# load MedMCQA
dataset = load_dataset("openlifescienceai/medmcqa")
dataset = dataset["train"].filter(filter_none).select_columns(["question", "exp"])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataset_df = pd.DataFrame(dataset.to_dict())
print(f"dataset length: {len(dataset)}")

dataset length: 160869


# Load Finetuned Model From Hugging Face

In [26]:
from huggingface_hub import hf_hub_download


tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = BertForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# load the trained model from huggingface
repo_id = "alibababeig/nlp-hw4"
filename = "BioClinicalBert-MLM-Finetuned.pth"
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=filename)

checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.bert  # dropping MLM head
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [28]:
max_length = 128
cls_tokens = []
for batch in tqdm(dataloader):
    tokens = tokenizer(
        batch["question"],
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    input_ids = tokens["input_ids"].to(device)
    att_mask = tokens["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, att_mask)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS Token found in the given model")

    cls_embedding = cls_embedding.cpu().numpy().tolist()
    cls_tokens += cls_embedding

print(len(cls_tokens))

  0%|          | 0/5028 [00:00<?, ?it/s]

160869


In [31]:
cls_tokens = [np.asarray(cls) for cls in cls_tokens]
dataset_df["question_cls"] = cls_tokens
display(dataset_df)

Unnamed: 0,question,exp,question_cls
0,Chronic urethral obstruction due to benign pri...,Chronic urethral obstruction because of urinar...,"[0.263091504573822, -0.05961083993315697, -0.2..."
1,Which vitamin is supplied from only animal sou...,Ans. (c) Vitamin B12 Ref: Harrison's 19th ed. ...,"[0.44245609641075134, 0.018535641953349113, -0..."
2,All of the following are surgical options for ...,"Ans. is 'd' i.e., Roux en Y Duodenal Bypass Ba...","[0.6987701058387756, 0.264356791973114, -0.003..."
3,Following endaerectomy on the right common car...,The central aery of the retina is a branch of ...,"[0.11031772941350937, 0.002846830990165472, -0..."
4,Growth hormone has its effect on growth through?,"Ans. is 'b' i.e., IGI-1GH has two major functi...","[0.40995046496391296, 0.4891767203807831, -0.5..."
...,...,...,...
160864,Organism that causes emphysematous cholecystit...,Ref: Harrison's 18th editionExplanation:Emphys...,"[0.7292048335075378, 0.3522907793521881, -0.17..."
160865,Most common site for extra mammary Paget&;s di...,.It is superficial manifestation of an intradu...,"[0.5586345195770264, 0.08569207042455673, -0.3..."
160866,Inferior Rib notching is seen in all except?,Answer is D (Neurofibromatosis) Neurofibromato...,"[0.718835175037384, -0.18725897371768951, 0.02..."
160867,Which is false regarding cryptococcus neoformans?,"Ans. is 'c' i e., Urease negative Cryptococcus...","[0.5742323398590088, 0.21848540008068085, -0.4..."


In [33]:
dataset_df.to_json("MEDMCQA-dataset-with-CLS.json", double_precision=15)

In [35]:
from huggingface_hub import HfApi

# generate a token from Profile > Setting > Access Tokens with write access
api = HfApi(
    token="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",
)
api.upload_file(
    path_or_fileobj="./MEDMCQA-dataset-with-CLS.json",
    path_in_repo="MEDMCQA-dataset-with-CLS.json",
    repo_id="alibababeig/nlp-hw4",
    repo_type="model",
)

MEDMCQA-dataset-with-CLS.json:   0%|          | 0.00/2.37G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alibababeig/nlp-hw4/commit/4e46d951f887bcd6c51fab4570e3f11551129970', commit_message='Upload MEDMCQA-dataset-with-CLS.json with huggingface_hub', commit_description='', oid='4e46d951f887bcd6c51fab4570e3f11551129970', pr_url=None, pr_revision=None, pr_num=None)