# FineTuning Data and Model


In [None]:
# importing packages
#!pip install accelerate -U
import os
import sys
import numpy as np
from tqdm import tqdm
import pickle
import json
import argparse
import pickle
import pandas as pd
import torch
import torch.nn as nn
from transformers import TrainingArguments,BartTokenizer, BartConfig,BartForConditionalGeneration,Seq2SeqTrainer, Seq2SeqTrainingArguments
from torch.utils.data import Dataset, DataLoader
from typing import List,Dict

In [None]:
#connecting to the data file
'''
!pip install gdown
!gdown https://drive.google.com/drive/folders/1CPEDAJ2Kezx6U2zGQhlmVS3Je3cW8h1e --folder'''

!set CUDA_VISIBLE_DEVICES=0
from google.colab import drive
# View current working directory
print("Current Working Directory:", os.getcwd())

# Mount Google Drive
drive.mount('/content/gdrive')
# Change working directory to your file position
path = "/content/gdrive/My Drive/Georgia Tec/BD4H/Entity_Linking"
os.chdir(path)
# Confirm the change
print("Working Directory:", os.getcwd())

Current Working Directory: /content
Mounted at /content/gdrive
Working Directory: /content/gdrive/My Drive/Georgia Tec/BD4H/Entity_Linking


Prefix Tree Trie Class - Class to hold the prefix tree

In [None]:
#class definitions

# creating prefix tree
#NOTE: The Trie data structure is reused from the provided code.
class Trie(object):
    def __init__(self, sequences: List[List[int]] = []):
        self.trie_dict = {}
        self.len = 0
        if sequences:
            for sequence in sequences:
                Trie._add_to_trie(sequence, self.trie_dict)
                self.len += 1

        self.append_trie = None
        self.bos_token_id = None

    def append(self, trie, bos_token_id):
        self.append_trie = trie
        self.bos_token_id = bos_token_id

    def add(self, sequence: List[int]):
        Trie._add_to_trie(sequence, self.trie_dict)
        self.len += 1

    def get(self, prefix_sequence: List[int]):
        return Trie._get_from_trie(
            prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
        )

    @staticmethod
    def load_from_dict(trie_dict):
        trie = Trie()
        trie.trie_dict = trie_dict
        trie.len = sum(1 for _ in trie)
        return trie

    @staticmethod
    def _add_to_trie(sequence: List[int], trie_dict: Dict):
        if sequence:
            if sequence[0] not in trie_dict:
                trie_dict[sequence[0]] = {}
            Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])

    @staticmethod
    def _get_from_trie(
        prefix_sequence: List[int],
        trie_dict: Dict,
        append_trie=None,
        bos_token_id: int = None,
    ):
        #tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", max_length=1024)
        #print (prefix_sequence)
        #print (tokenizer.decode(prefix_sequence))
        if len(prefix_sequence) == 0:
            #print ('Prefix sequence is empty')
            output = list(trie_dict.keys())
            #if len(output)<100:
             #   print ('Output: ',output)
            if append_trie and bos_token_id in output:
                output.remove(bos_token_id)
                output += list(append_trie.trie_dict.keys())
            return output

        elif prefix_sequence[0] in trie_dict:

            return Trie._get_from_trie(
                prefix_sequence[1:],
                trie_dict[prefix_sequence[0]],
                append_trie,
                bos_token_id,
            )

        else:
            #print ('Missing: ',prefix_sequence[0])
            if append_trie:
                return append_trie.get(prefix_sequence)
            else:
                return [2]

    def __iter__(self):
        def _traverse(prefix_sequence, trie_dict):
            if trie_dict:
                for next_token in trie_dict:
                    yield from _traverse(
                        prefix_sequence + [next_token], trie_dict[next_token]
                    )
            else:
                yield prefix_sequence

        return _traverse([], self.trie_dict)

    def __len__(self):
        return self.len

    def __getitem__(self, value):
        return self.get(value)




Class that is child of DataSet class - Used to create training, dev, and test datasets

In [None]:
# Create Custom DataSet

# data processing class/DataSet

class ChemDiseaseDataset(Dataset):

    def __init__(self,encodings,labels,test_set=False):
        self.encodings = encodings
        self.labels = labels
        self.test_set = test_set



    def __getitem__(self,index):

        item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}
        item['label_ids'] = torch.tensor(self.labels['labels'][index])
        item['decoder_input_ids'] = torch.tensor(self.labels['decoder_input_ids'][index])   #I COMMENTED THIS FOR THE LABEL ONLY TRAINING WIHTOUT DECODER INPUT ID.
        item['decoder_attention_mask'] = torch.tensor(self.labels['attention_mask'][index])
        # the decoder atten mask has the same length as label of decoder input
        if self.test_set:
            item['decoder_input_ids_test'] = torch.tensor(self.labels['decoder_input_ids_test'][index])
            item['decoder_attention_mask_test'] = torch.tensor(self.labels['attention_mask_test'][index])


        return item



    def __len__(self):
        return len(self.labels['labels'])




def return_target_tokens(file_path,tokenizer,prefix_mention_is):

    with open(file_path, 'r') as f:
        output =[line.strip('\n') for line in f.readlines()]

    #output=output[200:205]
    tokens_y = {'labels':[], 'attention_mask':[], 'decoder_input_ids':[], 'decoder_input_ids_test':[], 'attention_mask_test':[], 'unlikelihood_tokens':[]}

    max_len_y=0
    for item in tqdm(output):

        iy = json.loads(item)
        prefix = list(tokenizer(' '+iy[0])['input_ids'])[1:-1]
        label = list(tokenizer(' '+iy[1])['input_ids'])[1:-1]
        y =prefix+label
        max_len_y = np.max([max_len_y, len(y)+2])

        if prefix_mention_is:
            tokens_y['decoder_input_ids'].append([2] + y) #decoder input is the entire decoder with the SEP token
            labs_prefix = [-100] * len(prefix) + label + [2]
            tokens_y['labels'].append(labs_prefix) #masking the prompt part of the label.
            assert len(labs_prefix) == len(y) + 1
        else:
            tokens_y['decoder_input_ids'].append([2] + label) #decoder input
            tokens_y['labels'].append(label + [2]) # labels

        tokens_y['attention_mask'].append(list(np.ones_like(tokens_y['decoder_input_ids'][-1])))

        #in training we provide the entire sentence but in test we do not.
        if 'test' in file_path:
            #print ("For the test section only. Not for dev section.")
            if prefix_mention_is:
                tokens_y['decoder_input_ids_test'].append(prefix) #I ADDED THE 2 HERE.
            else:
                tokens_y['decoder_input_ids_test'].append([2]) #is 2 correct here for BOS or is it separator.
            tokens_y['attention_mask_test'].append(list(np.ones_like(tokens_y['decoder_input_ids_test'][-1])))

    #note attention_mask_test and decoder_input_ids_test is not padded. do we need to do it?
    for index in range(len(tokens_y['decoder_input_ids'])):
         tokens_y['decoder_input_ids'][index] = list(np.pad(tokens_y['decoder_input_ids'][index], ((0,max_len_y- len(tokens_y['decoder_input_ids'][index]))), 'constant', constant_values = (1,1)))

    for index in range(len(tokens_y['attention_mask'])):
         tokens_y['attention_mask'][index] = list(np.pad(tokens_y['attention_mask'][index], ((0,max_len_y- len(tokens_y['attention_mask'][index]))), 'constant', constant_values = (0,0)))

    for index in range(len(tokens_y['labels'])):
         tokens_y['labels'][index] = np.pad(tokens_y['labels'][index], ((0,max_len_y - len(tokens_y['labels'][index]))), 'constant', constant_values = (-100,-100))
         tokens_y['labels'][index] = list(tokens_y['labels'][index].astype(np.int64))


    return tokens_y


def return_source_tokens(file_path,tokenizer):
    with open(file_path, 'r') as f:
        output =[line.strip('\n') for line in f.readlines()]

    #output =output[200:205]
    token_x={'input_ids':[], 'attention_mask':[]}

    for x in tqdm(output):
        ix = json.loads(x)[0]
        line = tokenizer(' '+ix,padding='max_length',truncation=True)
        token_x['input_ids'].append(line['input_ids'])
        token_x['attention_mask'].append(line['attention_mask'])

    return token_x



def create_pretrained_datasets(tokenizer,dataset_path,prefix_mention_is=False,evaluate=False):
     if evaluate:
                file_path = f'{dataset_path}/test'
                test_token_x = return_source_tokens(file_path+'.source',tokenizer)
                test_token_y = return_target_tokens(file_path+'.target',tokenizer,prefix_mention_is)
                test_set = ChemDiseaseDataset(test_token_x, test_token_y,test_set=True)

                return None,None,test_set
     else:
                file_path = f'{dataset_path}/train'
                train_token_x = return_source_tokens(file_path+'.source',tokenizer)
                train_token_y = return_target_tokens(file_path+'.target',tokenizer,prefix_mention_is)
                train_set = ChemDiseaseDataset(train_token_x, train_token_y,test_set=False)

                file_path = f'{dataset_path}/dev'
                dev_token_x = return_source_tokens(file_path+'.source',tokenizer)
                dev_token_y = return_target_tokens(file_path+'.target',tokenizer,prefix_mention_is)
                dev_set = ChemDiseaseDataset(dev_token_x, dev_token_y,test_set=False)
                return train_set,dev_set,None

Training Method - Used for FineTuning - after pretraining

In [None]:
#Training function

def train(prefix_mention_is,evaluation):
    model_load_path = './pre_train_model/'   #pretrained model
    max_position_embeddings = 1024
    attention_dropout = 0.1
    dropout = 0.1
    dataset_path = './data/bc5cdr/'     #data path containing sample test and dev data
    model_save_path = './model_checkpoints/'  #path to save finetuned model
    model_token_path = 'facebook/bart-large'

    output_dir= './model_checkpoints/'
    num_train_epochs= 5
    per_device_train_batch_size= 16
    per_device_eval_batch_size=1
    warmup_steps=2000
    weight_decay=0.01
    logging_path= './logs/'
    logging_steps=1000
    save_steps=5000
    evaluation_strategy='steps'
    init_lr=1e-5 #1e-5 is original
    label_smoothing_factor=0.1
    max_grad_norm=0.1
    max_steps=20000
    eval_steps=120
    lr_scheduler_type= 'polynomial'
    seed=71
    gradient_accumulation_steps=1 #50
    #lr_scheduler_kwargs = {'lr_end':1e-8,'power':1.0}





    bartconf = BartConfig.from_pretrained("facebook/bart-large")
    #bartconf = BartConfig.from_pretrained(model_load_path)
    bartconf.max_position_embeddings = max_position_embeddings
    bartconf.attention_dropout = attention_dropout
    bartconf.dropout = dropout


    print (dataset_path)
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", max_length=1024)

    #model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", config = bartconf)
    model = BartForConditionalGeneration.from_pretrained(model_load_path, config = bartconf)  #use pretrained model.

    ## REMEMBER TO CHANGE THE LOAD PATH BACK FROM CHECKPOINT
    #model = BartForConditionalGeneration.from_pretrained(model_save_path+'checkpoint-10000/', config = bartconf)  #use appropriate path for loading from checkpoint
    model.to('cpu')

    train_dataset, eval_dataset, _ = create_pretrained_datasets(tokenizer,
                                                    dataset_path,
                                                    prefix_mention_is = prefix_mention_is,
                                                    evaluate = evaluation,
                                                    )
    #train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
    print ('dataset loaded')
    training_args = Seq2SeqTrainingArguments(
                output_dir=output_dir,          # output directory
                num_train_epochs=num_train_epochs,              # total number of training epochs
                per_device_train_batch_size=per_device_train_batch_size,  # batch size per device during training
                per_device_eval_batch_size=per_device_eval_batch_size,   # batch size for evaluation
                warmup_steps=warmup_steps,                # number of warmup steps for learning rate scheduler
                weight_decay=weight_decay,               # strength of weight decay
                logging_dir=logging_path,            # directory for storing logs
                logging_steps=logging_steps,
                save_steps=save_steps,
                evaluation_strategy=evaluation_strategy,
                eval_steps=eval_steps,
                learning_rate=init_lr,
                label_smoothing_factor=label_smoothing_factor,
                max_grad_norm=max_grad_norm,
                max_steps=max_steps,
                lr_scheduler_type=lr_scheduler_type,
                seed=seed,
                gradient_accumulation_steps=gradient_accumulation_steps,
                #lr_scheduler_kwargs =  lr_scheduler_kwargs   #custom arguments for polynomial LR scheduler
                #dropout=dropout,
                #attention_dropout=dropout
                )


    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset
    )

    #finetuning
    trainer.train()

    #save fine tuned model

    #trainer.save_model(model_save_path)
    return



In [None]:
# Evaluation Function

def evaluate(prefix_mention_is,evaluation,testset):

    model_load_path = './pre_train_model/'   #pre trained model
    max_position_embeddings = 1024
    attention_dropout = 0 #0.1
    dropout = 0 #0.1
    dataset_path = './data/bc5cdr/'
    model_save_path = './model_checkpoints/'   #saved finetuned model
    model_token_path = "facebook/bart-large"   #HF model tokenizer
    trie_path = './data/bc5cdr/trie.pkl'  #prefix tree
    max_length=2014



    print ('Evaluation')
    print(torch.cuda.is_available())
    bartconf = BartConfig.from_pretrained(model_load_path)
    bartconf.max_position_embeddings = max_position_embeddings
    bartconf.attention_dropout = attention_dropout
    bartconf.dropout = dropout
    bartconf.max_length = max_length

    tokenizer = BartTokenizer.from_pretrained(model_token_path)
    #model = BartForConditionalGeneration.from_pretrained(model_load_path, config = bartconf)

    model = BartForConditionalGeneration.from_pretrained(model_save_path+'checkpoint-15000/', config = bartconf)  #loading saved model after finetuning

    #model = model.cuda().to(model.device)
    model.to('cpu')
    print (model.device)

    _, dev_dataset, test_dataset = create_pretrained_datasets(tokenizer,
                                                    dataset_path,
                                                    prefix_mention_is = prefix_mention_is,
                                                    evaluate = evaluation)

    if testset:
        print('eval on test set')
        eval_dataset = test_dataset
    else:
        print('eval on develop set')
        eval_dataset = dev_dataset


    #loading the prefix tree
    print('loading trie......')
    with open(trie_path, "rb") as f:
        trie = Trie.load_from_dict(pickle.load(f))
    print('trie loaded.......')

    #will have to use custom Trie class methods

    bad_words=['is','to',' '+tokenizer.unk_token,' '+'to',tokenizer.unk_token]
    bad_word_ids = tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids

    top5results=[]
    for i,example in enumerate(eval_dataset):
        print('Example: ',i)
        example_result={}
        input_ids = example['input_ids'].unsqueeze(0)
        print ("Decoder prompt: ",tokenizer.decode(example['decoder_input_ids_test'], skip_special_tokens=True))
        encoder_attention_mask = example['attention_mask'].unsqueeze(0)
        decoder_input_ids = example['decoder_input_ids_test'][:-1].unsqueeze(0)  #FOR WITH PROMPT
        #decoder_input_ids = example['decoder_input_ids_test'].unsqueeze(0)  #FOR WITHOUT PROMPT
        decoder_attention_mask = example['decoder_attention_mask_test'].unsqueeze(0)
        input_ids = input_ids.to(model.device)
        encoder_attention_mask = encoder_attention_mask.to(model.device)
        decoder_input_ids = decoder_input_ids.to(model.device)
        decoder_attention_mask = decoder_attention_mask.to(model.device)

        beam_output = model.generate(input_ids,
                                     bad_words_ids=bad_word_ids,
                                     decoder_input_ids=decoder_input_ids,
                                     attention_mask=encoder_attention_mask,
                                     decoder_attention_mask=decoder_attention_mask,
                                     #max_length=10,
                                     max_new_tokens=20,
                                     num_beams=10,
                                     #do_sample=True,  #uncomment to run sampling beam search
                                     #top_k=500,   #uncomment to run sampling beam search
                                     num_return_sequences=5,
                                     decoder_start_token_id=2,
                                     early_stopping=True,
                                     ##### all hyperparams below are set to default
                                     #temperature=5.0, #uncomment to run sampling beam search
                                     repetition_penalty=1.2, #1.2
                                     no_repeat_ngram_size=0,
                                     length_penalty=0.1,
                                     remove_invalid_values=True,
                                     prefix_allowed_tokens_fn=lambda batch_id, sent: custom_constraint(trie,sent),
                                    )


        generated_text1 = tokenizer.decode(beam_output[0], skip_special_tokens=True)
        print ('text1 :',generated_text1)

        generated_text2 = tokenizer.decode(beam_output[1], skip_special_tokens=True)
        print ('text2 :',generated_text2)
        generated_text3 = tokenizer.decode(beam_output[2], skip_special_tokens=True)
        print ('text3 :',generated_text3)
        generated_text4 = tokenizer.decode(beam_output[3], skip_special_tokens=True)
        print ('text4 :',generated_text4)
        generated_text5 = tokenizer.decode(beam_output[4], skip_special_tokens=True)
        print ('text5 :',generated_text5)

        example_result['text1']=generated_text1.strip()
        example_result['text2']=generated_text2.strip()
        example_result['text3']=generated_text3.strip()
        example_result['text4']=generated_text4.strip()
        example_result['text5']=generated_text5.strip()
        top5results.append(example_result)
        top5_df = pd.DataFrame(top5results,columns=['text1','text2','text3','text4','text5'])
        if testset:
            top5_df.to_csv('Test_nb10_t5_rp1.2_pmtTrue_sampleTrue_topk50000_lp0.1.csv')
        else:
            top5_df.to_csv('Dev_nb10_t5_rp1.2_pmtTrue_sampleTrue_topk50000_lp0.1.csv')

    return top5results

def custom_constraint(trie,sequence):
    #print ("Existing :",sequence)
    return trie.get(sequence.tolist())


In [None]:
# Results Function
def find_cui_str(dict_path):
    #loading the json file with MeSH concept and synonyms.
    #concept with the list of synonyms

    with open(dict_path, 'r') as f:
        cui2str = json.load(f)

    #cui:[list of strings]
    return cui2str


def find_str_cui(cui2str):
     #mapping the synonyms to concept
    str2cui = {}
    for cui in cui2str:
        if isinstance(cui2str[cui], list):
            for name in cui2str[cui]:
                if name in str2cui:
                    str2cui[name].append(cui)
                else:
                    str2cui[name] = [cui]
        else:
            name = cui2str[cui]
            if name in str2cui:
                str2cui[name].append(cui)
                print('duplicated vocabulary')
            else:
                str2cui[name] = [cui]
    print('dictionary loaded......')
    #dict with str:list of cui..mostly one cui

    return str2cui

def find_true_labels(testset,dataset_path):
    if testset:
        #loading test label cuis
        print('loading label cuis......')
        with open(dataset_path+'/testlabel.txt', 'r') as f:
            cui_labels = [set(cui.strip('\n').replace('+', '|').split('|')) for cui in f.readlines()]
            cui_labels=[list(item) for item in cui_labels]
        print('label cuis loaded')
        return cui_labels

    else:

        #loading dev label cuis
        print('loading dev label cuis......')
        with open(dataset_path+'/devlabel.txt', 'r') as f:
            dev_cui_labels = [set(cui.strip('\n').replace('+', '|').split('|')) for cui in f.readlines()]
            dev_cui_labels=[list(item) for item in dev_cui_labels]
        print('dev label cuis loaded')
        return dev_cui_labels


def find_precision(top5results,true_labels,str2cui,testset):
    precision1,precision3,precision5=[],[],[]
    cui_dict={'cui_text1':[],'cui_text2':[],'cui_text3':[],'cui_text4':[],'cui_text5':[],'ground_truth':[item[0] for item in true_labels]}
    for i in range(len(top5results)):
        truth = true_labels[i][0]
        result_dict = top5results[i]

        p1=0  #Precision@1
        p3=0  #Precision@3
        p5=0  #Precision@5
        count=0
        for k,v in result_dict.items():

           cui_dict['cui_'+k].append(str2cui.get(v,[None])[0])
           if truth in str2cui.get(v,[None]):
                if count==0:
                   p1+=1
                   p3+=1
                   p5+=1
                elif count>0 and count<3:
                    p3+=1
                    p5+=1
                else:
                    p5+=1
           count+=1



        precision1.append(float(p1/1))
        precision3.append(float(p3/3))
        precision5.append(float(p5/5))

    print (precision5)
    avg_top1_precision= np.mean(precision1)
    avg_top3_precision= np.mean(precision3)
    avg_top5_precision= np.mean(precision5)

    cui_df = pd.DataFrame(cui_dict,columns=list(cui_dict.keys()))

    if testset:
        cui_df.to_csv('TEST_Top5outputswithprompt.csv')

        print ('TEST_Average TOP 1 Precision: ',round(avg_top1_precision,2))
        print ('TEST_Average TOP 3 Precision: ',round(avg_top3_precision,2))
        print ('TEST_Average TOP 5 Precision: ',round(avg_top5_precision,2))
    else:
        cui_df.to_csv('DEV_Top5outputswithprompt.csv')

        print ('DEV_Average Top 1 Precision: ',round(avg_top1_precision,2))
        print ('DEV_Average TOP 3 Precision: ',round(avg_top3_precision,2))
        print ('DEV_Average TOP 5 Precision: ',round(avg_top5_precision,2))


    return precision1,precision3,precision5

In [None]:
#Main Method
evaluation = False  #change to True for test/dev
testset = False  #change to True  for test
prefix_mention_is = True
dict_path = './data/bc5cdr/target_kb.json'
data_path = './data/bc5cdr/'


if evaluation:
    top5results = evaluate(prefix_mention_is,evaluation,testset)
    cui2str = find_cui_str(dict_path)
    str2cui = find_str_cui(cui2str)
    true_labels = find_true_labels(testset,data_path)
    p1,p3,p5 = find_precision(top5results,true_labels,str2cui,testset)

else:
    train(prefix_mention_is,evaluation)

Evaluation
True
cpu


100%|██████████| 50/50 [00:00<00:00, 448.66it/s]
100%|██████████| 50/50 [00:00<00:00, 3633.00it/s]

eval on test set
loading trie......





trie loaded.......
Example:  0
Decoder prompt:   malondialdehyde is




text1 :  malondialdehyde
text2 :  malondialdehyde sodium
text3 :  malondialdehyde low density lipoprotein human
text4 :  malondialdehyde low density lipoprotein mouse
text5 :  malondialdehyde low density lipoprotein human
Example:  1
Decoder prompt:   nitric oxide is
text1 :  nitric oxide
text2 :  nitric oxide synthase
text3 :  nitric oxide reductase
text4 :  nitric oxide receptors
text5 :  nitric oxide synthetase
Example:  2
Decoder prompt:   glutathione is
text1 :  glutathione synthase
text2 :  glutathione oxidase
text3 :  glutathione oxidized
text4 :  glutathione sulfonate
text5 :  glutathione sulfonamide
Example:  3
Decoder prompt:   superoxide is
text1 :  superoxide anion
text2 :  superoxide radical
text3 :  superoxide
text4 :  superoxide reductase
text5 :  superoxide dismutase
Example:  4
Decoder prompt:   isoproterenol is
text1 :  isoproterenol
text2 :  isoproterenol hydrochloride
text3 :  isoproterenol sulfate
text4 :  isoproterenol bitartrate
text5 :  isoproterenol bitartrate
