<a href="https://colab.research.google.com/github/stellaevat/ontology-mapping/blob/main/colabs/bi_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers[torch] datasets evaluate

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

In [None]:
import gc
import random
import evaluate
import torch
import numpy as np
from pprint import pprint
from tqdm import tqdm
from sklearn.model_selection import train_test_split, ParameterGrid
from datasets import Dataset, DatasetDict
from transformers import get_scheduler, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Collate input dataset

In [None]:
def read_bi_tokenized_from_file(filepath):
  s_input_ids, s_token_type_ids, s_attention_mask, t_input_ids, t_token_type_ids, t_attention_mask, labels = [], [], [], [], [], [], []
  with open(filepath) as f:
    for line in f:
      strings = line.strip().split("],")
      lists = [list(map(int, s.strip("[").split(","))) for s in strings[:-1]]
      s_input_id, s_token_type_id, s_attention, t_input_id, t_token_type_id, t_attention = lists
      label = int(strings[-1])

      s_input_ids.append(s_input_id)
      s_token_type_ids.append(s_token_type_id)
      s_attention_mask.append(s_attention)
      t_input_ids.append(t_input_id)
      t_token_type_ids.append(t_token_type_id)
      t_attention_mask.append(t_attention)
      labels.append(label)

  source_tokenized = {"input_ids" : s_input_ids,
                      "token_type_ids" : s_token_type_ids,
                      "attention_mask" : s_attention_mask}
  target_tokenized = {"input_ids" : t_input_ids,
                      "token_type_ids" : t_token_type_ids,
                      "attention_mask" : t_attention_mask}
  return source_tokenized, target_tokenized, labels

def read_cross_tokenized_from_file(filepath):
  input_ids, token_type_ids, attention_mask, labels = [], [], [], []
  with open(filepath) as f:
    for line in f:
      strings = line.strip().split("],")
      lists = [list(map(int, s.strip("[").split(","))) for s in strings[:-1]]
      input_id, token_type_id, attention = lists
      label = int(strings[-1])

      input_ids.append(input_id)
      token_type_ids.append(token_type_id)
      attention_mask.append(attention)
      labels.append(label)

  tokenized = {"input_ids" : input_ids,
               "token_type_ids" : token_type_ids,
               "attention_mask" : attention_mask}
  return tokenized, labels

def onto_cross_tokenized_from_file(filepath):
  input_ids, token_type_ids, attention_mask = [], [], []
  with open(filepath) as f:
    for line in f:
      strings = line.strip().split("],")
      lists = [list(map(int, s.strip("[").split(","))) for s in strings]
      input_id, token_type_id, attention = lists

      input_ids.append(input_id)
      token_type_ids.append(token_type_id)
      attention_mask.append(attention)

  tokenized = {"input_ids" : input_ids,
               "token_type_ids" : token_type_ids,
               "attention_mask" : attention_mask}
  return tokenized

In [None]:
def filter_source_target(Xi, source, target):
  X_source = {k : [v[i] for i in Xi] for (k, v) in source.items()}
  X_target = {k : [v[i] for i in Xi] for (k, v) in target.items()}
  return X_source, X_target


def collate_dataset(X_train, X_val, X_test, y_train, y_val, y_test):
  dataset_train = Dataset.from_dict(X_train | {'labels' : y_train})
  dataset_val = Dataset.from_dict(X_val | {'labels' : y_val})
  dataset_test = Dataset.from_dict(X_test | {'labels' : y_test})
  dataset = DatasetDict({'train' : dataset_train,
                         'val' : dataset_val,
                         'test' : dataset_test})
  dataset.set_format(type="torch")
  return dataset


def get_bi_datasets_from_tokenized(source_tokenized, target_tokenized, labels):
  Xi = np.arange(len(labels))
  y = np.array(labels)
  Xi_train_val, Xi_test, y_train_val, y_test = train_test_split(Xi, y, test_size=0.2, random_state=3)
  Xi_train, Xi_val, y_train, y_val = train_test_split(Xi_train_val, y_train_val, test_size=0.25, random_state=3)

  X_source_train, X_target_train = filter_source_target(Xi_train, source_tokenized, target_tokenized)
  X_source_val, X_target_val = filter_source_target(Xi_val, source_tokenized, target_tokenized)
  X_source_test, X_target_test = filter_source_target(Xi_test, source_tokenized, target_tokenized)

  print(f"Train: {len(y_train)}")
  print(f"Validation: {len(y_val)}")
  print(f"Test: {len(y_test)}")

  source_data = collate_dataset(X_source_train, X_source_val, X_source_test, y_train, y_val, y_test)
  target_data = collate_dataset(X_target_train, X_target_val, X_target_test, y_train, y_val, y_test)
  return source_data, target_data

In [None]:
negative_sampling = ['random', 'multi', 'neighbour']
features = ['term', 'int', 'ext']
direction = "ncit2doid"

feature = features[0]
negatives = negative_sampling[0]

source_tokenized, target_tokenized, labels = read_bi_tokenized_from_file(f"bi_tokenized_{feature}_{negatives}_{direction}.csv")
source_data, target_data = get_bi_datasets_from_tokenized(source_tokenized, target_tokenized, labels)

Train: 3860
Validation: 1287
Test: 1287


# Bi-encoder

In [None]:
def full_determinism(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.use_deterministic_algorithms(True)
  torch.backends.cudnn.deterministic = True

In [None]:
class BiEncoderForSequenceClassification(torch.nn.Module):
  def __init__(self, model_name, num_labels, id2label=None, label2id=None, token_embeddings_size=None, hidden_layer=-1):
    super().__init__()
    self.source_model = AutoModel.from_pretrained(model_name)
    self.target_model = AutoModel.from_pretrained(model_name)
    if token_embeddings_size:
      self.source_model.resize_token_embeddings(token_embeddings_size)
      self.target_model.resize_token_embeddings(token_embeddings_size)

    self.num_labels = num_labels
    self.hidden_layer = hidden_layer

    self.sigmoid = torch.nn.Sigmoid()
    self.dropout = torch.nn.Dropout(0.1)
    self.similarity = torch.nn.CosineSimilarity(dim=-1)
    self.loss = torch.nn.BCEWithLogitsLoss()

  def forward(
      self,
      s_input_ids=None, t_input_ids=None,
      s_attention_mask=None, t_attention_mask=None,
      s_token_type_ids=None, t_token_type_ids=None,
      s_position_ids=None, t_position_ids=None,
      s_head_mask=None, t_head_mask=None,
      s_inputs_embeds=None, t_inputs_embeds=None,
      labels=None
    ):

    source_outputs = self.source_model(
      s_input_ids,
      attention_mask=s_attention_mask,
      token_type_ids=s_token_type_ids,
      position_ids=s_position_ids,
      head_mask=s_head_mask,
      inputs_embeds=s_inputs_embeds,
    )

    target_outputs = self.target_model(
      t_input_ids,
      attention_mask=t_attention_mask,
      token_type_ids=t_token_type_ids,
      position_ids=t_position_ids,
      head_mask=t_head_mask,
      inputs_embeds=t_inputs_embeds,
    )

    pooled_source_outputs = self.dropout(source_outputs[1])
    pooled_target_outputs = self.dropout(target_outputs[1])

    # similarities = self.similarity(pooled_source_outputs, pooled_target_outputs)
    # logits = self.sigmoid(similarities)
    logits = torch.sum(pooled_source_outputs * pooled_target_outputs, dim=-1)

    loss = None
    if labels is not None:
      loss = self.loss(logits.view(-1), labels.view(-1).float())

    return SequenceClassifierOutput(loss=loss, logits=logits)

  def get_source_model_outputs(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None):
    source_outputs = self.source_model(
      input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids
    )
    return source_outputs[1]

  def get_target_model_outputs(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None):
    target_outputs = self.target_model(
      input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids
    )
    return target_outputs[1]

In [None]:
class EarlyStopper:
  def __init__(self, patience=1, delta=0):
    self.patience = patience
    self.delta = delta
    self.counter = 0
    self.min_loss = float('inf')

    self.best_epoch = 0
    self.best_metrics = {}
    self.best_model_state = None
    self.best_optimizer_state = None

  def early_stop(self, loss, epoch, model_state, optimizer_state, metrics):
    if loss < self.min_loss:
      self.min_loss = loss
      self.counter = 0

      self.best_epoch = epoch
      self.best_metrics = metrics
      self.best_model_state = model_state
      self.best_optimizer_state = optimizer_state

    elif loss >= (self.min_loss + self.delta):
      self.counter += 1
      if self.counter >= self.patience:
        return True
    return False

  def save_best_checkpoint(self, filepath):
    with open("bi_encoder_metrics.csv", 'a') as f:
      f.write(f"{filepath},{self.best_epoch},{self.min_loss},{self.best_metrics.get('accuracy')},{self.best_metrics.get('precision')},{self.best_metrics.get('recall')},{self.best_metrics.get('f1')}\n")

    torch.save({
      'epoch': self.best_epoch,
      'model_state_dict': self.best_model_state,
      'optimizer_state_dict': self.best_optimizer_state,
      'loss': self.min_loss,
      'accuracy': self.best_metrics.get('accuracy'),
      'precision': self.best_metrics.get('precision'),
      'recall': self.best_metrics.get('recall'),
      'f1': self.best_metrics.get('f1'),
      }, filepath)

In [None]:
def show_results(epoch, loss, metrics):
  print(f"\n\nEPOCH {epoch}\n")
  print(f"Training loss: {loss}")
  pprint(metrics)
  print()


def evaluate_biencoder(model, source_data, target_data, batch_size=32):
  eval_dataloader_index = DataLoader(Dataset.from_dict({'index' : range(len(source_data["val"]))}), batch_size=batch_size)
  metrics = [evaluate.load("accuracy"), evaluate.load("precision"), evaluate.load('recall'), evaluate.load('f1')]

  model.eval()
  avg_loss = 0
  for batch in eval_dataloader_index:
    batch_index = list(batch["index"])
    source_batch = source_data["val"][batch_index]
    target_batch = target_data["val"][batch_index]
    labels = source_batch["labels"]

    source_batch = {"s_" + k: v.to(device) for (k, v) in source_batch.items() if k != "labels"}
    target_batch = {"t_" + k: v.to(device) for (k, v) in target_batch.items() if k != "labels"}
    params = source_batch | target_batch
    params["labels"] = labels.to(device)

    with torch.no_grad():
      outputs = model(**params)

    logits = outputs.logits.cpu()
    loss = outputs.loss.item()
    avg_loss += loss * len(batch_index)

    predictions = np.where(logits.squeeze() >= 0.5, 1, 0)
    for metric in metrics:
      metric.add_batch(predictions=predictions, references=labels.cpu())

  avg_loss = avg_loss / len(source_data["val"])
  metric_dict = {"Validation loss" : avg_loss}
  metric_dict.update(metrics[0].compute())
  for metric in metrics[1:]:
    metric_dict.update(metric.compute(average='macro'))

  return metric_dict


def train_biencoder(model, source_data, target_data, learning_rate=1e-5, epochs=3, batch_size=32, save_filepath="bi_encoder_state.pt", verbose=True):
  train_dataloader_index = DataLoader(Dataset.from_dict({'index' : range(len(source_data["train"]))}), shuffle=True, batch_size=batch_size)
  num_training_steps = epochs * len(train_dataloader_index)

  optimizer = AdamW(model.parameters(), lr=learning_rate) ### consider changing to adam (bertsubs)
  scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
  early_stopper = EarlyStopper(patience=3)

  progress_bar = tqdm(range(num_training_steps), position=0, leave=True)

  model.train()
  for epoch in range(1, epochs+1):
    avg_loss = 0
    for batch in train_dataloader_index:
      batch_index = list(batch["index"])
      source_batch = source_data["train"][batch_index]
      target_batch = target_data["train"][batch_index]
      labels = source_batch["labels"]

      source_batch = {"s_" + k: v.to(device) for (k, v) in source_batch.items() if k != "labels"}
      target_batch = {"t_" + k: v.to(device) for (k, v) in target_batch.items() if k != "labels"}
      params = source_batch | target_batch
      params["labels"] = labels.to(device)

      outputs = model(**params)
      loss = outputs.loss
      avg_loss += loss * len(batch["index"])
      loss.backward()

      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      progress_bar.update(1)

    avg_loss = avg_loss / len(source_data["train"])
    metrics = evaluate_biencoder(model, source_data, target_data, batch_size=batch_size)
    if verbose:
      show_results(epoch, avg_loss, metrics)

    val_loss = metrics["Validation loss"]
    if early_stopper.early_stop(val_loss, epoch, model.state_dict(), optimizer.state_dict(), metrics):
      break

  early_stopper.save_best_checkpoint(save_filepath)

# Experiments

In [None]:
epochs = [10]
batch_size = [2 ** i for i in range(4, 8)]
learning_rate = [10 ** i for i in range(-6, -2)]
param_grid = ParameterGrid({"epochs" : epochs, "learning_rate" : learning_rate, "batch_size" : batch_size})

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
pretrained = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"

for params in list(param_grid)[6:16]:
  full_determinism(seed=3)
  model = BiEncoderForSequenceClassification(pretrained, num_labels=1)
  model.to(device)

  save_filepath = f"bi_encoder_state_(lr={params['learning_rate']},bs={params['batch_size']}).pt"
  train_biencoder(model, source_data, target_data,
                  save_filepath=save_filepath,
                  verbose=False,
                  **params)

  model.to(torch.device("cpu"))
  model = None
  gc.collect()
  torch.cuda.empty_cache()

In [None]:
# # # Transfer files to GDrive
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !cp "/content/bi_encoder_state_(lr=1e-06,bs=32).pt" "/content/drive/My Drive/bi_encoder_state_(lr=1e-06,bs=32).pt"

In [None]:
# # Test loading saved model metrics
# checkpoint = torch.load("/content/drive/My Drive/bi_encoder_state_(lr=1e-05,bs=16).pt", map_location=torch.device('cpu'))
# print(checkpoint['epoch'], checkpoint['accuracy'], checkpoint['precision'], checkpoint['recall'], checkpoint['f1'])

In [None]:
# # Test loading saved model
# torch.cuda.empty_cache()
# full_determinism(seed=3)
# pretrained = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# model = BiEncoderForSequenceClassification(pretrained, num_labels=1)
# model.to(device)
# train_biencoder(model, source_data, target_data)

# model2 = BiEncoderForSequenceClassification(pretrained, num_labels=1)
# checkpoint = torch.load("bi_encoder_state_dict")
# model2.load_state_dict(checkpoint['model_state_dict'])
# model2.to(device)

# source_try = model2.get_source_model_outputs(source_data["test"]["input_ids"][:1].to(device), source_data["test"]["token_type_ids"][:1].to(device), source_data["test"]["attention_mask"][:1].to(device))
# target_try = model2.get_source_model_outputs(target_data["test"]["input_ids"][:1].to(device), target_data["test"]["token_type_ids"][:1].to(device), target_data["test"]["attention_mask"][:1].to(device))
# source_real = model.get_source_model_outputs(source_data["test"]["input_ids"][:1].to(device), source_data["test"]["token_type_ids"][:1].to(device), source_data["test"]["attention_mask"][:1].to(device))
# target_real = model.get_source_model_outputs(target_data["test"]["input_ids"][:1].to(device), target_data["test"]["token_type_ids"][:1].to(device), target_data["test"]["attention_mask"][:1].to(device))

# print(torch.all(source_try == source_real))
# print(torch.all(target_try == target_real))