## Load Dependencies, Model and Data

In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install wandb

In [None]:
import torch
import wandb
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForMultipleChoice, AutoModelForMaskedLM, BertModel, BertConfig
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
from sklearn.metrics import classification_report
from datasets import load_dataset, load_from_disk, DatasetDict
import pandas as pd
import numpy as np
import os

from tqdm import tqdm
from tqdm import trange
import random
import math

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load the PubMed Abstracts

train_dataset = load_dataset("ywchoi/pubmed_abstract_1", split='train')
val_dataset = load_dataset("ywchoi/pubmed_abstract_1", split='validation')

# Susbet data (for experimentation)

train_dataset = train_dataset.select(( i for i in range(int(len(train_dataset)/1))))
val_dataset = train_dataset.select(( i for i in range(int(len(train_dataset)/1))))


In [None]:
def initializeStudent(save_path):
  
  ''' Initiliazes student model as a subset of layers from the teacher. Adapted from
      https://github.com/nlpie-research/Compact-Biomedical-Transformers/blob/main/DistilBioBERT-Distillation.py '''

  bertModel = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.2")

  distilBertConfig = bertModel.config.to_dict()
  distilBertConfig["num_hidden_layers"] //= 2

  distillationModel = BertModel(config= BertConfig.from_dict(distilBertConfig))
  distillationModel.embeddings = bertModel.embeddings

  for index,layer in enumerate(distillationModel.encoder.layer):
    distillationModel.encoder.layer[index] = bertModel.encoder.layer[2*index]

  distillationModel.save_pretrained(save_path)

  return save_path

In [None]:
# Initialize student model and load teacher model

#student_model = AutoModelForMaskedLM.from_pretrained('distilbert-base-cased')
student_model = AutoModelForMaskedLM.from_pretrained(initializeStudent('initialized_model/'))
student_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')

teacher_model = AutoModelForMaskedLM.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
teacher_tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

In [None]:
# View model sizes 

print(student_model.num_parameters())
print(assistant_model.num_parameters())
print(teacher_model.num_parameters())

In [None]:
for param in teacher_model.parameters():
  param.requires_grad = False

## Continued student pretraining w/ KD

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import math
from transformers.modeling_outputs import MaskedLMOutput
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

In [None]:
class DistillationWrapper(nn.Module):

  ''' Distillation code. Adapted from https://github.com/nlpie-research/Compact-Biomedical-Transformers/blob/main/DistilBioBERT-Distillation.py '''

  def __init__(self,
               student, 
               teacher, 
               temperature=2.0, 
               alpha_ce=5.0, 
               alpha_mlm=2.0, 
               alpha_cos=1.0):
    
    super().__init__()

    self.student = student
    self.teacher = teacher

    self.temperature = temperature
    self.vocab_size = self.teacher.config.vocab_size
    self.dim = self.teacher.config.hidden_size

    self.restrict_ce_to_mask = True

    self.alpha_ce = alpha_ce
    self.alpha_mlm = alpha_mlm
    self.alpha_cos = alpha_cos

    self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
    self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

  def forward(self, 
              input_ids, 
              attention_mask,
              labels=None,
              **kargs):

    student_outputs = self.student(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   labels=labels,
                                   output_hidden_states=True,
                                   **kargs)   
    
    s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"]

    loss = None

    if labels != None:
      
      with torch.no_grad():
        teacher_outputs = self.teacher(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       output_hidden_states=True,
                                       **kargs)

      t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"]    


      if self.restrict_ce_to_mask:
        mask = (labels > -1).unsqueeze(-1).expand_as(s_logits).bool()
      else:
        mask = attention_mask.unsqueeze(-1).expand_as(s_logits).bool()

      s_logits_slct = torch.masked_select(s_logits, mask)  
      s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  
      t_logits_slct = torch.masked_select(t_logits, mask)  
      t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) 
      assert t_logits_slct.size() == s_logits_slct.size()
      
      loss_mlm = student_outputs.loss

      loss_ce = (
          self.ce_loss_fct(
              nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
              nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
          )
          * (self.temperature) ** 2
      )

      loss = (self.alpha_mlm * loss_mlm) + (self.alpha_ce * loss_ce)

      if self.alpha_cos > 0.0:
          s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
          t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
          mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).bool()  # (bs, seq_length, dim)
          assert s_hidden_states.size() == t_hidden_states.size()
          dim = s_hidden_states.size(-1)

          s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
          s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
          t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
          t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

          target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
          loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
          loss += (self.alpha_cos * loss_cos)


    return MaskedLMOutput(
        loss=loss,
        logits=student_outputs.logits,
        hidden_states=student_outputs.hidden_states,
        attentions=student_outputs.attentions,
    )

In [None]:
model = DistillationWrapper(student=student_model, teacher=teacher_model)

In [None]:
# Tokenize and collate

def tokenize_function(examples):
    return student_tokenizer(examples["text"], padding="max_length", max_length=512, truncation=True)

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_function, batched=True)

data_collator = DataCollatorForLanguageModeling(tokenizer=student_tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")

In [None]:
!wandb login

In [None]:
## Continued pretraining setup

savePath = "learned_student_kd_2/"

trainingArguments = TrainingArguments(
    output_dir= savePath + "checkpoints",
    #evaluation_strategy="steps",
    #eval_steps=200,
    logging_steps=1000,
    overwrite_output_dir=True,
    save_steps=250,
    num_train_epochs=0.5, 
    learning_rate=5e-4,
    lr_scheduler_type="linear",
    warmup_steps=5000,
    per_device_train_batch_size=24, 
    weight_decay=0.0,
    save_total_limit=5,
    remove_unused_columns=True,
    report_to="wandb"
) 

trainer = Trainer(
    model=model,
    args=trainingArguments,
    train_dataset=tokenized_train_dataset,
    data_collator=data_collator,
    #callbacks=[ts.ProgressCallback(), CustomCallback()],
)

trainer.train()

In [None]:
# Save model

trainer.save_model("/content/learned_student_kd_trainersave/")
model.student.save_pretrained("/content/learned_student_kd/")
!zip -r /content/learned_student_kd.zip /content/learned_student_kd

In [None]:
# Model model to Google Drive

import shutil
colab_link = "/content/learned_student_kd.zip"
gdrive_link = "/content/drive/MyDrive/CLS/KD-CL-SR/"
shutil.copy(colab_link, gdrive_link)

## Finetune and evaluate on student BioASQ

In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install wandb

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForMultipleChoice
from sklearn.metrics import classification_report
from datasets import load_dataset, DatasetDict
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
import pandas as pd
import numpy as np
import wandb
import os

In [None]:
# Load learned model

learned_student = AutoModelForMultipleChoice.from_pretrained("/content/learned_student_kd")
learned_tokenizer = student_tokenizer

In [None]:
def stratify(dataset, yes_max, no_max):
  
  " Simple class balancing function w/ shuffling"

  yes_count = 0
  no_count = 0

  exclude_id = []

  for i in range(len(dataset)):

    if (dataset[i]['answers'] == "yes"): 
      yes_count+=1
      if yes_count > yes_max:
        exclude_id.append(i)
    
    if (dataset[i]['answers'] == "no"):   
      no_count+=1
      if no_count > no_max:
        exclude_id.append(i)

  dataset = dataset.select(
      (
          i for i in range(len(dataset)) 
          if i not in set(exclude_id)
      )
  )

  return dataset.shuffle(seed=42)

In [None]:
# Load the labeled BioASQ dataset

dataset = load_dataset("reginaboateng/Bioasq7b")['train']

# Balance classes

dataset_balanced = stratify(dataset, 883, 883)
train_dataset = dataset_balanced.select(( i for i in range(0, 1500)))
val_dataset = dataset_balanced.select(( i for i in range(1500, 1600)))
test_dataset = dataset_balanced.select(( i for i in range(1600, 1766)))

# Add numeric label column for all datasets

d = {'yes' : 0, 'no': 1}
new_column = [d[fd] for fd in train_dataset['answers']] 
train_dataset = train_dataset.add_column("label", new_column)
new_column = [d[fd] for fd in val_dataset['answers']] 
val_dataset = val_dataset.add_column("label", new_column)
new_column = [d[fd] for fd in test_dataset['answers']] 
test_dataset = test_dataset.add_column("label", new_column)

In [None]:
# Tokenize the data 

def preprocess(example):

  batch_size = 2
  answers = ["yes", "no"]
  context = [[c] * len(answers) for c in example["context"]]
  question_headers = example["question"]
  
  question_answer = [
      [f"{header} {a}" for a in answers] for i, header in enumerate(question_headers)
  ]

  context = sum(context, [])
  question_answer = sum(question_answer, [])
  
  tokenized_examples = learned_tokenizer(context, question_answer, truncation='only_first', max_length=512)
  
  return {k: [v[i : i + batch_size] for i in range(0, len(v), len(answers))] for k, v in tokenized_examples.items()}

tokenized_train_dataset = train_dataset.map(preprocess, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess, batched=True)
tokenized_test_dataset = test_dataset.map(preprocess, batched=True)

In [None]:
# Load data collator

from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

@dataclass
class DataCollatorForMultipleChoice:
    
    """
    Data collator that will dynamically pad the inputs for multiple choice received. Adapted from https://huggingface.co/docs/transformers/tasks/multiple_choice
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
      
      label_name = "label" if "label" in features[0].keys() else "label"
      labels = [feature.pop(label_name) for feature in features]
      batch_size = len(features)
      num_choices = len(features[0]["input_ids"])
      
      flattened_features = [
          [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features]
   
      flattened_features = sum(flattened_features, [])

      batch = self.tokenizer.pad(
          flattened_features,
          padding=self.padding,
          max_length=self.max_length,
          pad_to_multiple_of=self.pad_to_multiple_of,
          return_tensors="pt",
      )

      batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
      batch["labels"] = torch.tensor(labels, dtype=torch.int64)

      #print('batch input_id size: ', batch['input_ids'].shape)
      #print('batch token_type_id size: ', batch['token_type_ids'].shape)
      #print('batch attention_mask size: ', batch['attention_mask'].shape)
      #print('batch labels: ', batch['labels'])
      #print('labels size: ', len(labels))

      #print('\nBatch: ', batch)

      return batch

In [None]:
# Load evaluation metrics

from datasets import load_metric
import evaluate

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

### Hyperparameter Search

In [None]:
!wandb login

In [None]:
# W&B hyperparameter specification

# method
sweep_config = {
    'method': 'random'
}

# hyperparameters
parameters_dict = {
    'epochs': {
        'values': [1, 2]
        },
    'batch_size': {
        'values': [4, 8, 12]
        },
    'learning_rate': {
        'distribution': 'log_uniform_values',
        'min': 1e-6,
        'max': 1e-3
    },
    'weight_decay': {
        'values': [0.05, 0.1, 0.15]
    },
}

metric = {
    'name' : 'loss',
    'goal' : 'minimize'
}

sweep_config['metric'] = metric
sweep_config['parameters'] = parameters_dict
sweep_id = wandb.sweep(sweep_config, project='learned-student-kd-bioasq-1')

In [None]:
# W&B trainer setup

preds = []
test_accs = []

def fine_tune(config=None):

  with wandb.init(config=config):
    # set sweep configuration
    config = wandb.config

    training_args = TrainingArguments(
      output_dir="/content/wandb/outputs", 
      evaluation_strategy="steps",
      save_strategy="steps",
      eval_steps=20,
      load_best_model_at_end=True,
      #learning_rate=5e-4,
      learning_rate=config.learning_rate,
      per_device_train_batch_size=config.batch_size,
      per_device_eval_batch_size=config.batch_size,
      num_train_epochs=config.epochs,
      weight_decay=config.weight_decay,
      logging_steps=10,
      push_to_hub=False,
      report_to="wandb"
    )

    trainer = Trainer(
        model=learned_student,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_val_dataset,
        tokenizer=learned_tokenizer,
        data_collator=DataCollatorForMultipleChoice(tokenizer=learned_tokenizer),
        compute_metrics=compute_metrics,
    )

    trainer.train()

wandb.agent(sweep_id, fine_tune, count=20) 

In [None]:
test_results = trainer.predict(test_dataset=tokenized_test_dataset)
test_results