In [None]:
from datetime import datetime
import os
import random
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
import transformers
from transformers import LlamaTokenizer, LlamaForSequenceClassification

# Prevents many tokenizer warnings
transformers.logging.set_verbosity_error()

In [None]:
class Utils:
    """Utility functions called by the model operations .
    """

    def __init__(self, time_stamp):
        """
        """

        self.time_stamp = time_stamp
    
    def get_timestamp(self):
        return self.time_stamp

    def write_log(self, s, path = 'log.out', prnt = True):
        ''
        
        f = open(self.time_stamp + path , "a")
        f.write('\n' + s)
        if prnt:
            print(s)
        f.close()

    def load_relations(self, path: str):
        ''

        relations = []
        with open(path) as f:
            for line in f.readlines():
                relations.append(line.strip())
        
        print('\nRelations loaded')
        return relations

    # Reproduce
    def seed_worker(self, worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

In [None]:
class KGDataset(Dataset):
    """
    """

    def __init__(self, args,
                 triples_filename: str,
                 relations: list,
                ) -> None:
        """This constructor loads the necessary data.
        """

        self.args = args
        self.relations = relations
        self.cache = {}

        # Read file
        f = open( os.path.join(args.dataset_directory, triples_filename) )
        file_lines = f.readlines()
        f.close()

        # Load lines based on the specified data amount
        self.lines = file_lines[: int(len(file_lines) * args.data_size) ]
        print(f'\n{len(self.lines)} triples loaded')

        # Tokenizer
        tokenizer = LlamaTokenizer.from_pretrained(args.model_id,
                                                   token = self.args.repo_token,)
        tokenizer.padding_side = 'right'
        tokenizer.truncation_side = 'right'
        tokenizer.add_special_tokens({'pad_token': '<pad>'})
        tokenizer.model_max_length = args.padding
        self.tokenizer = tokenizer

        # Load entity translations
        self.entities_dict = {}
        with open( os.path.join(args.dataset_directory, args.entities_filename) ) as f:
            for line in f.readlines():
                fields = line.split('\t')
                fields = [p.strip() for p in fields]
                self.entities_dict[ fields[0] ] = fields[1]
        
        print(f'\n{len(self.entities_dict.keys())} entity translations loaded')

        return None
    
    def __len__(self) -> int:
        """
        """

        return len(self.lines)
    
    def __getitem__(self, index):
        """
        """

        # Check the cache
        if index in self.cache.keys():
            return self.cache[index]

        # Create triple from a dataset line

        fields = self.lines[index].split('\t')
        fields = [p.strip() for p in fields]

        # Prepare Y label
        rel = fields[1]
        rel_index = self.relations.index(rel)
        relations_tagged = [0.0] * len(self.relations)
        relations_tagged[ rel_index ] = 1.0

        # Tokenize
        inputs = self.tokenizer([[self.entities_dict[fields[0]],\
                                   self.entities_dict[fields[2]]]],
                  padding='max_length',
                  truncation = True,
                  return_attention_mask=True,
                  return_tensors="pt")
        inputs['input_ids'] = inputs['input_ids'].squeeze(0).squeeze(0)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(0).squeeze(0)

        result = (inputs, torch.tensor(relations_tagged))

        self.cache[index] = result
        
        return result

In [None]:
class Llama():
    """
    """

    def __init__(self,relations, model_id, repo_token) -> None:
        """
        """

        self.relations = relations
        self.model_id = model_id
        self.repo_token = repo_token
        
    
    def get_model(self):
        model = LlamaForSequenceClassification.from_pretrained(
            self.model_id,
            token = self.repo_token,
            num_labels= len(self.relations))
        model.config.pad_token_id = 4
        for param in model.parameters():
            if param.dtype == torch.float32 or \
            param.dtype == torch.float16 :
                param.data = param.data.to(torch.bfloat16)

        return model

In [None]:
def train(relations,
    repo_token,
    args,
    utils,
    generator,
    training_triples,
    validation_triples,
    device
    ):
    """Train
    """

    # Load training set
    training_set = KGDataset(args=args,
                             relations=relations,
                             triples_filename=training_triples)
    training_generator = DataLoader(training_set,
                                    batch_size = args.batch_size,
                                    worker_init_fn=utils.seed_worker,
                                    generator=generator,)
    validation_set = KGDataset(args=args,
                               relations=relations,
                               triples_filename=validation_triples)
    validation_generator = DataLoader(validation_set,
                                      batch_size = args.batch_size,
                                    worker_init_fn=utils.seed_worker,
                                    generator=generator,)

    # Initializing the model
    llama = Llama(relations, args.model_id, repo_token)
    model = llama.get_model()
    loss_f = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate )
    scheduler = lr_scheduler.StepLR(optimizer, gamma=args.decay, step_size = 1)
    model.to(device)

    v_loss = 1_000_000
    no_change_counter = 1
    for epoch in range(args.epochs):
        print(f'\nEpoch {epoch + 1}\n-------------------------------')
        lr = optimizer.param_groups[0]['lr']
        model.train()
        loop = tqdm(training_generator, disable = not args.verbose)

        # Loop over batches in an epoch using DataLoader
        for _, data in enumerate(loop):
            inputs = data[0].to(device)
            optimizer.zero_grad()
            logits = model(**inputs).logits
            loss = loss_f(logits , data[1].to(device))
            loss.backward()       
            optimizer.step()
            last_loss = loss.item()
        
        v_losses = []
        model.eval()
        with torch.no_grad():
            for _, data in enumerate(validation_generator):
                inputs = data[0].to(device)
                logits = model(**inputs).logits
                loss = loss_f(logits , data[1].to(device))
                v_losses.append(loss)
            
            v_loss_epoch = sum(v_losses) / len(v_losses)
            utils.write_log(f'lr {lr:8f} train loss {last_loss:.8f} val loss {v_loss_epoch:.8f}')

            if v_loss_epoch < v_loss:
                v_loss = v_loss_epoch
                no_change_counter = 0
                torch.save(model.state_dict(), utils.get_timestamp()+'chkpnt.pt')
            elif no_change_counter > args.patience - 1:
                break
            else:
                no_change_counter += 1
        
        scheduler.step()


In [None]:
#To be removed from online repo
class Args(argparse.Namespace):
    pass

args=Args()
args.model_id = 'meta-llama/Llama-2-7b-hf'
args.data_size = 0.01
args.dataset_directory = 'data/FB15K'
args.entities_filename = 'entity2text.txt'
args.padding = 40
args.batch_size = 16
args.patience = 5
args.repo_token = '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'
args.task = 'train'
args.learning_rate = 2e-5
args.decay = 0.35
args.epochs = 3
args.verbose = True

In [None]:
def main():
    """
    """

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_size', required = True, type = float)
    parser.add_argument('--dataset_directory', required = True, type = str)
    parser.add_argument('--entities_filename', required = True, type = str)
    parser.add_argument('--model_id', required = True, type = str)
    parser.add_argument('--repo_token', required = True, type = str)
    parser.add_argument('--padding', required = True, type = int)
    parser.add_argument('--learning_rate', required = True, type = float)
    parser.add_argument('--decay', required = True, type = float)
    parser.add_argument('--task', required = True, type = str)
    parser.add_argument('--batch_size', required = True, type = int)
    parser.add_argument('--patience', required = True, type = int)
    parser.add_argument('--epochs', required = True, type = int)
    parser.add_argument('--verbose', required = True)
    
    # Keep for online repo
    # args = parser.parse_args()

    # Random seeds
    g = torch.Generator()
    g.manual_seed(0)

    # Initializations
    
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H_%M_%S')+'.out'

    # To be removed from the online repo

    # Load utilities
    utils = Utils(time_stamp)
    relations = utils.load_relations(args.dataset_directory + '/relations.txt')    

    # Loading GPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    utils.write_log('\ndevice ' + str(device))

    if args.task == 'train':
        train(relations=relations, args=args, repo_token = args.repo_token,\
              utils = utils, generator=g,
              training_triples='train.tsv',
              validation_triples='dev.tsv',
              device=device)
    else:
        pass

if __name__ == "__main__": main()