In [None]:
# download packages
%%capture
! pip install datasets transformers

In [41]:
# import library
from tqdm import tqdm
from IPython import display

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

from sklearn.metrics import accuracy_score

from transformers import T5Tokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq
from transformers.models.t5.modeling_t5 import T5LayerFF

from datasets import load_dataset

In [42]:
# define configuration
SMALL_MODEL_NAME = 't5-small'
BASE_MODEL_NAME = 't5-base'
LARGE_MODEL_NAME = 't5-large'

BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EPOCHS = 20
BOTTLENECK_SIZE = 8

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [43]:
# load the model and tokenizer
tokenizer = T5Tokenizer.from_pretrained(SMALL_MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(SMALL_MODEL_NAME)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [44]:
# load MedMCQA
dataset = load_dataset('openlifescienceai/medmcqa')
dataset

Using the latest cached version of the dataset since openlifescienceai/medmcqa couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /home/mmd/.cache/huggingface/datasets/openlifescienceai___medmcqa/default/0.0.0/91c6572c454088bf71b679ad90aa8dffcd0d5868 (last modified on Mon Jul 15 03:16:13 2024).


DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 182822
    })
    test: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 6150
    })
    validation: Dataset({
        features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
        num_rows: 4183
    })
})

In [64]:
def filter_none(example):
    return example['exp'] is not None

dataset['validation'] = dataset['validation'].filter(filter_none)
dataset['train'] = dataset['train'].filter(filter_none)

Filter: 100%|██████████| 182822/182822 [00:01<00:00, 93598.39 examples/s]


In [70]:
# clean the dataset for our task
def format_example_training(row):
    input_text = f"context:{row['exp']}\n question: {row['question']}\n options: 0:{row['opa']}, 1:{row['opb']}, 2:{row['opc']}, 3:{row['opd']}"
    target_text = f"{row['cop']}"
    return {'input_text': input_text, 'target_text': target_text}

def format_example_validation(row):
    input_text = f"context:{row['exp']}\n question: {row['question']}\n options: 0:{row['opa']}, 1:{row['opb']}, 2:{row['opc']}, 3:{row['opd']}"
    target_text = f"{row['cop']}"
    return {'input_text': input_text, 'target_text': target_text}

dataset['train'] = dataset['train'].map(format_example_training, remove_columns=dataset['train'].column_names)
dataset['validation'] = dataset['validation'].map(format_example_validation, remove_columns=dataset['validation'].column_names)

Map: 100%|██████████| 160869/160869 [00:22<00:00, 7053.56 examples/s]
Map: 100%|██████████| 2206/2206 [00:00<00:00, 6061.33 examples/s]


In [71]:
# use tokenzier and prepare the data to feed the model
def map_function(row):
  input_info = tokenizer(row['input_text'], truncation=True, max_length=1024)
  output_info = tokenizer(row['target_text'])
  return {
      **input_info,
      'labels': output_info.input_ids
  }

dataset['train'] = dataset['train'].map(map_function, batched=True)
dataset['train'].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

dataset['validation'] = dataset['validation'].map(map_function, batched=True)
dataset['validation'].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

Map: 100%|██████████| 160869/160869 [01:53<00:00, 1421.95 examples/s]
Map: 100%|██████████| 2206/2206 [00:01<00:00, 1522.36 examples/s]


In [72]:
col_fn = DataCollatorForSeq2Seq(tokenizer, return_tensors='pt', padding='longest')

indices = list(range(0, 14000))
train_dataset = Subset(dataset['train'], indices)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=col_fn, shuffle=True)

indices = list(range(15000, 17000))
test_dataset = Subset(dataset['train'], indices)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=col_fn)

val_loader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, collate_fn=col_fn)

In [73]:
import gc

torch.cuda.empty_cache()
gc.collect()

26

In [74]:
def train_loop(model, loader, optimizer):
  model.train()

  batch_losses = []

  for row in tqdm(loader):
    row = row.to(model.device)

    optimizer.zero_grad()
    out = model(**row)
    loss = out.loss
    batch_loss_value = loss.item()
    loss.backward()
    optimizer.step()

    batch_losses.append(batch_loss_value)
  loss_value = np.mean(batch_losses)
  return {'train_loss': loss_value}

def predict(model, row):
  return model.generate(input_ids=row.input_ids, attention_mask=row.attention_mask, max_length=5)

def tokenizer_ids_to_label(all_input_ids):
  vocab_size = tokenizer.vocab_size

  filtered_input_ids = [[token_id for token_id in seq if 0 <= token_id < vocab_size] for seq in all_input_ids]

  return tokenizer.batch_decode(filtered_input_ids, skip_special_tokens=True)

def valid_loop(model, loader, compute_metrics):
  model.eval()

  all_true = []
  all_pred = []

  with torch.no_grad():
    for row in tqdm(loader):
      row = row.to(model.device)
      pred = predict(model, row)

      all_true += row.labels.detach().cpu().tolist()
      all_pred += pred.detach().cpu().tolist()

  all_true = tokenizer_ids_to_label(all_true)
  all_pred = tokenizer_ids_to_label(all_pred)

  return {'valid_acc': compute_metrics(y_true=all_true, y_pred=all_pred)}

In [75]:
# adapter layer
class AdapterLayer(nn.Module):
    def __init__(self, emb_dim: int, bottleneck_size: int):

        super().__init__()

        self.sharif_llm_adapter = nn.Sequential(
            nn.Linear(emb_dim, bottleneck_size),
            nn.ReLU(),
            nn.Linear(bottleneck_size, emb_dim)
        )

    def forward(self, x: torch.Tensor):
        adapter_output = self.sharif_llm_adapter(x)
        output =  x + adapter_output
        return output

class FeedForwardAdapterWrapper(nn.Module):
    def __init__(self, original_module: T5LayerFF, bottleneck_size: int):

        super().__init__()
        assert isinstance(original_module, T5LayerFF)

        self.original_module = original_module
        emb_dim = original_module.DenseReluDense.wi.in_features
        self.adapter = AdapterLayer(emb_dim, bottleneck_size)

    def forward(self, x: torch.Tensor):
        output = self.original_module(x)
        output = self.adapter(output)
        return output

In [76]:
# add adapter to the model
def mutate_model_recursive(model: nn.Module, bottleneck_size: int):
    for name, module in model.named_children():
        if isinstance(module, T5LayerFF):
            feed_forward_with_adapter = FeedForwardAdapterWrapper(module, bottleneck_size)
            setattr(model, name, feed_forward_with_adapter)
            print(f"Replaced {name} with FeedForwardAdapterWrapper layer.")
        else:
            mutate_model_recursive(module, bottleneck_size)

def mutate_model(model: nn.Module, bottleneck_size: int):
    if hasattr(model, '_mutated'):
        print("Model already contains adapter layers! \n Try reloading the model.")
        return

    mutate_model_recursive(model, bottleneck_size)

    model._mutated = True


mutate_model(model, bottleneck_size=BOTTLENECK_SIZE)

Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 1 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.
Replaced 2 with FeedForwardAdapterWrapper layer.


In [77]:
# freeze non adapter parameters
def freeze_non_adapter(model, peft_key):
    print('Non freezed weights:')
    total_params = 0
    for param_name, weights in model.named_parameters():
        weights.requires_grad = peft_key in param_name
        if weights.requires_grad:
            print(param_name)
            total_params += weights.numel()
    print(f'Total number of parameters should be update: {total_params}')

freeze_non_adapter(model, peft_key='sharif_llm')

Non freezed weights:
encoder.block.0.layer.1.adapter.sharif_llm_adapter.0.weight
encoder.block.0.layer.1.adapter.sharif_llm_adapter.0.bias
encoder.block.0.layer.1.adapter.sharif_llm_adapter.2.weight
encoder.block.0.layer.1.adapter.sharif_llm_adapter.2.bias
encoder.block.1.layer.1.adapter.sharif_llm_adapter.0.weight
encoder.block.1.layer.1.adapter.sharif_llm_adapter.0.bias
encoder.block.1.layer.1.adapter.sharif_llm_adapter.2.weight
encoder.block.1.layer.1.adapter.sharif_llm_adapter.2.bias
encoder.block.2.layer.1.adapter.sharif_llm_adapter.0.weight
encoder.block.2.layer.1.adapter.sharif_llm_adapter.0.bias
encoder.block.2.layer.1.adapter.sharif_llm_adapter.2.weight
encoder.block.2.layer.1.adapter.sharif_llm_adapter.2.bias
encoder.block.3.layer.1.adapter.sharif_llm_adapter.0.weight
encoder.block.3.layer.1.adapter.sharif_llm_adapter.0.bias
encoder.block.3.layer.1.adapter.sharif_llm_adapter.2.weight
encoder.block.3.layer.1.adapter.sharif_llm_adapter.2.bias
encoder.block.4.layer.1.adapter.sha

In [78]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
compute_metrics = accuracy_score

In [79]:
model.to(DEVICE)

all_results = []
for epoch in range(EPOCHS):
    epoch_results = {'epoch': epoch}

    epoch_results.update(train_loop(model=model, loader=train_loader, optimizer=optimizer))
    epoch_results.update(valid_loop(model=model, loader=val_loader, compute_metrics=compute_metrics))
    all_results.append(epoch_results)

    display.clear_output()
    display.display(pd.DataFrame(all_results).set_index('epoch'))

Unnamed: 0_level_0,train_loss,valid_acc
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.637607,0.669538
1,0.419869,0.729828
2,0.355666,0.765639
3,0.319466,0.768359
4,0.302168,0.782865
5,0.288083,0.781505
6,0.2764,0.787851
7,0.269804,0.790118
8,0.267213,0.787398
9,0.258894,0.789665


In [80]:
model.eval()

all_true = []
all_pred = []

with torch.no_grad():
  for row in tqdm(test_loader):
    row = row.to(model.device)
    pred = predict(model, row)

    all_true += row.labels.detach().cpu().tolist()
    all_pred += pred.detach().cpu().tolist()

all_true = tokenizer_ids_to_label(all_true)
all_pred = tokenizer_ids_to_label(all_pred)
print(compute_metrics(y_true=all_true, y_pred=all_pred))

100%|██████████| 250/250 [00:11<00:00, 20.94it/s]


0.779


In [81]:
print(all_true)
print(all_pred)

['1', '1', '3', '1', '0', '3', '0', '2', '0', '0', '3', '2', '3', '2', '0', '1', '0', '2', '2', '0', '1', '0', '1', '0', '1', '3', '3', '1', '0', '2', '3', '0', '2', '0', '0', '0', '1', '0', '3', '3', '3', '1', '0', '1', '3', '2', '2', '1', '1', '0', '1', '3', '1', '3', '1', '3', '0', '1', '0', '3', '1', '2', '0', '3', '1', '2', '0', '1', '0', '1', '0', '0', '3', '0', '1', '0', '2', '3', '3', '1', '3', '3', '3', '2', '0', '1', '0', '0', '2', '0', '0', '1', '1', '2', '2', '0', '2', '1', '1', '3', '2', '1', '0', '1', '2', '2', '1', '3', '0', '3', '2', '1', '1', '0', '0', '2', '3', '0', '3', '1', '0', '2', '3', '0', '0', '3', '1', '0', '1', '3', '2', '2', '1', '3', '0', '2', '1', '3', '2', '1', '0', '3', '1', '1', '0', '0', '2', '2', '1', '3', '2', '1', '2', '2', '1', '0', '1', '1', '2', '1', '3', '0', '2', '0', '2', '2', '3', '0', '0', '2', '0', '1', '0', '0', '0', '2', '0', '1', '2', '2', '2', '0', '2', '3', '0', '2', '0', '1', '1', '0', '2', '1', '1', '3', '0', '2', '2', '3', '1', '2',

In [82]:
torch.save(model.state_dict(), 't5_adapter_weights.pth')

In [86]:
# code to generate answer based on model
def generate_answer(row):
    input_text = f"context:{row['exp']}\n question: {row['question']}\n options: 0:{row['opa']}, 1:{row['opb']}, 2:{row['opc']}, 3:{row['opd']}"
    input_ids = tokenizer(input_text, truncation=True, max_length=1024)
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    outputs = model.generate(input_ids.to(DEVICE), max_length=5)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

In [91]:
t = {'id': '71ed60b6-c4cc-4fad-9c58-1ae876a59842',
 'question': 'Which of the following is most characteristic of diabetic neuropathy?',
 'opa': 'it is usually bilateral',
 'opb': 'pain is not a feature',
 'opc': 'it most commonly affects the brain',
 'opd': 'it spares the autonomic system',
 'cop': 0,
 'choice_type': 'multi',
 'exp': "Diabetic neuropathy usually presents as peripheral polyneuropathy, usually bilateral, including symptoms of numbness, paresthesia, severe hyperesthesia, and pain. Impairment of proprioceptive fibers can lead to gait abnormalities and Charcot's joints. Mononeuropathy is less common and is often spontaneously reversible. Common syndromes include wrist or foot drop and third, fourth, or sixth cranial nerve palsies. Autonomic neuropathy may cause gastroesophageal dysfunction, bladder dysfunction, and orthostatic hypotension.",
 'subject_name': 'Medicine',
 'topic_name': 'Endocrinology'
 }

In [92]:
answer = generate_answer(t)
print(f"Answer: {answer}")
print(t['cop'])

Answer: 0
0
