<a href="https://colab.research.google.com/github/stellaevat/ontology-mapping/blob/main/colabs/bi_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pronto transformers[torch] datasets evaluate \
&& pip install accelerate -U \
&& wget -O doid.obo https://gla-my.sharepoint.com/:u:/g/personal/2526934t_student_gla_ac_uk/EfUC_RdrfZdOsOrtmNATjuoBPDaIkSTUMyxJXyO2KKC6yw?download=1 \
&& wget -O ncit.obo https://gla-my.sharepoint.com/:u:/g/personal/2526934t_student_gla_ac_uk/ETmaJIC0fAlItdsp8WQxS_wBzKN_6x08EZrtsOxVnbzvSg?download=1

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

In [26]:
import random
import pronto
import evaluate
import torch
import torch.nn.functional as F
import numpy as np
from copy import copy, deepcopy
from pprint import pprint
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import get_scheduler, RobertaModel, AutoTokenizer, AutoModel, BioGptTokenizer, BioGptModel, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
from scipy.special import softmax
from torch.utils.data import DataLoader
from torch.optim import AdamW

In [4]:
ncit = pronto.Ontology("ncit.obo")
doid = pronto.Ontology("doid.obo")

In [5]:
# Get subsumptions from CSV file to a dictionary

def get_mappings_from_file(filename):
  mappings = {}
  with open(filename) as f:
    for line in f:
      source_id, target_id = line.strip().split(',')
      mappings[source_id] = target_id
  return mappings

In [6]:
equiv_doid2ncit = get_mappings_from_file("equiv_doid2ncit.csv")
equiv_ncit2doid = get_mappings_from_file("equiv_ncit2doid.csv")
subs_doid2ncit = get_mappings_from_file("subs_doid2ncit.csv")
subs_ncit2doid = get_mappings_from_file("subs_ncit2doid.csv")
neg_subs_doid2ncit = get_mappings_from_file("neg_subs_doid2ncit.csv")
neg_subs_ncit2doid = get_mappings_from_file("neg_subs_ncit2doid.csv")

# Convert relations to sentences

> Currently considering parents, children & siblings for conceptual reasons, but could also take 'n-hop' appraoch, e.g. 1-hop only with parents and children, or 2-hop to include grandparents, grandchildren and siblings.

> How do I incorporate the desired mapping for training? Should I incorporate both all this AND target info, or too much? Could be SELF + desired relatives instead, or SELF + PARENT + DESIRED PARENT, etc.

In [7]:
entity_markers = ["[SUB]", "[/SUB]", "[SUP]", "[/SUP]"]
sep_token = "[SEP]"
cls_token = "[CLS]"

In [8]:
# Create sentence from the given entity, containing its direct parents & siblings

def get_sentence(entity_id, onto):
  sub_in, sub_out, sup_in, sup_out = entity_markers

  subsumer = onto.get_term(entity_id)
  supersumers = list(subsumer.superclasses(distance=1, with_self=False))

  sentence = [sub_in, subsumer.name, sub_out]
  for supersumer in supersumers:
    sentence.extend([sup_in, supersumer.name, sup_out])

  return "".join(sentence)

In [9]:
# Create sentence from the given source entity, containing its mapping's parent & siblings

def get_combined_sentence(source_id, target_id, source_onto, target_onto):
  sub_in, sub_out, sup_in, sup_out = entity_markers

  if source_id not in source_onto.terms() or target_id not in target_onto.terms():
    return

  subsumer = source_onto.get_term(source_id)
  equivalent = target_onto.get_term(target_id)
  parents = list(equivalent.superclasses(distance=1, with_self=False))

  if len(parents) != 1:
    return
  supersumer = parents[0]

  #sentence = "".join([sub_in, subsumer.name, sub_out, sup_in, supersumer.name, sup_out])
  sentence = (sep_token + cls_token).join([subsumer.name, supersumer.name])
  return sentence

In [10]:
source_id = "DOID:0014667"
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
# tokenizer.add_tokens(entity_markers)

sentence = get_combined_sentence(source_id, subs_doid2ncit[source_id], doid, ncit)
tokenized = tokenizer(sentence)

print(sentence)
tokenizer.convert_ids_to_tokens(tokenized['input_ids'])

disease of metabolism[SEP][CLS]Non-Neoplastic Disorder


['[CLS]',
 'disease',
 'of',
 'metabolism',
 '[SEP]',
 '[CLS]',
 'non',
 '-',
 'neoplastic',
 'disorder',
 '[SEP]']

# Train end-to-end BERT model

In [11]:
def generate_labelled_biencoder_samples(subs, negsubs, source_onto, target_onto):
  random.seed(3)
  source_samples = []
  target_samples = []
  labels = []

  pairs = list(subs.items()) + list(negsubs.items())
  zeros_and_ones = [0] * len(negsubs) + [1] * len(subs)
  labelled_pairs = list(zip(zeros_and_ones, pairs))
  random.shuffle(labelled_pairs)

  for label, (source_id, target_id) in tqdm(labelled_pairs):
    source_sentence = get_sentence(source_id, source_onto)
    target_sentence = get_sentence(target_id, target_onto)
    if source_sentence and target_sentence:
      source_samples.append(source_sentence)
      target_samples.append(target_sentence)
      labels.append(label)

  print()
  return source_samples, target_samples, labels

In [12]:
source_samples, target_samples, labels = generate_labelled_biencoder_samples(subs_doid2ncit, neg_subs_doid2ncit, doid, ncit)
print(f"Samples: {len(source_samples)}, {len(target_samples)}")
print(f"Labels: {len(labels)}")

100%|██████████| 3766/3766 [00:00<00:00, 11737.27it/s]


Samples: 3766, 3766
Labels: 3766





In [13]:
def filter_source_target(Xi, source_samples, target_samples):
  X_source = np.array([source_samples[i] for i in Xi])
  X_target = np.array([target_samples[i] for i in Xi])
  return X_source, X_target

In [14]:
def collate_dataset(X_train, X_val, X_test, y_train, y_val, y_test):
  dataset_train = Dataset.from_dict({'sample':X_train, 'label':y_train})
  dataset_val = Dataset.from_dict({'sample':X_val, 'label':y_val})
  dataset_test = Dataset.from_dict({'sample':X_test, 'label':y_test})
  dataset = DatasetDict({'train':dataset_train,'val':dataset_val,'test':dataset_test})
  return dataset

In [15]:
Xi = np.arange(len(source_samples))
y = np.array(labels)
Xi_train_val, Xi_test, y_train_val, y_test = train_test_split(Xi, y, test_size=0.2, random_state=3)
Xi_train, Xi_val, y_train, y_val = train_test_split(Xi_train_val, y_train_val, test_size=0.25, random_state=3)

X_source_train, X_target_train = filter_source_target(Xi_train, source_samples, target_samples)
X_source_val, X_target_val = filter_source_target(Xi_val, source_samples, target_samples)
X_source_test, X_target_test = filter_source_target(Xi_test, source_samples, target_samples)

source_data = collate_dataset(X_source_train, X_source_val, X_source_test, y_train, y_val, y_test)
target_data = collate_dataset(X_target_train, X_target_val, X_target_test, y_train, y_val, y_test)

In [16]:
class BiEncoderModel(torch.nn.Module):
  def __init__(self, model_name, num_labels, id2label=None, label2id=None, token_embeddings_size=None, hidden_layer=-1):
    super().__init__()
    self.source_model = AutoModel.from_pretrained(model_name)
    self.target_model = AutoModel.from_pretrained(model_name)
    if token_embeddings_size:
      self.source_model.resize_token_embeddings(token_embeddings_size)
      self.target_model.resize_token_embeddings(token_embeddings_size)
    self.source_model.config.pad_token_id = tokenizer.pad_token_id
    self.target_model.config.pad_token_id = tokenizer.pad_token_id

    self.num_labels = num_labels
    self.hidden_layer = hidden_layer

    self.linear = torch.nn.Linear(32, num_labels)
    self.dropout = torch.nn.Dropout(0.1)
    self.similarity = torch.nn.CosineSimilarity(dim=-1)
    self.loss = torch.nn.BCEWithLogitsLoss()

  def forward(
      self,
      s_input_ids=None, t_input_ids=None,
      s_attention_mask=None, t_attention_mask=None,
      s_token_type_ids=None, t_token_type_ids=None,
      s_position_ids=None, t_position_ids=None,
      s_head_mask=None, t_head_mask=None,
      s_inputs_embeds=None, t_inputs_embeds=None,
      labels=None
    ):

    source_outputs = self.source_model(
      s_input_ids,
      attention_mask=s_attention_mask,
      token_type_ids=s_token_type_ids,
      position_ids=s_position_ids,
      head_mask=s_head_mask,
      inputs_embeds=s_inputs_embeds,
    )

    target_outputs = self.target_model(
      t_input_ids,
      attention_mask=t_attention_mask,
      token_type_ids=t_token_type_ids,
      position_ids=t_position_ids,
      head_mask=t_head_mask,
      inputs_embeds=t_inputs_embeds,
    )

    pooled_source_outputs = self.dropout(source_outputs[1])
    pooled_target_outputs = self.dropout(target_outputs[1])
    logits = self.similarity(pooled_source_outputs, pooled_target_outputs)

    loss = None
    if labels is not None:
      loss = self.loss(logits.view(-1), labels.view(-1).float())

    return SequenceClassifierOutput(loss=loss, logits=logits)

In [17]:
def full_determinism(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.use_deterministic_algorithms(True)
  torch.backends.cudnn.deterministic = True

In [18]:
def format_tokenized_dataset(dataset, preprocess_tokenize):
  tokenized_data = dataset.map(preprocess_tokenize, batched=True, batch_size=len(dataset["train"]["sample"]))
  tokenized_data = tokenized_data.remove_columns(["sample"])
  tokenized_data = tokenized_data.rename_column("label", "labels")
  tokenized_data.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
  return tokenized_data

In [29]:
def show_results(epoch, loss, metrics):
  print(f"\n\nEPOCH {epoch}\n")
  print(f"Training loss: {loss}")
  pprint(metrics)
  print()

In [20]:
def setup_biencoder(source_data, target_data, id2label=None, label2id=None, entity_markers=[]):
  pretrained = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
  tokenizer = AutoTokenizer.from_pretrained(pretrained)
  tokenizer.add_tokens(entity_markers)
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})

  preprocess_tokenize = lambda data: tokenizer(data["sample"], padding="longest")
  tokenized_source = format_tokenized_dataset(source_data, preprocess_tokenize)
  tokenized_target = format_tokenized_dataset(target_data, preprocess_tokenize)

  model = BiEncoderModel(pretrained, num_labels=1, id2label=id2label, label2id=label2id, token_embeddings_size=len(tokenizer))
  return model, tokenized_source, tokenized_target

def evaluate_biencoder(model, tokenized_source, tokenized_target, batch_size=32):
  eval_dataloader_index = DataLoader(Dataset.from_dict({'index' : range(len(tokenized_source["val"]))}), batch_size=batch_size)
  metrics = [evaluate.load("accuracy"), evaluate.load("precision"), evaluate.load('recall'), evaluate.load('f1')]

  model.eval()
  for batch in eval_dataloader_index:
    batch_index = list(batch["index"])
    source_batch = tokenized_source["val"][batch_index]
    target_batch = tokenized_target["val"][batch_index]
    labels = source_batch["labels"]

    source_batch = {"s_" + k: v.to(device) for (k, v) in source_batch.items() if k != "labels"}
    target_batch = {"t_" + k: v.to(device) for (k, v) in target_batch.items() if k != "labels"}
    params = source_batch | target_batch
    params["labels"] = labels.to(device)

    with torch.no_grad():
      outputs = model(**params)

    logits = outputs.logits.cpu()
    predictions = np.where(logits.squeeze() >= 0.95, 1, 0)
    for metric in metrics:
      metric.add_batch(predictions=predictions, references=labels.cpu())

  metric_dict = metrics[0].compute()
  for metric in metrics[1:]:
    metric_dict.update(metric.compute(average='macro'))

  return metric_dict

def train_biencoder(model, tokenized_source, tokenized_target, learning_rate=1e-5, epochs=3, batch_size=32):
  train_dataloader_index = DataLoader(Dataset.from_dict({'index' : range(len(tokenized_source["train"]))}), shuffle=True, batch_size=batch_size)
  num_training_steps = epochs * len(train_dataloader_index)

  optimizer = AdamW(model.parameters(), lr=learning_rate)
  scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

  progress_bar = tqdm(range(num_training_steps), position=0, leave=True)

  model.train()
  for epoch in range(1, epochs+1):
    for i, batch in enumerate(train_dataloader_index):
      batch_index = list(batch["index"])
      source_batch = tokenized_source["train"][batch_index]
      target_batch = tokenized_target["train"][batch_index]
      labels = source_batch["labels"]

      source_batch = {"s_" + k: v.to(device) for (k, v) in source_batch.items() if k != "labels"}
      target_batch = {"t_" + k: v.to(device) for (k, v) in target_batch.items() if k != "labels"}
      params = source_batch | target_batch
      params["labels"] = labels.to(device)

      outputs = model(**params)
      loss = outputs.loss
      loss.backward()

      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()
      progress_bar.update(1)

    metrics = evaluate_biencoder(model, tokenized_source, tokenized_target)
    show_results(epoch, loss, metrics)

In [21]:
torch.cuda.empty_cache()

full_determinism(seed=3)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model, tokenized_source, tokenized_target = setup_biencoder(source_data, target_data)
model.to(device)

print("")

Map:   0%|          | 0/2259 [00:00<?, ? examples/s]

Map:   0%|          | 0/753 [00:00<?, ? examples/s]

Map:   0%|          | 0/754 [00:00<?, ? examples/s]

Map:   0%|          | 0/2259 [00:00<?, ? examples/s]

Map:   0%|          | 0/753 [00:00<?, ? examples/s]

Map:   0%|          | 0/754 [00:00<?, ? examples/s]




In [28]:
torch.cuda.empty_cache()
train_biencoder(model, tokenized_source, tokenized_target)

 33%|███▎      | 71/213 [00:58<01:50,  1.28it/s]



EPOCH 1
Training loss: 0.40505069494247437
{'accuracy': 0.9442231075697212,
 'f1': 0.9430736554107566,
 'precision': 0.9527942794279427,
 'recall': 0.9390137239564815}



 67%|██████▋   | 142/213 [02:07<00:52,  1.34it/s]



EPOCH 2
Training loss: 0.3478681147098541
{'accuracy': 0.9614873837981408,
 'f1': 0.9610020841034872,
 'precision': 0.9638654617031365,
 'recall': 0.9591552300362653}



100%|██████████| 213/213 [03:40<00:00,  1.04s/it]



EPOCH 3
Training loss: 0.38828930258750916
{'accuracy': 0.9588313413014609,
 'f1': 0.9582872449353392,
 'precision': 0.9616032116032116,
 'recall': 0.9562397781412216}




