data (datasets): https://github.com/dmis-lab/biobert

based on: https://www.youtube.com/watch?v=r6XY80Z9eSA&t=793s

#### 0. Install, download

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
!pip install -q git+https://github.com/huggingface/transformers.git@main 
!pip install -q datasets SentencePiece onnx peft pytorch-lightning

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [16]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import time
import random
import pandas as pd
import numpy as np

from datasets import load_dataset

from transformers import T5Tokenizer, T5ForConditionalGeneration, GPT2Tokenizer
from transformers import AdamW, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup

import re
from tqdm.notebook import tqdm
import textwrap
from termcolor import colored

from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split

import json
from operator import itemgetter
from distutils.util import strtobool

import argparse
import glob
import os
import logging
from itertools import chain
from string import punctuation

from pathlib import Path
from termcolor import colored
import textwrap

In [17]:
pl.seed_everything(42)

INFO:lightning_fabric.utilities.seed:Global seed set to 42


42

#### Data

In [18]:
cd /content/drive/MyDrive/ds/t5-ft

/content/drive/MyDrive/ds/t5-ft


In [19]:
def extract_questions_and_answers(factoid_path: Path):
  with Path(factoid_path).open() as json_file:
    data = json.load(json_file)

  questions = data['data'][0]['paragraphs']

  data_rows = []

  for question in questions:
    context = question['context']
    for question_and_answers in question['qas']:
      question = question_and_answers['question']
      answers = question_and_answers['answers']

      for answer in answers:
        answer_text = answer['text']
        answer_start = answer['answer_start']
        answer_end = answer_start + len(answer_text)

        data_rows.append({
            'question': question,
            'context': context,
            'answer_text': answer_text,
            'answer_start': answer_start,
            'answer_end': answer_end
        })
  return pd.DataFrame(data_rows)

In [20]:
factoid_paths = sorted(list(Path('BioASQ/').glob('BioASQ-train-factoid-*')))
factoid_paths

[PosixPath('BioASQ/BioASQ-train-factoid-4b.json'),
 PosixPath('BioASQ/BioASQ-train-factoid-5b.json'),
 PosixPath('BioASQ/BioASQ-train-factoid-6b.json'),
 PosixPath('BioASQ/BioASQ-train-factoid-7b.json')]

In [21]:
dfs = []

for factoid_path in factoid_paths[:3]:
  dfs.append(extract_questions_and_answers(factoid_path))

df = pd.concat(dfs)

In [22]:
df.head()

Unnamed: 0,question,context,answer_text,answer_start,answer_end
0,What is the inheritance pattern of Li–Fraumeni...,Balanced t(11;15)(q23;q15) in a TP53+/+ breast...,autosomal dominant,213,231
1,What is the inheritance pattern of Li–Fraumeni...,Genetic modeling of Li-Fraumeni syndrome in ze...,autosomal dominant,105,123
2,Which type of lung cancer is afatinib used for?,Clinical perspective of afatinib in non-small ...,EGFR-mutant NSCLC,1203,1220
3,Which hormone abnormalities are characteristic...,"DOCA sensitive pendrin expression in kidney, h...",thyroid,419,426
4,Which hormone abnormalities are characteristic...,Clinical and molecular characteristics of Pend...,thyroid,705,712


In [23]:
df.question.nunique(), df.context.nunique(), df.answer_text.nunique(), df.shape

(443, 2582, 661, (12988, 5))

### Tokenization

In [24]:
MODEL_NAME = "t5-base"

In [25]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [26]:
class BioQADataset(Dataset):

  def __init__(
      self,
      data: pd.DataFrame,
      tokenizer: T5Tokenizer,
      source_max_token_len: int = 396,
      target_max_token_len: int = 32
  ):

    self.tokenizer = tokenizer
    self.data = data
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index:int):
    data_row = self.data.iloc[index]

    source_encoding = tokenizer(
      data_row['question'],
      data_row['context'],
      max_length = self.source_max_token_len,
      padding = 'max_length',
      truncation = 'only_second',
      add_special_tokens = True,
      return_tensors = 'pt'
    )

    target_encoding = tokenizer(
      data_row['answer_text'],
      max_length = self.target_max_token_len,
      padding = 'max_length',
      truncation = True,
      add_special_tokens = True,
      return_tensors = 'pt'
    )

    labels = target_encoding['input_ids']
    labels[labels == 0] = - 100

    return dict(
        question = data_row['question'],
        context = data_row['context'],
        answer_text = data_row['answer_text'],
        input_ids = source_encoding['input_ids'].flatten(),
        attention_mask = source_encoding['attention_mask'].flatten(),
        labels = labels.flatten()
    )

In [27]:
train_df, val_df = train_test_split(df, test_size = 0.05)

In [28]:
train_df.shape, val_df.shape

((12338, 5), (650, 5))

In [29]:
class BioQADataModule(pl.LightningDataModule):
  
  def __init__(
      self,
      train_df: pd.DataFrame,
      test_df: pd.DataFrame,
      tokenizer: T5Tokenizer,
      batch_size: int = 8,
      source_max_token_len: int = 396,
      target_max_token_len: int = 32
  ):
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.tokenizer = tokenizer
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def setup(self, stage=None):
    self.train_dataset = BioQADataset(
        self.train_df,
        self.tokenizer,
        self.source_max_token_len,
        self.target_max_token_len
    )
    self.test_dataset = BioQADataset(
        self.test_df,
        self.tokenizer,
        self.source_max_token_len,
        self.target_max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle = True,
        num_workers = 4
    )
    
  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = 1,
        num_workers = 4
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = 1,
        num_workers = 4
    )

In [30]:
BATCH_SIZE = 12
N_EPOCHS = 3  

data_module = BioQADataModule(train_df, val_df, tokenizer, batch_size = BATCH_SIZE)

In [31]:
data_module.setup()

In [32]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)

### Modeling

In [33]:
class BioQAModel(pl.LightningModule):

  def __init__(self):
    super().__init__()
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)

  def forward(self, input_ids, attention_mask, labels = None):
    output = self.model(
        input_ids = input_ids,
        attention_mask = attention_mask,
        labels = labels
    )
    return output.loss, output.logits

  def training_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('train_loss', loss, prog_bar = True, logger = True)
    return loss

  def validation_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('val_loss', loss, prog_bar = True, logger = True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log('test_loss', loss, prog_bar = True, logger = True)
    return loss
  
  def configure_optimizers(self):
    return AdamW(self.parameters(), lr = 0.0001)

In [34]:
model = BioQAModel()

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath = "checkpoints",
    filename = "best-checkpoint",
    save_top_k = 1,
    verbose = True,
    monitor = "val_loss",
    mode = "min"
)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("training-logs", name = "bio-qa")

In [None]:
trainer = pl.Trainer(
    logger = logger,
    callbacks=[checkpoint_callback],
    max_epochs = N_EPOCHS,
    devices=1, 
    accelerator="gpu",
    log_every_n_steps=30
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
%load_ext tensorboard

In [None]:
!rm -rf lightning_logs

In [None]:
%tensorboard --logdir ./training_logs

ERROR: Failed to launch TensorBoard (exited with 1).
Contents of stderr:
/usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.33' not found (required by /usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server)
/usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.34' not found (required by /usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server)
/usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/lib/python3.10/dist-packages/tensorboard_data_server/bin/server)
Address already in use
Port 6006 is in use by another program. Either identify and stop that program, or start the server with a different port.

In [None]:
trainer.fit(model, data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 1029: 'val_loss' reached 0.09975 (best 0.09975), saving model to '/content/drive/MyDrive/ds/t5-ft/checkpoints/best-checkpoint.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 2058: 'val_loss' reached 0.08391 (best 0.08391), saving model to '/content/drive/MyDrive/ds/t5-ft/checkpoints/best-checkpoint.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 3087: 'val_loss' reached 0.07721 (best 0.07721), saving model to '/content/drive/MyDrive/ds/t5-ft/checkpoints/best-checkpoint.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.


In [None]:
trainer.test()

### Predictions

In [35]:
trained_model = BioQAModel.load_from_checkpoint('checkpoints/best-checkpoint.ckpt')
trained_model.freeze()

In [36]:
trained_model.cuda()

BioQAModel(
  (model): T5ForConditionalGeneration(
    (shared): Embedding(32128, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=768, out_features=3072, bias=False)
                (wo): Linear(in_features

In [37]:
def generate_answer(question):
  source_encoding = tokenizer(
      question["question"],
      question["context"],
      max_length = 396,
      padding = "max_length",
      truncation = "only_second",
      return_attention_mask = True,
      add_special_tokens = True,
      return_tensors = "pt"
  )

  generated_ids = trained_model.model.generate(
      input_ids = source_encoding["input_ids"].cuda(),
      attention_mask = source_encoding["attention_mask"].cuda(),
      num_beams = 1,
      max_length = 80,
      repetition_penalty = 1.0,
      early_stopping = True,
      use_cache = True
  )

  preds = [
      tokenizer.decode(generated_id, skip_special_tokens = True, clean_up_tokenization_spaces = True)
      for generated_id in generated_ids
  ]

  return "".join(preds)

In [38]:
sample_question = val_df.iloc[10]

In [39]:
sample_question["question"]

'What is the main component of the Lewy bodies?'

In [40]:
sample_question["answer_text"]

'alpha-Synuclein'

In [41]:
generate_answer(sample_question)

'alpha-Synuclein'

In [42]:
val_df

Unnamed: 0,question,context,answer_text,answer_start,answer_end
2531,Which drug should be used as an antidote in be...,Flumazenil: a benzodiazepine antagonist. The m...,Flumazenil,202,212
3400,Which bone protein is used in archaelogy for d...,Species identification by analysis of bone col...,collagen,571,579
4218,What is the typical rash associated with gluten ?,[Clinical guidelines for the diagnosis and tre...,Dermatitis herpetiformis,83,107
212,Which transcription factor is considered as a ...,Increased lysosomal biogenesis in activated mi...,transcription factor EB (TFEB),746,776
1740,Which type of myeloma is ixazomib being evalua...,"Phase 1 study of twice-weekly ixazomib, an ora...",multiple myeloma,1473,1489
...,...,...,...,...,...
2258,Which disorder is rated by Palmini classificat...,Post-surgical outcome for epilepsy associated ...,focal cortical dysplasia,2056,2080
80,Which fusion protein is involved in the develo...,Ewing sarcoma EWS protein regulates midzone fo...,EWS/FLI1,287,295
3507,Which tool employs self organizing maps for an...,INCA: synonymous codon usage analysis and clus...,INCA,529,533
3207,Which interleukin is blocked by Siltuximab?,"A phase I/II study of siltuximab (CNTO 328), a...",interleukin-6,53,66


In [61]:
def evaluation(df):
  res = []

  for index, question in tqdm(df.iterrows()):
    predicted = generate_answer(question)
    actual = question['answer_text']
    correct = predicted == actual

    res.append({
            'index': index,
            'question': question["question"],
            'context': question["context"],
            'actual': actual,
            'predicted': predicted,
            'correct': correct
        })
    
  return pd.DataFrame(res)

In [62]:
ev = evaluation(val_df)

0it [00:00, ?it/s]

In [63]:
ev

Unnamed: 0,index,question,context,actual,predicted,correct
0,2531,Which drug should be used as an antidote in be...,Flumazenil: a benzodiazepine antagonist. The m...,Flumazenil,Flumazenil,True
1,3400,Which bone protein is used in archaelogy for d...,Species identification by analysis of bone col...,collagen,collagen,True
2,4218,What is the typical rash associated with gluten ?,[Clinical guidelines for the diagnosis and tre...,Dermatitis herpetiformis,Dermatitis herpetiformis,True
3,212,Which transcription factor is considered as a ...,Increased lysosomal biogenesis in activated mi...,transcription factor EB (TFEB),transcription factor EB (TFEB),True
4,1740,Which type of myeloma is ixazomib being evalua...,"Phase 1 study of twice-weekly ixazomib, an ora...",multiple myeloma,multiple myeloma,True
...,...,...,...,...,...,...
645,2258,Which disorder is rated by Palmini classificat...,Post-surgical outcome for epilepsy associated ...,focal cortical dysplasia,focal cortical dysplasia,True
646,80,Which fusion protein is involved in the develo...,Ewing sarcoma EWS protein regulates midzone fo...,EWS/FLI1,EWS/FLI1,True
647,3507,Which tool employs self organizing maps for an...,INCA: synonymous codon usage analysis and clus...,INCA,INCA,True
648,3207,Which interleukin is blocked by Siltuximab?,"A phase I/II study of siltuximab (CNTO 328), a...",interleukin-6,interleukin-6,True


In [65]:
ev.correct.unique()

array([ True, False])

In [73]:
acc = sum(list(ev.correct.values))/len(ev)
acc

0.8769230769230769