In [18]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
)
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

!unzip -qn /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32

# Load MEDMCQA Dataset and Select Columns 

In [20]:
def filter_none(example):
    return (example["exp"] is not None) and (len(example["exp"]) > 20) and (example["question"] is not None)


trainset_range = list(range(18000, 38000))


# load MedMCQA
dataset = load_dataset("openlifescienceai/medmcqa")
dataset = dataset["train"].select(trainset_range)
dataset = dataset.filter(filter_none).select_columns(["question", "exp"])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataset_df = pd.DataFrame(dataset.to_dict())
print(f"dataset length: {len(dataset)}")

Filter:   0%|          | 0/20000 [00:00<?, ? examples/s]

dataset length: 16831


# Load Finetuned Model From Hugging Face

In [21]:
from huggingface_hub import hf_hub_download


tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = BertForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# load the trained model from huggingface
repo_id = "alibababeig/nlp-hw4"
filename = "BioClinicalBert-MLM-Finetuned-20k-15epoch.pth"
checkpoint_file = hf_hub_download(repo_id=repo_id, filename=filename)

checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.bert  # dropping MLM head
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [22]:
def preprocess_text(text):
    tokens = word_tokenize(text)
    tokens = [word.lower() for word in tokens]
    tokens = [word for word in tokens if word.isalpha()]
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(word) for word in tokens]
    return " ".join(tokens)

In [23]:
max_length = 128
cls_tokens = []
for batch in tqdm(dataloader):
    batch["question"] = [preprocess_text(txt) for txt in batch["question"]]
    tokens = tokenizer(
        batch["question"],
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    input_ids = tokens["input_ids"].to(device)
    att_mask = tokens["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, att_mask)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS Token found in the given model")

    cls_embedding = cls_embedding.cpu().numpy().tolist()
    cls_tokens += cls_embedding

print(len(cls_tokens))

  0%|          | 0/526 [00:00<?, ?it/s]

16831


In [24]:
cls_tokens = [np.asarray(cls) for cls in cls_tokens]
dataset_df["question_cls"] = cls_tokens
display(dataset_df)

Unnamed: 0,question,exp,question_cls
0,"All of the following are pyrogenic cytokines, ...",Interleukin 18 is not a pyrogenic cytokine. IL...,"[0.210591122508049, 0.03523605689406395, -0.16..."
1,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[0.14564621448516846, 0.25639796257019043, -0...."
2,Following statement regarding dislocation of t...,Anterior dislocation is more common in which h...,"[-0.20996041595935822, -0.06706653535366058, -..."
3,The active search for unrecognized disease or ...,Screening is the search for unrecognized disea...,"[0.061764661222696304, 0.05983557179570198, -0..."
4,Fir tree pattern lesion is seen in,Fir tree pattern of distribution of lesions is...,"[0.3631107211112976, -0.02278529293835163, -0...."
...,...,...,...
16826,Carcinoma sigmoid colon with obstruction Manag...,- Obstruction due to rectosigmoid growth with ...,"[0.1965109258890152, 0.41880345344543457, -0.2..."
16827,ADHD in childhood can lead to which of the fol...,"ADHD can lead to substance abuse,mood disorder...","[0.24358224868774414, 0.5815790891647339, -0.3..."
16828,Nerve for adductor compament of thigh ?,Ans. B) Obturator nerveObturator nerve is the ...,"[0.18723972141742706, -0.011040580458939075, -..."
16829,The &;a&;wave of jugular venous pulse is produ...,JVP or jugular venous is a reflection of the r...,"[0.21560895442962646, 0.297360360622406, -0.24..."


In [27]:
dataset_df.to_json("MEDMCQA-dataset-with-CLS-20k-nltk.json", double_precision=15)

In [28]:
from huggingface_hub import HfApi

# generate a token from Profile > Setting > Access Tokens with write access
api = HfApi(
    token="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",
)
api.upload_file(
    path_or_fileobj="./MEDMCQA-dataset-with-CLS-20k-nltk.json",
    path_in_repo="MEDMCQA-dataset-with-CLS-20k-nltk.json",
    repo_id="alibababeig/nlp-hw4",
    repo_type="model",
)

MEDMCQA-dataset-with-CLS-20k-nltk.json:   0%|          | 0.00/249M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alibababeig/nlp-hw4/commit/33033a331562d1ae6c9a8fb9eb9ec4b67a24efef', commit_message='Upload MEDMCQA-dataset-with-CLS-20k-nltk.json with huggingface_hub', commit_description='', oid='33033a331562d1ae6c9a8fb9eb9ec4b67a24efef', pr_url=None, pr_revision=None, pr_num=None)