<a href="https://colab.research.google.com/github/stellaevat/ontology-mapping/blob/main/colabs/preprocess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pronto transformers[torch] datasets evaluate \
&& pip install accelerate -U \
&& wget -O doid.obo https://gla-my.sharepoint.com/:u:/g/personal/2526934t_student_gla_ac_uk/EfUC_RdrfZdOsOrtmNATjuoBPDaIkSTUMyxJXyO2KKC6yw?download=1 \
&& wget -O ncit.obo https://gla-my.sharepoint.com/:u:/g/personal/2526934t_student_gla_ac_uk/ETmaJIC0fAlItdsp8WQxS_wBzKN_6x08EZrtsOxVnbzvSg?download=1

In [2]:
import pronto
import torch
import evaluate
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModel, BioGptTokenizer, BioGptModel, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer

In [3]:
ncit = pronto.Ontology("ncit.obo")
doid = pronto.Ontology("doid.obo")

In [4]:
# Get subsumptions from CSV file to a dictionary

def get_mappings_from_file(filename):
  mappings = {}
  with open(filename) as f:
    for line in f:
      source_id, target_id = line.strip().split(',')
      mappings[source_id] = target_id
  return mappings

In [5]:
doid2ncit_equiv = get_mappings_from_file("doid2ncit_equiv.csv")
ncit2doid_equiv = get_mappings_from_file("ncit2doid_equiv.csv")
doid2ncit_subs = get_mappings_from_file("doid2ncit_subs.csv")
ncit2doid_subs = get_mappings_from_file("ncit2doid_subs.csv")

# Convert relations to sentences

> Currently considering parents, children & siblings for conceptual reasons, but could also take 'n-hop' appraoch, e.g. 1-hop only with parents and children, or 2-hop to include grandparents, grandchildren and siblings.

> How do I incorporate the desired mapping for training? Should I incorporate both all this AND target info, or too much? Could be SELF + desired relatives instead, or SELF + PARENT + DESIRED PARENT, etc.

In [34]:
entity_markers = ["[SUB]", "[/SUB]", "[SUP]", "[/SUP]"]

In [35]:
# Create sentence from the given source entity, containing its mapping's parent & siblings

def get_mixed_sentence(source_id, target_id, source_onto, target_onto):
  sub_in, sub_out, sup_in, sup_out = entity_markers

  if source_id not in source_onto.terms() or target_id not in target_onto.terms():
    return

  subsumer = source_onto.get_term(source_id)
  equivalent = target_onto.get_term(target_id)
  parents = list(equivalent.superclasses(distance=1, with_self=False))

  if len(parents) != 1:
    return

  supersumer = parents[0]
  siblings = set(supersumer.subclasses(distance=1, with_self=False))
  siblings.remove(equivalent)

  sentence = [sub_in, subsumer.name, sub_out, sup_in, supersumer.name, sup_out]
  # for sibling in siblings:
  #   sentence.extend([sibl_in, sibling.name, sibl_out])

  return " ".join(sentence)

In [36]:
# Create sentence from the given entity, containing its direct parents & siblings

def get_sentence(entity_id, onto):
  sub_in, sub_out, sup_in, sup_out = entity_markers

  subsumer = onto.get_term(entity_id)
  ## What to do if multiple parents? -- sort later
  supersumers = list(subsumer.superclasses(distance=1, with_self=False))
  siblings = set()
  for supersumer in supersumers:
    siblings.update(set(supersumer.subclasses(distance=1, with_self=False)))
  siblings.remove(subsumer)

  sentence = [sub_in, subsumer.name, sub_out]
  for supersumer in supersumers:
    sentence.extend([sup_in, supersumer.name, sup_out])
  # for sibling in siblings:
  #   sentence.extend([sibl_in, sibling.name, sibl_out])

  return " ".join(sentence)

In [37]:
source_id = "DOID:0014667"
sentence = get_sentence(source_id, doid)
print(sentence)

[SUB] disease of metabolism [/SUB] [SUP] disease [/SUP]


In [38]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
tokenizer.add_tokens(entity_markers)
tokenized = tokenizer(sentence)
tokenizer.convert_ids_to_tokens(tokenized['input_ids'])

['[CLS]',
 '[SUB]',
 'disease',
 'of',
 'metabolism',
 '[/SUB]',
 '[SUP]',
 'disease',
 '[/SUP]',
 '[SEP]']

# Train end-to-end BERT model

In [58]:
samples = []
targets = []
labels = []
for (doid_id, ncit_id) in doid2ncit_subs.items():
  sample_sentence = get_sentence(doid_id, doid)
  target_sentence = get_sentence(ncit_id, ncit)
  # sample_sentence = get_mixed_sentence(doid_id, ncit_id, doid, ncit)
  # target_sentence = get_mixed_sentence(ncit_id, doid_id, ncit, doid)
  if sample_sentence and target_sentence:
    samples.append(sample_sentence)
    targets.append(target_sentence)
    labels.append(ncit_id)

In [59]:
X = np.array(samples)
y = np.array(labels)
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.25, random_state=3)

In [60]:
id2label = {ind:label for (ind,label) in enumerate([term.id for term in ncit.terms()])}
label2id = {label:ind for (ind,label) in enumerate([term.id for term in ncit.terms()])}

dataset_train = Dataset.from_dict({'sample':X_train, 'label':[label2id[label] for label in y_train]})
dataset_val = Dataset.from_dict({'sample':X_val, 'label':[label2id[label] for label in y_val]})
dataset_test = Dataset.from_dict({'sample':X_test, 'label':[label2id[label] for label in y_test]})
dataset = DatasetDict({'train':dataset_train,'val':dataset_val,'test':dataset_test})

In [61]:
def compute_metrics(eval_pred):
  accuracy = evaluate.load("accuracy")
  macro_metrics = {'precision' : evaluate.load("precision"),
                   'recall' : evaluate.load('recall'),
                   'f1' : evaluate.load('f1')}

  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)

  metric_dict = {name : metric.compute(predictions=predictions, references=labels, average='macro') for (name, metric) in macro_metrics.items()}
  metric_dict['accuracy'] = accuracy.compute(predictions=predictions, references=labels)
  return metric_dict

In [62]:
def train_model(dataset, id2label, label2id, entity_markers=entity_markers, learning_rate=1e-4, epochs=1, batch_size=16):
  pretrained = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"

  tokenizer = AutoTokenizer.from_pretrained(pretrained)
  tokenizer.add_tokens(entity_markers)
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})

  preprocess_tokenize = lambda examples: tokenizer(examples["sample"], padding="longest")
  tokenized_data = dataset.map(preprocess_tokenize, batched=True)
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
  model = AutoModelForSequenceClassification.from_pretrained(
      pretrained, num_labels=len(id2label), id2label=id2label, label2id=label2id
  )
  model.resize_token_embeddings(len(tokenizer))
  model.config.pad_token_id = tokenizer.pad_token_id

  training_args = TrainingArguments(
    output_dir="testing",
    evaluation_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
  )

  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=tokenized_data['train'],
      eval_dataset=tokenized_data['val'],
      tokenizer=tokenizer,
      data_collator=data_collator,
      compute_metrics=compute_metrics,
  )

  return trainer, tokenized_data

In [63]:
torch.cuda.empty_cache()
trainer, tokenized_data = train_model(dataset, id2label, label2id)
trainer.train()

Map:   0%|          | 0/1129 [00:00<?, ? examples/s]

Map:   0%|          | 0/377 [00:00<?, ? examples/s]

Map:   0%|          | 0/377 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tokenizer size: 28899
Training shape 0: (1129, 33)
Training shape 1: (1129, 33)
Validation shape 0: (377, 46)
Validation shape 1: (377, 46)


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,11.34644,{'precision': 7.445669877611802e-05},{'recall': 0.0035087719298245615},{'f1': 0.00014581909318751423},{'accuracy': 0.021220159151193633}


  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "{'precision': 7.445669877611802e-05}" of type <class 'dict'> for key "eval/precision" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'recall': 0.0035087719298245615}" of type <class 'dict'> for key "eval/recall" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'f1': 0.00014581909318751423}" of type <class 'dict'> for key "eval/f1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'accuracy': 0.021220159151193633}" of type <class 'dict'> for key "eval/accuracy" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


TrainOutput(global_step=71, training_loss=11.754845847546214, metrics={'train_runtime': 24.0388, 'train_samples_per_second': 46.966, 'train_steps_per_second': 2.954, 'total_flos': 51317062456968.0, 'train_loss': 11.754845847546214, 'epoch': 1.0})

In [64]:
predictions, label_ids, metrics = trainer.predict(tokenized_data['val'])
labels_predicted = [id2label[prediction] for prediction in np.argmax(predictions, axis=1)]
print(metrics)

  _warn_prf(average, modifier, msg_start, len(result))


{'test_loss': 11.346440315246582, 'test_precision': {'precision': 7.445669877611802e-05}, 'test_recall': {'recall': 0.0035087719298245615}, 'test_f1': {'f1': 0.00014581909318751423}, 'test_accuracy': {'accuracy': 0.021220159151193633}, 'test_runtime': 4.4393, 'test_samples_per_second': 84.924, 'test_steps_per_second': 5.406}
