# **PRETRAINING**
## **CSE6250 Big Data Analytics for Healthcare Spring 2024 Final Project Draft**
Paper: "Generative Biomedical Entity Linking via Knowledge Base-Guided Pre-training and Synonyms-Aware Fine-tuning" by Hongyi Yuan, Zheng Yuan, and Sheng Yu

## **Mount Notebook to Google Drive**

In [1]:
import datetime
start = datetime.datetime.now()
print(start)

import os, sys
from google.colab import drive

print('Current Working Directory:', os.getcwd())
drive.mount('/content/gdrive')
drive_path = '/content/gdrive/My Drive/bd4h-team-a4'
os.chdir(drive_path)
print('Working Directory:', os.getcwd())

import warnings
warnings.filterwarnings('ignore')

2024-04-21 12:12:32.548142
Current Working Directory: /content
Mounted at /content/gdrive
Working Directory: /content/gdrive/My Drive/bd4h-team-a4


In [2]:
%pip install transformers[torch]
%pip install peft
%pip install evaluate



In [4]:
# import packages
from collections import Counter
# import evaluate
import json
import numpy as np
import pandas as pd
# from peft import LoraConfig, get_peft_model
import pickle
import random
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import BartConfig, BartTokenizer, BartForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from typing import List, Dict

# finetuning hyperparams
evaluation = False  # change to True for test/dev
testset = True  # change to True for test
prefix_mention_is = True

pretrain_data_path = './pretrain/2023AB'
pretrain_model_path = './pretrain/model'
finetune_dict_path = './finetune/data/bc5cdr/target_kb.json'
finetune_data_path = './finetune/data/bc5cdr/'

### **3.a.1 Pretraining Data**

Load raw data (UMLS pretraining):

In [5]:
# Create Pre-Training Knowledge Base from UMLS Data

# (DRAFT cui_list = ['C1834523', 'C1998071', 'C1136989', 'C1562931', 'C5241135', 'C1334508'])
# limit dataset to relevant source ontologies
source_ontology = ['CPT', 'FMA', 'GO', 'HGNC', 'HPO', 'ICD10', 'ICD10CM', 'ICD9CM', 'MDR', 'MSH', 'MTH', 'NCBI', 'NCI', 'NDDF', 'NDFRT', 'OMIM', 'RXNORM', 'SNOMEDCT_US']
# limit dataset to 21 semantic types
semantic_type = set(['T005', 'T007', 'T017', 'T022', 'T031', 'T033', 'T037', 'T038', 'T058', 'T062', 'T074', 'T082', 'T091', 'T092', 'T097', 'T098', 'T103', 'T168', 'T170', 'T201', 'T204'])

# manually define templates for synthetic training sentences
templates = ['is defined as', 'is described as', 'is the definition of', 'describes', 'defines']
templates_nodef = ['are the synonyms of', 'indicate the same concept as', 'has synonyms, such as', 'refers to the same concept as']
templates_nosyn = ['is', 'is the same as', 'is', 'is the same as']
# define special tokens
special_tokens = ['START', 'END']

# pre-process STY.csv
semantic_type_ontology = pd.read_csv(f'{pretrain_data_path}/STY.csv')

semantic_type_size = 0

while len(semantic_type) != semantic_type_size:                           # recursively go through list to add all child and child-child TUIs to semantic_type
  semantic_type_size = len(semantic_type)
  for i in range(len(semantic_type_ontology)):
    if semantic_type_ontology['Parents'][i][-4:] in semantic_type:        # if the parent TUI is in semantic_type
      semantic_type.update([semantic_type_ontology['Class ID'][i][-4:]])  # add the child TUI to semantic_type
print('STY.csv loaded')

# pre-process MRCONSO.RRF
cui_synonyms = dict()
source_ontology_cuis = set()
with open(f'{pretrain_data_path}/MRCONSO.RRF') as mrconso:
# with open(f'{pretrain_data_path}/MRCONSO-SMALL.RRF') as mrconso:
  for line in mrconso:
    record = line.strip().split("|")
    cui = record[0]
    language = record[1]
    source = record[11]

    if source in source_ontology and language == "ENG":   # skip non-english records (DRAFT if cui in cui_list)
      cui_name = record[14]

      if cui not in cui_synonyms:             # create dictionary of cuis to mrconso strings
        cui_synonyms[cui] = [cui_name]
      else:
        cui_synonyms[cui].append(cui_name)

      source_ontology_cuis.update([cui])    # if source is in ontology list above, add it to set
print('MRCONSO.RRF loaded')

# pre-process MRSTY.RRF
semantic_type_cuis = dict()
with open(f'{pretrain_data_path}/MRSTY.RRF') as mrsty:
# with open(f'{pretrain_data_path}/MRSTY-SMALL.RRF') as mrsty:
  for line in mrsty:
    record = line.strip().split('|')
    cui = record[0]
    semantic = record[1]

    if semantic in semantic_type:         # choose only semantic types in 21 types above (DRAFT if cui in cui_list)
      type_str = record[3].lower()
      semantic_type_cuis[cui] = type_str  # if semantic type is in semantic type list above (DRAFT if cui in cui_list)
    elif cui in cui_synonyms:
      cui_synonyms.pop(cui)
print('MRSTY.RRF loaded')

# remove special characters from cui strings
def remove_special_chars(synonym):
  if isinstance(synonym, list) or isinstance(synonym, set):                              # recursively iterate through data structures and remove special characters from strings contained in them (borrowed directly from paper repo)
    return [remove_special_chars(s) for s in synonym]
  synonym = synonym.strip().lower()                                                      # convert synonym string to lowercase
  for char in ',.;{}[]()+-_*/?!`\"\'=%></':
    synonym = synonym.replace(char, ' ')
  return ' '.join([s for s in synonym.split() if s])

# pre-process MRDEF.RRF
cui_definitions = dict()
cui_set = set()
aui_set = set()
with open(f'{pretrain_data_path}/MRDEF.RRF') as mrdef:                               # get definitions of cui concepts
  for line in mrdef:
    record = line.strip().split('|')
    cui = record[0]
    if cui in cui_synonyms:
      cui_set.update([record[0]])

with open(f'{pretrain_data_path}/MRCONSO.RRF') as mrconso:
  for line in mrconso:
    record = line.strip().split('|')
    cui = record[0]
    language = record[1]

    if language != 'ENG' and cui in cui_set:     # remove non-English records (DRAFT if cui in cui_list)
      aui = record[7]
      aui_set.update([aui])

with open(f'{pretrain_data_path}/MRDEF.RRF') as mrdef:
  for line in mrdef:
    record = line.strip().split('|')
    cui = record[0]
    aui = record[1]

    if cui in cui_set and aui not in aui_set:
      definition = record[5].lower()
      if cui not in cui_definitions:
        cui_definitions[cui] = [definition]
      else:
        cui_definitions[cui].append(definition)
print('MRDEF.RRF loaded')

print('cui_synonyms', len(cui_synonyms))
print('cui_definitions', len(cui_definitions))

STY.csv loaded
MRCONSO.RRF loaded
MRSTY.RRF loaded
MRDEF.RRF loaded
cui_synonyms 2050287
cui_definitions 220355


Process raw data (UMLS pretraining):

In [6]:
# create pretraining data
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
max_length = 1024

encoder = {'input_ids': [], 'attention_mask': []}
decoder = {'input_ids': [], 'attention_mask': [], 'labels': []}
labels_dict = dict()

cui_start = 500000

# choose templates based on presense/absence of definitions and synonyms
for index, cui in enumerate(cui_synonyms):
  if index > cui_start:
    if index == cui_start + 1:
      print('start index', index)
    for synonym in cui_synonyms[cui]:

      # when no definitions are present
      if cui not in cui_definitions:
        synonyms = [s for s in cui_synonyms[cui] if s != synonym]
        if len(synonyms) > 1:
          mention = random.choice(synonyms)
          random.shuffle(synonyms)

          i = random.randint(0, 3)
          description = ' '.join([special_tokens[0], mention, special_tokens[1], templates_nodef[i], ', '.join(synonyms[:3])]) if i > 1 else ' '.join([', '.join(synonyms[:3]), templates_nodef[i], special_tokens[0], mention, special_tokens[1]])

        # when no definitions or synonyms are present
        else:
          mention = synonym
          i = random.randint(0, 3)
          description = ' '.join([special_tokens[0], mention, special_tokens[1], templates_nosyn[i], synonym]) if i > 1 else ' '.join([synonym, templates_nosyn[i], special_tokens[0], mention, special_tokens[1]])

      else:
        i = random.randint(0, 4)

        # when deinitions and synonyms are present
        synonyms = [s for s in cui_synonyms[cui] if s != synonym]
        if len(synonyms) > 1:
          mention = random.choice(synonyms)

        # when definitions and no synonyms are present
        else:
          mention = synonym

        random.shuffle(cui_definitions[cui])
        i = random.randint(0, 3)
        description = ' '.join([special_tokens[0], mention, special_tokens[1], templates[i], ' '.join(cui_definitions[cui][:2])]) if i < 2 else ' '.join([' '.join(cui_definitions[cui][:2]), templates[i], special_tokens[0], mention, special_tokens[1]])

        tokens = tokenizer(description, padding=True, truncation=True)['input_ids'] # handle very long descriptions
        if len(tokens) > 700:
          description = tokenizer.decode(tokens[:700]) if i < 2 else tokenizer.decode(tokens[-700:])

      if synonym in labels_dict.keys():
        labels_dict[synonym] += 1
      else:
        labels_dict[synonym] = 1

      # tokenize
      encoder_ids = tokenizer(description, padding=True, truncation=True)['input_ids'] # [:max_length]
      decoder_ids = tokenizer(f' {mention} is {synonym}', padding=True, truncation=True)['input_ids']
      prefix_len = len(tokenizer(f' {mention} is', padding=True, truncation=True)['input_ids'])

      # attention masks
      encoder_attention_mask = torch.tensor([1] * len(encoder_ids) + [0] * (max_length - len(encoder_ids)))
      decoder_attention_mask = torch.tensor([1] * len(decoder_ids) + [0] * (max_length - len(decoder_ids)))
      labels = torch.tensor([-100] * prefix_len + decoder_ids[prefix_len:] + [-100] * (max_length - len(decoder_ids)))

      # add padding
      encoder_ids += [1] * (max_length - len(encoder_ids))
      encoder_ids = torch.tensor(encoder_ids)
      decoder_ids += [1] * (max_length - len(decoder_ids))
      decoder_ids = torch.tensor(decoder_ids)

      encoder['input_ids'].append(encoder_ids)
      encoder['attention_mask'].append(encoder_attention_mask)
      decoder['input_ids'].append(decoder_ids)
      decoder['attention_mask'].append(decoder_attention_mask)
      decoder['labels'].append(labels)

    if index % 10000 == 0:
      print('saving index', index)
      # save encoder and decoder checkpoints to disk
      with open(f'./pretrain/data/encoder/{index}.pkl', 'wb') as encoder_pkl:
        pickle.dump(encoder, encoder_pkl)
      with open(f'./pretrain/data/decoder/{index}.pkl', 'wb') as decoder_pkl:
        pickle.dump(decoder, decoder_pkl)

      print('checkpoint saved', index, len(encoder['input_ids']))
      encoder = {'input_ids': [], 'attention_mask': []}
      decoder = {'input_ids': [], 'attention_mask': [], 'labels': []}

# save final encoder and decoder checkpoint to disk
with open(f'./pretrain/data/encoder/{index}.pkl', 'wb') as encoder_pkl:
  pickle.dump(encoder, encoder_pkl)
with open(f'./pretrain/data/decoder/{index}.pkl', 'wb') as decoder_pkl:
  pickle.dump(decoder, decoder_pkl)
print('checkpoint saved', index, len(encoder['input_ids']))

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

start index 500001
saving index 510000
checkpoint saved 510000 17285
saving index 520000
checkpoint saved 520000 18272
saving index 530000
checkpoint saved 530000 23691
saving index 540000
checkpoint saved 540000 34877
saving index 550000
checkpoint saved 550000 21709
saving index 560000
checkpoint saved 560000 17861
saving index 570000
checkpoint saved 570000 16513
saving index 580000
checkpoint saved 580000 16487
saving index 590000
checkpoint saved 590000 23568
saving index 600000
checkpoint saved 600000 33593
saving index 610000
checkpoint saved 610000 29342
saving index 620000
checkpoint saved 620000 28267
saving index 630000
checkpoint saved 630000 28167
saving index 640000
checkpoint saved 640000 47605
saving index 650000
checkpoint saved 650000 94993
saving index 660000
checkpoint saved 660000 83037
saving index 670000
checkpoint saved 670000 29866
saving index 680000
checkpoint saved 680000 27544
saving index 690000
checkpoint saved 690000 16009
saving index 700000
checkpoint 

KeyboardInterrupt: 

In [6]:
class PretrainDataset(Dataset):
  def __init__(self, encoder, decoder):
    self.encoder = encoder
    self.decoder = decoder

  def __len__(self):
    return len(self.encoder['input_ids'])

  def __getitem__(self, i):
    item = {}
    item['input_ids'] = self.encoder['input_ids'][i]
    item['attention_mask'] = self.encoder['attention_mask'][i]
    item['decoder_input_ids'] = self.decoder['input_ids'][i]
    item['decoder_attention_mask'] = self.decoder['attention_mask'][i]
    item['label_ids'] = self.decoder['labels'][i]
    return item

In [20]:
# load encoder / decoder files
encoder = {'input_ids': [], 'attention_mask': []}
decoder = {'input_ids': [], 'attention_mask': [], 'labels': []}
encoder_path = './pretrain/data/encoder-370000/'
decoder_path = './pretrain/data/decoder-370000/'

for filename in os.listdir(encoder_path):
  if '.pkl' in filename:
    with open(os.path.join(encoder_path, filename), 'rb') as encoder_file:
      print('loading encoder', filename)
      encoder_pkl = pickle.load(encoder_file)
      encoder['input_ids'] += encoder_pkl['input_ids']
      encoder['attention_mask'] += encoder_pkl['attention_mask']
      print(len(encoder_pkl['input_ids']), len(encoder['input_ids']))

for filename in os.listdir(decoder_path):
  if '.pkl' in filename:
    with open(os.path.join(decoder_path, filename), 'rb') as decoder_file:
      print('loading decoder', filename)
      decoder_pkl = pickle.load(decoder_file)
      decoder['input_ids'] += decoder_pkl['input_ids']
      decoder['attention_mask'] += decoder_pkl['attention_mask']
      decoder['labels'] += decoder_pkl['labels']
      print(len(decoder_pkl['input_ids']), len(decoder['input_ids']))

# load data as dataset
pretrain_dataset = PretrainDataset(encoder, decoder)

# split into train and validation datasets
train_split = 0.9
eval_split = 0.1
generator = torch.Generator().manual_seed(42)
train_dataset, eval_dataset = random_split(pretrain_dataset, [train_split, eval_split], generator)

# save datasets
with open('./pretrain/data/train-370000.pkl', 'wb') as train_pkl:
  print('saving train.pkl', len(train_dataset))
  pickle.dump(train_dataset, train_pkl)
  print('train.pkl saved')
with open('./pretrain/data/eval-370000.pkl', 'wb') as eval_pkl:
  print('saving eval.pkl', len(eval_dataset))
  pickle.dump(eval_dataset, eval_pkl)
  print('eval.pkl saved')

loading encoder 370000.pkl
15537 15537
loading decoder 370000.pkl
15537 15537
saving train.pkl 13984
train.pkl saved
saving eval.pkl 1553
eval.pkl saved


Calculate statistics (UMLS pretraining):

In [None]:
# pretraining - raw datasets
print('pretraining - raw datasets')
def get_size_and_distribution(filename):
  count = 0
  label_distribution = {}
  with open(f'{pretrain_data_path}/{filename}') as file:
    for line in file:
      record = line.strip().split('|')
      cui = record[0]
      if cui in label_distribution.keys():
        label_distribution[cui] += 1
      else:
        label_distribution[cui] = 1
      count += 1
  print(f'{filename} record count:', count)

  print(f'{filename} label distribution (number of unique cuis):', dict(sorted(dict(Counter(v for v in label_distribution.values())).items(), key=lambda v: v[1], reverse=True)))

filenames = ['MRCONSO-SMALL.RRF', 'MRSTY-SMALL.RRF', 'MRDEF-SMALL.RRF']
for filename in filenames:
  get_size_and_distribution(filename)

count = 0
with open(f'{pretrain_data_path}/STY.csv') as file:
  for line in file:
    count += 1
print(f'STY.csv record count:', count)

# pretraining - processed data
print('\npretraining - processed data')
print('labels count:', len(labels_dict))
print('labels distribution:', Counter([k.lower() for k in labels_dict.keys()]))

### **3.b.1 Pretraining Model**

In [7]:
# pretraining hyperparams
num_epochs = 10
learning_rate = 1e-3
batch_size = 4
gradient_accumulation_steps = 100
label_smoothing_factor = 0.1

In [32]:
# load datasets
with open('./pretrain/data/train-370000.pkl', 'rb') as train_pkl:
  print('loading train_dataset...')
  train_dataset = pickle.load(train_pkl)
  print('train_dataset loaded', len(train_dataset))
with open('./pretrain/data/eval-370000.pkl', 'rb') as eval_pkl:
  print('loading eval_dataset...')
  eval_dataset = pickle.load(eval_pkl)
  print('eval_dataset loaded', len(eval_dataset))

loading train_dataset...
train_dataset loaded 13984
loading eval_dataset...
eval_dataset loaded 1553


In [23]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
print(tokenizer.decode([i for i in eval_dataset.dataset.decoder['labels'][0].tolist() if i != -100]))

 inferior genicular artery</s>


In [33]:
# load model
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

# configure lora
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['q_proj', 'v_proj'],
    lora_dropout=0.1,
    bias='none',
    modules_to_save=['classifier']
)
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 2,359,296 || all params: 408,650,752 || trainable%: 0.5773379807704355


In [90]:
# load args
args = Seq2SeqTrainingArguments(
  output_dir=pretrain_model_path,
  evaluation_strategy='steps',
  per_device_train_batch_size=8,
  per_device_eval_batch_size=1,
  gradient_accumulation_steps=1,
  # eval_accumulation_steps=10,
  learning_rate=5e-5,
  weight_decay=0,
  adam_epsilon=1e-8,
  num_train_epochs=3.0,
  warmup_steps=0,
  logging_dir='./pretrain/logs/',
  logging_steps=100,
  save_strategy='steps',
  save_only_model=True,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  # load_best_model_at_end=True,
  # metric_for_best_model='eval_loss',
  # greater_is_better=False,
  label_smoothing_factor=0.1
)

# load evaluation metric
# metric = evaluate.load('accuracy')
# def compute_metrics(pred):
#   predictions, labels = pred
#   print(predictions.shape, labels.shape)
#   return metric.compute(predictions=predictions, references=labels)

from datasets import load_metric
metric1 = load_metric("precision")
metric2 = load_metric("recall")
metric3 = load_metric("f1")
metric4 = load_metric("accuracy")
def compute_metrics(pred):
    # predictions, labels = eval_pred
    # predictions = predictions[0]

    labels_ids = pred.label_ids
    pred_ids = pred.predictions[0]

    # precision = metric1.compute(predictions=predictions, references=labels, average="micro")["precision"]
    # recall = metric2.compute(predictions=predictions, references=labels, average="micro")["recall"]
    # f1 = metric3.compute(predictions=predictions, references=labels, average="micro")["f1"]
    accuracy = metric4.compute(predictions=pred_ids, references=labels_ids)["accuracy"]

    # return {"precision": precision, "recall": recall, "f1": f1, "accuracy": accuracy}
    return {"accuracy": accuracy}

def preprocess_logits_for_metrics(logits, labels):
  # print(logits[0].shape, np.argmax(logits[0].cpu(), axis=-1).shape, labels.shape)
  predictions = np.argmax(logits[0].cpu(), axis=-1)
  return predictions, labels

# load trainer
trainer = Seq2SeqTrainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics)

In [35]:
# train the Bart model on our pretraining knowledge-base dataset
print('begin pretraining...')
torch.cuda.empty_cache()
trainer.train()
print('pretraining complete!')

#save pretrained model
print('saving pretrained model...')
trainer.save_model(pretrain_model_path)
print('pretrained model saved')

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
begin pretraining...


Step,Training Loss
100,8.4935
200,3.6732
300,2.3009
400,1.9789
500,1.8089
600,1.7266
700,1.6827
800,1.6623
900,1.6676
1000,1.6365


Checkpoint destination directory ./pretrain/model/checkpoint-1600 already exists and is non-empty. Saving will proceed but saved results may be invalid.


pretraining complete!
saving pretrained model...
pretrained model saved


In [66]:
# load model
model = BartForConditionalGeneration.from_pretrained(f'./pretrain/model-370000-large-lora/checkpoint-5200')

# load trainer
trainer = Seq2SeqTrainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics)

In [1]:
# train the Bart model on our pretraining knowledge-base dataset
print('begin evaluation...')
torch.cuda.empty_cache()
trainer.evaluate()
print('evaluation complete!')

begin evaluation...


NameError: name 'torch' is not defined

In [45]:
import gc
gc.collect()

136

In [None]:
end = datetime.datetime.now()
print("Duration:", end - start)