In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    BertTokenizer,
    BertForMaskedLM,
)
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
from huggingface_hub import HfApi
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

!unzip -qn /usr/share/nltk_data/corpora/wordnet.zip -d /usr/share/nltk_data/corpora/

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

trainset_range = list(range(18000, 38000))
medmcqa_dataset_path = "openlifescienceai/medmcqa"
base_bert_path = "emilyalsentzer/Bio_ClinicalBERT"
finetuned_bert_path = (
    "BioClinicalBert-MLM-Finetuned-40k-25epoch-exp-25epoch-questions.pth"
)
dataset_file_name = (
    "MEDMCQA-dataset-with-CLS-40k-25epoch-exp-25epoch-questions-nltk.json"
)
repo_id = "alibababeig/nlp-hw4"
push_dataset_to_huggingface = False

batch_size = 32

# Load MEDMCQA Dataset and Select Columns 

In [3]:
def filter_none(example):
    return (
        (example["exp"] is not None)
        and (len(example["exp"]) > 20)
        and (example["question"] is not None)
    )


dataset = load_dataset(medmcqa_dataset_path)
dataset = dataset["train"].select(trainset_range)
dataset = dataset.filter(filter_none).select_columns(["question", "exp"])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataset_df = pd.DataFrame(dataset.to_dict())
print(f"dataset length: {len(dataset)}")

Downloading readme:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/85.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/936k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/182822 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6150 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4183 [00:00<?, ? examples/s]

Filter:   0%|          | 0/20000 [00:00<?, ? examples/s]

dataset length: 16831


# Load Finetuned Model From Hugging Face

In [4]:
from huggingface_hub import hf_hub_download


tokenizer = BertTokenizer.from_pretrained(base_bert_path)
model = BertForMaskedLM.from_pretrained(base_bert_path).to(device)

checkpoint_file = hf_hub_download(repo_id=repo_id, filename=finetuned_bert_path)
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.bert  # dropping MLM head
model.eval()

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


(…)-MLM-Finetuned-40k-25epoch-questions.pth:   0%|          | 0.00/433M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [5]:
def preprocess_text(text):
    tokens = word_tokenize(text)
    tokens = [word.lower() for word in tokens]
    tokens = [word for word in tokens if word.isalpha()]
    stop_words = set(stopwords.words("english"))
    tokens = [word for word in tokens if word not in stop_words]
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(word) for word in tokens]
    return " ".join(tokens)

In [None]:
max_length = 128
cls_tokens = []
for batch in tqdm(dataloader):
    batch["question"] = [preprocess_text(txt) for txt in batch["question"]]
    tokens = tokenizer(
        batch["question"],
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    input_ids = tokens["input_ids"].to(device)
    att_mask = tokens["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, att_mask)

    if "pooler_output" in outputs:
        cls_embedding = outputs.pooler_output
    elif "last_hidden_state" in outputs:
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    else:
        raise Exception("No CLS token found in the given model")

    cls_embedding = cls_embedding.cpu().numpy().tolist()
    cls_tokens += cls_embedding

In [7]:
cls_tokens = [np.asarray(cls) for cls in cls_tokens]
dataset_df["question_cls"] = cls_tokens
display(dataset_df)

Unnamed: 0,question,exp,question_cls
0,"All of the following are pyrogenic cytokines, ...",Interleukin 18 is not a pyrogenic cytokine. IL...,"[0.5136734843254089, -0.6453795433044434, 0.18..."
1,40-year old female presented with neck swellin...,Ref. Robbins Pathology. 9th edition. Page. 109...,"[-0.18327629566192627, -0.23770320415496826, -..."
2,Following statement regarding dislocation of t...,Anterior dislocation is more common in which h...,"[-0.4960843026638031, -0.30119097232818604, 0...."
3,The active search for unrecognized disease or ...,Screening is the search for unrecognized disea...,"[0.05901067703962326, 0.06995794922113419, -0...."
4,Fir tree pattern lesion is seen in,Fir tree pattern of distribution of lesions is...,"[-0.3408776819705963, -0.5097606778144836, -0...."
...,...,...,...
16826,Carcinoma sigmoid colon with obstruction Manag...,- Obstruction due to rectosigmoid growth with ...,"[-0.00792787317186594, 0.36405614018440247, -0..."
16827,ADHD in childhood can lead to which of the fol...,"ADHD can lead to substance abuse,mood disorder...","[-0.3736489713191986, -0.13032260537147522, -0..."
16828,Nerve for adductor compament of thigh ?,Ans. B) Obturator nerveObturator nerve is the ...,"[0.35813137888908386, 0.021466786041855812, 0...."
16829,The &;a&;wave of jugular venous pulse is produ...,JVP or jugular venous is a reflection of the r...,"[-0.3384331464767456, 0.4519999027252197, -0.2..."


In [8]:
dataset_df.to_json(dataset_file_name, double_precision=15)

In [9]:
if push_dataset_to_huggingface:
    # generate a token from Profile > Setting > Access Tokens with write access
    api = HfApi(
        token="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",
    )
    api.upload_file(
        path_or_fileobj=f"./{dataset_file_name}",
        path_in_repo=dataset_file_name,
        repo_id=repo_id,
        repo_type="model",
    )

MEDMCQA-dataset-with-CLS-40k-25epoch-questions-nltk.json:   0%|          | 0.00/249M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/alibababeig/nlp-hw4/commit/9a68fe97e1ca2ed47884caa2c4f141564d80299e', commit_message='Upload MEDMCQA-dataset-with-CLS-40k-25epoch-questions-nltk.json with huggingface_hub', commit_description='', oid='9a68fe97e1ca2ed47884caa2c4f141564d80299e', pr_url=None, pr_revision=None, pr_num=None)