## Environment

In [None]:
!pip install ray transformers

import pandas as pd
import random
import math
import json
import os
import torch.nn as nn
import numpy as np

from ray import tune, put, get
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from functools import partial

from transformers import get_linear_schedule_with_warmup, ElectraTokenizer, ElectraForSequenceClassification 
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW

In [None]:
import torch

if torch.cuda.is_available():    

    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

In [None]:
model_id = "sultan/BioM-BERT-PubMed-PMC-Large"

## Data

In [None]:
path_to_train_df = '/path/to/preprocessed/train_df.json'
path_to_validation_df = '/path/to/preprocessed/val_df.json'

In [None]:
train_df = pd.read_json(path_to_train_df)
val_df = pd.read_json(path_to_validation_df)

### Tokenization

In [None]:
tokenizer = ElectraTokenizer.from_pretrained(model_id)
label_dict = {'Entailment': 0, 'Contradiction': 1}
MAX_LEN = 512

# creates a TensorDataset from a given examples dataframe. 
def get_dataset(df):
  input_token_ids = []
  mask_ids = []
  segment_ids = []
  labels = []

  premise_list = df['Premise'].to_list() # this is a list of lists of strings
  hypothesis_list = df['Statement'].to_list()
  label_list = df['Label'].to_list()
  
  for (premise, hypothesis, label) in zip(premise_list, hypothesis_list, label_list):
    tokenization_output = tokenizer.encode_plus(text=hypothesis,
                                                text_pair=' '.join(premise),
                                                add_special_tokens=True,
                                                truncation=True,
                                                max_length=MAX_LEN, 
                                                return_tensors="pt",
                                                return_token_type_ids=True,
                                                return_attention_mask=True)
    
          
    out_input_ids = tokenization_output['input_ids'][0]
    out_mask_ids = tokenization_output['attention_mask'][0]
    out_segment_ids = tokenization_output['token_type_ids'][0]

    input_token_ids.append(out_input_ids)
    mask_ids.append(out_mask_ids)
    segment_ids.append(out_segment_ids)
    labels.append(label_dict[label])
  
  input_token_ids = pad_sequence(input_token_ids, batch_first=True)
  mask_ids = pad_sequence(mask_ids, batch_first=True)
  segment_ids = pad_sequence(segment_ids, batch_first=True)
  labels = torch.tensor(labels)

  return TensorDataset(input_token_ids, mask_ids, segment_ids, labels)

# creates train and validation dataloaders
def get_dataloaders(batch_size):  
  train_data = get_dataset(train_df)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

  validation_data = get_dataset(val_df)
  validation_sampler = SequentialSampler(validation_data)
  validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

  return train_dataloader, validation_dataloader

## Model

In [None]:
bertModel = ElectraForSequenceClassification.from_pretrained(model_id, num_labels=2)
bertModel.to(device)

In [None]:
# Number of training epochs
epochs = 3

def get_scheduler_and_optimizer(model, learning_rate, batches, adam_epsilon = 1e-8, warmup_steps_ratio=0):
  optimizer = AdamW(model.parameters(),
                    lr = learning_rate,
                    eps = adam_epsilon)
  
  total_steps = epochs * batches

  return get_linear_schedule_with_warmup(optimizer, 
                                         num_warmup_steps = math.floor(total_steps * warmup_steps_ratio),
                                         num_training_steps = total_steps), optimizer

In [None]:
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
  pred_flat = np.argmax(preds, axis=1).flatten()
  labels_flat = labels.flatten()
  return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
# Set the seed value to make this reproducible.
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

def fine_tune(config, model, checkpoint_dir=None):
  evaluated_model_ref = get(model)

  train_dataloader, validation_dataloader = get_dataloaders(config["batch_size"])
  scheduler, optimizer = get_scheduler_and_optimizer(evaluated_model_ref, config["learning_rate"], len(train_dataloader) * epochs, config["adam_epsilon"], config["warmup_steps_percentage"])
  criterion = nn.CrossEntropyLoss()

  if checkpoint_dir:
    model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
    evaluated_model_ref.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)

  for epoch_i in range(0, epochs):
      evaluated_model_ref.train()

      for step, (pair_token_ids, mask_ids, seg_ids, y) in enumerate(train_dataloader):

          evaluated_model_ref.zero_grad()   

          pair_token_ids = pair_token_ids.to(device)
          mask_ids = mask_ids.to(device)
          seg_ids = seg_ids.to(device)
          labels = y.to(device)     

          loss, prediction = evaluated_model_ref(pair_token_ids, 
                                  token_type_ids=seg_ids, 
                                  attention_mask=mask_ids, 
                                  labels=labels).values()

          loss.backward()

          torch.nn.utils.clip_grad_norm_(evaluated_model_ref.parameters(), 1.0)

          optimizer.step()

          scheduler.step()
    
      evaluated_model_ref.eval()

      # Tracking variables 
      eval_loss, eval_accuracy = 0, 0
      nb_eval_steps, nb_eval_examples = 0, 0

      for (pair_token_ids, mask_ids, seg_ids, labels) in validation_dataloader:

          pair_token_ids = pair_token_ids.to(device)
          mask_ids = mask_ids.to(device)
          seg_ids = seg_ids.to(device)
          labels = labels.to(device)

          with torch.no_grad():        

              outputs = evaluated_model_ref(pair_token_ids, 
                              token_type_ids=seg_ids, 
                              attention_mask=mask_ids)
          
          logits = outputs[0]

          logits = logits.detach().cpu().numpy()
          label_ids = labels.to('cpu').numpy()

          tmp_eval_accuracy = flat_accuracy(logits, label_ids)
          tmp_eval_loss = criterion(torch.tensor(logits), torch.tensor(label_ids))
          
          eval_accuracy += tmp_eval_accuracy
          eval_loss += tmp_eval_loss

          nb_eval_steps += 1

      with tune.checkpoint_dir(epoch_i) as checkpoint_dir:
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save((evaluated_model_ref.state_dict(), optimizer.state_dict()), path)
            
      tune.report(loss=(eval_loss/nb_eval_steps), accuracy=eval_accuracy/nb_eval_steps)

In [None]:
# Driver cell

# hyperparameter tuning configuration. Add values to the corresponding tune.choice list to include them in the search. 
config = {
    "learning_rate": tune.choice([2e-5]),
    "batch_size": tune.choice([16]),
    "adam_epsilon": tune.choice([1e-8]),
    "warmup_steps_percentage": tune.choice([0.02])
    }

scheduler = ASHAScheduler(
    metric="accuracy",
    mode="max",
    max_t=epochs,
    grace_period=2,
    reduction_factor=2)

reporter = CLIReporter(
    metric_columns=["loss", "accuracy", "training_iteration"])

# Create a ref to the model to pass to the fine_tune function.
# The model is ~1.3 GiB. If it is directly referenced in our fine_tune method, Ray will throw a 'worker function too big > 95MiB' error.
modelRef = put(bertModel)

result = tune.run(
    partial(fine_tune, model=modelRef),
    resources_per_trial={"gpu": 1},
    config=config,
    num_samples=3,
    scheduler=scheduler,
    progress_reporter=reporter)

best_trial = result.get_best_trial("accuracy", "max", "all")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(best_trial.last_result["accuracy"]))
result.get_best_checkpoint(trial, "accuracy", "max", "all")