In [1]:
import transformers 
import pandas as pd
import numpy as np
import json
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


from typing import Any, Union, List, Optional

import re
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import datasets
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    default_data_collator,
    set_seed,
)

from torch.optim import lr_scheduler
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#for pretrain data
data = pd.read_csv("soft_data/data/out_data/entail_articles_url_coref3.csv.done.csv", index_col = 0)
data = data.rename(columns = {"label":"triplets"})

x_train, x_val, y_train, y_val = train_test_split(data.text, data.triplets, test_size = 0.15)
x_val, x_test, y_val, y_test = train_test_split(x_val, y_val, test_size = 0.5)
data_dict = {
    "train": pd.DataFrame(zip(x_train, y_train), columns = ["text","triplets"]), 
    "val": pd.DataFrame(zip(x_val, y_val), columns = ["text","triplets"]), 
    "test": pd.DataFrame(zip(y_test, y_test), columns = ["text","triplets"])
}

# #### just random splits for testing, change later to proper splitting
# ###### and then change later again for CV
# split = ds.train_test_split(test_size = 0.15)
# split_val = split["test"].train_test_split(test_size = 0.5)
# ds = datasets.DatasetDict({"train":split["train"],"val":split_val["train"], "test":split_val["test"]})

# #ds = ds.with_format("torch")
# ds

In [None]:
cameo_to_penta = {
    "Make Public Statement" : "Make a statement",
    "Appeal" : "Verbal Cooperation",  	#Statement?
    "Express intent to cooperate" : "Verbal Cooperation",
    "Consult" : "Verbal Cooperation",
    "Engage in Diplomatic Cooperation" : "Verbal Cooperation",
    "Engange in Material Cooperation" : "Material Cooperation",
    "Provide Aid" : "Material Cooperation",
    "Yield" : "Material Cooperation",  #Verbal?
    "Investigate" : "Material Conflict",
    "Demand" : "Verbal Conflict",
    "Disapprove" : "Verbal Conflict",
    "Reject" : "Verbal Conflict",
    "Threaten" : "Verbal Conflict",
    "Exhibit Military Posture" : "Material Conflict",
    "Protest" : "Verbal Conflict", #Material?
    "Reduce Relations" : "Verbal Conflict",
    "Coerce" : "Material Conflict",
    "Assault" : "Material Conflict",
    "Fight" : "Material Conflict",
    "Engange in unconventional mass violence" : "Material Conflict"
}

pentacode = True
if pentacode == True:
    for row in data.iterrows()


In [6]:
#for train data

read = pd.read_json("data_src/raw/out_label/summaries_out1.json", lines = True)
df = read[read.answer == "accept"]
df.head()

Unnamed: 0,text,article_id,_input_hash,_task_hash,_is_binary,spans,tokens,_view_id,relations,answer,_timestamp
0,Ukraine war: Russian tactics on eastern front ...,0,1192826571,1725000051,False,"[{'start': 0, 'end': 7, 'token_start': 0, 'tok...","[{'text': 'Ukraine', 'start': 0, 'end': 7, 'id...",relations,"[{'head': 3, 'child': 0, 'head_span': {'start'...",accept,1666853824
2,Lee Jae-yong: Samsung appoints Lee Jae-yong to...,2,-386898449,495824270,False,"[{'start': 0, 'end': 12, 'token_start': 0, 'to...","[{'text': 'Lee', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 31, 'child': 18, 'head_span': {'star...",accept,1666854104
3,New Zealand Instagram couple 'relieved' after ...,3,94784817,-1047716132,False,"[{'start': 0, 'end': 28, 'token_start': 0, 'to...","[{'text': 'New', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 3, 'child': 9, 'head_span': {'start'...",accept,1666854141
4,Iran protests: Police fire on Mahsa Amini mour...,4,-1204324385,1067222802,False,"[{'start': 0, 'end': 4, 'token_start': 0, 'tok...","[{'text': 'Iran', 'start': 0, 'end': 4, 'id': ...",relations,"[{'head': 3, 'child': 8, 'head_span': {'start'...",accept,1666854246
5,Germany plans to legalise recreational cannabi...,5,1523244431,980527859,False,"[{'start': 0, 'end': 7, 'token_start': 0, 'tok...","[{'text': 'Germany', 'start': 0, 'end': 7, 'id...",relations,"[{'head': 0, 'child': 5, 'head_span': {'start'...",accept,1666854286


In [4]:
#prune df
df = df.drop(columns = ["_input_hash","_task_hash","_is_binary","tokens","_view_id","answer","_timestamp"])
df = df.reset_index().drop(columns = ["index"])

NameError: name 'df' is not defined

In [9]:
#get data into needed format and create new DataFrame
data = []
for index,row in df.iterrows():
    rel_list = []
    for rel in row.relations:
        subj = row.text[rel["head_span"]["start"]:rel["head_span"]["end"]]
        obj = row.text[rel["child_span"]["start"]:rel["child_span"]["end"]]
        rel_list.append(f"<triplet> {subj} <subj> {obj} <obj> {rel['label']}")
    data.append({"doc_id":row.article_id, "text": row.text, "triplets": " ".join(rel_list)})

data = pd.DataFrame(data, columns = ["doc_id","text","triplets"])

In [10]:
df

Unnamed: 0,text,article_id,spans,relations,relations_count
0,Ukraine war: Russian tactics on eastern front ...,0,"[{'start': 0, 'end': 7, 'token_start': 0, 'tok...","[{'head': 3, 'child': 0, 'head_span': {'start'...",4
1,Lee Jae-yong: Samsung appoints Lee Jae-yong to...,2,"[{'start': 0, 'end': 12, 'token_start': 0, 'to...","[{'head': 31, 'child': 18, 'head_span': {'star...",2
2,New Zealand Instagram couple 'relieved' after ...,3,"[{'start': 0, 'end': 28, 'token_start': 0, 'to...","[{'head': 3, 'child': 9, 'head_span': {'start'...",1
3,Iran protests: Police fire on Mahsa Amini mour...,4,"[{'start': 0, 'end': 4, 'token_start': 0, 'tok...","[{'head': 3, 'child': 8, 'head_span': {'start'...",2
4,Germany plans to legalise recreational cannabi...,5,"[{'start': 0, 'end': 7, 'token_start': 0, 'tok...","[{'head': 0, 'child': 5, 'head_span': {'start'...",1
...,...,...,...,...,...
148,Macri calls for unity but gives little away at...,194,"[{'start': 0, 'end': 5, 'token_start': 0, 'tok...","[{'head': 20, 'child': 14, 'head_span': {'star...",2
149,Whirlpool touts bet on troubled Argentina as o...,195,"[{'start': 0, 'end': 15, 'token_start': 0, 'to...","[{'head': 1, 'child': 5, 'head_span': {'start'...",2
150,LGBT football fans fight for safe space in Bra...,197,"[{'start': 0, 'end': 18, 'token_start': 0, 'to...","[{'head': 2, 'child': 8, 'head_span': {'start'...",2
151,ANSES: Half a million sign up for emergency fo...,198,"[{'start': 7, 'end': 21, 'token_start': 2, 'to...","[{'head': 15, 'child': 32, 'head_span': {'star...",1


In [11]:
#ds = datasets.Dataset.from_pandas(data.drop(columns=["doc_id"]))

#### just random splits for testing, change later to proper splitting
###### and then change later again for CV
#split = ds.train_test_split(test_size = 0.3)
#split_val = split["test"].train_test_split(test_size = 0.3)
#ds = datasets.DatasetDict({"train":split["train"],"val":split_val["train"], "test":split_val["test"]})

#ds = ds.with_format("torch")
#ds

In [3]:
class conf:
    #general
    seed = 0
    gpus = 1 #more?
    
    #input
    batch_size = 8
    max_length = 128
    ignore_pad_token_for_loss = True
    use_fast_tokenizer = True
    gradient_acc_steps = 1
    gradient_clip_value = 10.0
    load_workers = 8 #8 are local laptop

    #optimizer
    lr = 0.00005
    weight_decay = 0.01
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 0.00000001
    warm_up = 2 #num of epochs to warm_up on

    #training
    max_steps = 1200000
    samples_interval = 1000

    monitor_var  = "val_loss"
    monitor_var_mode = "min"
    # val_check_interval = 0.5
    # val_percent_check = 0.1

    model_name = "model1.pth"
    checkpoint_path = f"models/{model_name}"
    save_top_k = 1

    early_stopping = False
    patience = "5"

    length_penalty = 0
    no_repeat_ngram_size = 0
    num_beams = 3
    precision = 16
    amp_level = None

    

In [22]:
class GetData(pl.LightningDataModule):
    def __init__(self, conf: conf, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM):
        """init params from config"""
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.datasets = data_dict
        
        # Data collator
        label_pad_token_id = -100                       
        self.data_collator = DataCollatorForSeq2Seq(self.tokenizer, self.model, label_pad_token_id=label_pad_token_id, padding = True)

    def preprocess_function(self, data):
        """tokenize, pad, truncate"""
        #split into input and labels
        inputs = data["text"]       
        outputs = data["triplets"]

        #process input
        model_inputs = self.tokenizer(inputs, max_length = conf.max_length, padding = True, truncation = True)

        #process labels
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(outputs, max_length = conf.max_length, padding = True, truncation = True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    def apply_mask(self, data):        
        #masking entities in sentence
            print("applying mask ...")
            new = []
            for row in data.iterrows():
                split = re.split("<\w*>", row[1]["triplets"])[1:]   #first one is empty
                for i in range(int(len(split)/3)):                  #always pairs of 3
                    sub = split[i*3:i*3+3]
                    subj = sub[0]
                    obj = sub[1]             
                    if np.random.binomial(1, 0.15, 1) == 1:                  
                        tok = np.random.choice(sub).rstrip().lstrip()               
                        new.append([row[1]["text"].replace(tok, "<MASK>"), row[1]["triplets"].replace(tok, "<MASK>")])                    
                        break
                    else:
                        new.append([row[1]["text"], row[1]["triplets"]])
                    
            df = pd.DataFrame(new, columns = ["text", "triplets"])
            df = df.drop_duplicates("text")
            print(df.shape)
            
            ds = datasets.Dataset.from_pandas(df)
            ds = ds.remove_columns(["__index_level_0__"])
            return ds 
        
    #apply the preprocessing and load data
    def train_dataloader(self, *args, **kwargs): 
        self.train_dataset = self.datasets["train"]
        self.train_dataset = self.apply_mask(self.train_dataset)
        self.train_dataset = self.train_dataset.map(self.preprocess_function, remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.train_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, shuffle = True, num_workers= conf.load_workers)
    
    def val_dataloader(self, *args, **kwargs): 
        self.eval_dataset = self.datasets["val"]
        self.eval_dataset = self.apply_mask(self.eval_dataset)
        self.eval_dataset = self.eval_dataset.map(self.preprocess_function, remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.eval_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, num_workers= conf.load_workers)

    def test_dataloader(self, *args, **kwargs): 
        self.test_dataset = self.datasets["test"]
        self.test_dataset = self.apply_mask(self.test_dataset)
        self.test_dataset = self.test_dataset.map(self.preprocess_function,  remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.test_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, num_workers= conf.load_workers)

In [15]:
#from REBEL
from typing import Sequence

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import nn
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
import wandb

class GenerateTextSamplesCallback(Callback):
    """
    PL Callback to generate triplets along training
    """

    def __init__(self, logging_batch_interval):
        """
        Args:
            logging_batch_interval: How frequently to inspect/potentially plot something
        """
        super().__init__()
        self.logging_batch_interval = logging_batch_interval

    def on_train_batch_end(self,trainer: Trainer,pl_module: LightningModule, outputs: Sequence, batch: Sequence, batch_idx: int) -> None:
        wandb_table = wandb.Table(columns=["Source", "Pred", "Gold"])
        # pl_module.logger.info("Executing translation callback")
        labels = batch.pop("labels")
        gen_kwargs = {
            "max_length": conf.max_length,
            "early_stopping": False,
            "no_repeat_ngram_size": 0,
            "num_beams": conf.num_beams
        }
        pl_module.eval()

        decoder_inputs = torch.roll(labels, 1, 1)[:,0:2]
        decoder_inputs[:, 0] = 0
        generated_tokens = pl_module.model.generate(
            batch["input_ids"].to(pl_module.model.device),
            attention_mask=batch["attention_mask"].to(pl_module.model.device),
            decoder_input_ids=decoder_inputs.to(pl_module.model.device),
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = pl_module._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
        pl_module.train()
        decoded_preds = pl_module.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

        #If ignore pad token for loss == True 
            # Replace -100 in the labels as we can't decode them.
        labels = torch.where(labels != -100, labels, pl_module.tokenizer.pad_token_id)

        decoded_labels = pl_module.tokenizer.batch_decode(labels, skip_special_tokens=False)
        decoded_inputs = pl_module.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)

        # pl_module.logger.experiment.log_text('generated samples', '\n'.join(decoded_preds).replace('<pad>', ''))
        # pl_module.logger.experiment.log_text('original samples', '\n'.join(decoded_labels).replace('<pad>', ''))
        for source, translation, gold_output in zip(decoded_inputs, decoded_preds, decoded_labels):
            wandb_table.add_data(
                source.replace('<pad>', ''), translation.replace('<pad>', ''), gold_output.replace('<pad>', '')
            )
        pl_module.logger.experiment.log({"Triplets": wandb_table})

In [16]:
#from REBEL
def shift_tokens_left(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, :-1] = input_ids[:, 1:].clone()
    shifted_input_ids[:, -1] = pad_token_id

    assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."

    return shifted_input_ids

In [17]:
#from REBEL

##### can maybe be rewritten more effective
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

In [18]:
#from REBEL 
def score(key, prediction, verbose=False):
    correct_by_relation = Counter()
    guessed_by_relation = Counter()
    gold_by_relation    = Counter()

    # Loop over the data to compute a score
    for row in range(len(prediction)):
        gold = key[row]
        guess = prediction[row]
         
        if gold == NO_RELATION and guess == NO_RELATION:
            pass
        elif gold == NO_RELATION and guess != NO_RELATION:
            guessed_by_relation[guess] += 1
        elif gold != NO_RELATION and guess == NO_RELATION:
            gold_by_relation[gold] += 1
        elif gold != NO_RELATION and guess != NO_RELATION:
            guessed_by_relation[guess] += 1
            gold_by_relation[gold] += 1
            if gold == guess:
                correct_by_relation[guess] += 1

    # Print verbose information
    if verbose:
        print("Per-relation statistics:")
        relations = gold_by_relation.keys()
        longest_relation = 0
        for relation in sorted(relations):
            longest_relation = max(len(relation), longest_relation)
        for relation in sorted(relations):
            # (compute the score)
            correct = correct_by_relation[relation]
            guessed = guessed_by_relation[relation]
            gold    = gold_by_relation[relation]
            prec = 1.0
            if guessed > 0:
                prec = float(correct) / float(guessed)
            recall = 0.0
            if gold > 0:
                recall = float(correct) / float(gold)
            f1 = 0.0
            if prec + recall > 0:
                f1 = 2.0 * prec * recall / (prec + recall)
            # (print the score)
            sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation))
            sys.stdout.write("  P: ")
            if prec < 0.1: sys.stdout.write(' ')
            if prec < 1.0: sys.stdout.write(' ')
            sys.stdout.write("{:.2%}".format(prec))
            sys.stdout.write("  R: ")
            if recall < 0.1: sys.stdout.write(' ')
            if recall < 1.0: sys.stdout.write(' ')
            sys.stdout.write("{:.2%}".format(recall))
            sys.stdout.write("  F1: ")
            if f1 < 0.1: sys.stdout.write(' ')
            if f1 < 1.0: sys.stdout.write(' ')
            sys.stdout.write("{:.2%}".format(f1))
            sys.stdout.write("  #: %d" % gold)
            sys.stdout.write("\n")
        print("")

    # Print the aggregate score
    if verbose:
        print("Final Score:")
    prec_micro = 1.0
    if sum(guessed_by_relation.values()) > 0:
        prec_micro   = float(sum(correct_by_relation.values())) / float(sum(guessed_by_relation.values()))
    recall_micro = 0.0
    if sum(gold_by_relation.values()) > 0:
        recall_micro = float(sum(correct_by_relation.values())) / float(sum(gold_by_relation.values()))
    f1_micro = 0.0
    if prec_micro + recall_micro > 0.0:
        f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro)
    print("Precision (micro): {:.3%}".format(prec_micro))
    print("   Recall (micro): {:.3%}".format(recall_micro))
    print("       F1 (micro): {:.3%}".format(f1_micro))
    return prec_micro, recall_micro, f1_micro

'''Adapted from: https://github.com/btaille/sincere/blob/6f5472c5aeaf7ef7765edf597ede48fdf1071712/code/utils/evaluation.py'''
def re_score(pred_relations, gt_relations, relation_types, mode="boundaries"):
    """Evaluate RE predictions
    Args:
        pred_relations (list) :  list of list of predicted relations (several relations in each sentence)
        gt_relations (list) :    list of list of ground truth relations
            rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
                    "tail": (start_idx (inclusive), end_idx (exclusive)),
                    "head_type": ent_type,
                    "tail_type": ent_type,
                    "type": rel_type}
        vocab (Vocab) :         dataset vocabulary
        mode (str) :            in 'strict' or 'boundaries' """

    assert mode in ["strict", "boundaries"]
    relation_types = ["MakePublicStatement","Appeal","ExpressIntentToCooperate","Consult","EngageInDiplomaticCooperation","EngageInMaterialCooperation","ProvideAid","Yield","Investigate","Demand","Disapprove","Reject","Threaten","ExhibitMilitaryPosture","Protest","ReduceRelations","Coerce","Assault","Fight","EngageInUnconventialMassViolence"]
        
    # relation_types = [v for v in relation_types if not v == "None"]
    scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}

    # Count GT relations and Predicted relations
    n_sents = len(gt_relations)
    n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
    n_found = sum([len([rel for rel in sent]) for sent in pred_relations])

    # Count TP, FP and FN per type
    for pred_sent, gt_sent in zip(pred_relations, gt_relations):
        for rel_type in relation_types:
            # strict mode takes argument types into account
            if mode == "strict":
                pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
                             rel["type"] == rel_type}
                gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
                           rel["type"] == rel_type}

            # boundaries mode only takes argument spans into account
            elif mode == "boundaries":
                pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
                gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}

            scores[rel_type]["tp"] += len(pred_rels & gt_rels)
            scores[rel_type]["fp"] += len(pred_rels - gt_rels)
            scores[rel_type]["fn"] += len(gt_rels - pred_rels)

    # Compute per relation Precision / Recall / F1
    for rel_type in scores.keys():
        if scores[rel_type]["tp"]:
            scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
            scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
        else:
            scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0

        if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
            scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
                    scores[rel_type]["p"] + scores[rel_type]["r"])
        else:
            scores[rel_type]["f1"] = 0

    # Compute micro F1 Scores
    tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
    fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
    fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])

    if tp:
        precision = 100 * tp / (tp + fp)
        recall = 100 * tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)

    else:
        precision, recall, f1 = 0, 0, 0

    scores["ALL"]["p"] = precision
    scores["ALL"]["r"] = recall
    scores["ALL"]["f1"] = f1
    scores["ALL"]["tp"] = tp
    scores["ALL"]["fp"] = fp
    scores["ALL"]["fn"] = fn

    # Compute Macro F1 Scores
    scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
    scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
    scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])

    print(f"RE Evaluation in *** {mode.upper()} *** mode")

    print(
        "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(n_sents, n_rels, n_found,
                                                                                             tp))
    print(
        "\tALL\t TP: {};\tFP: {};\tFN: {}".format(
            scores["ALL"]["tp"],
            scores["ALL"]["fp"],
            scores["ALL"]["fn"]))
    print(
        "\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(
            precision,
            recall,
            f1))
    print(
        "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
            scores["ALL"]["Macro_p"],
            scores["ALL"]["Macro_r"],
            scores["ALL"]["Macro_f1"]))

    for rel_type in relation_types:
        print("\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
            rel_type,
            scores[rel_type]["tp"],
            scores[rel_type]["fp"],
            scores[rel_type]["fn"],
            scores[rel_type]["p"],
            scores[rel_type]["r"],
            scores[rel_type]["f1"],
            scores[rel_type]["tp"] +
            scores[rel_type][
                "fp"]))

    return scores, precision, recall, f1

In [19]:
class BaseModule(pl.LightningModule):

    def __init__(self, conf, config: AutoConfig, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM):
        super().__init__()
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        #self.loss_fn = label_smoothed_nll_loss
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
        self.num_beams = conf.num_beams

    def forward(self, inputs, labels, *args):
        ##### Check later if smooth labeled loss is better 
        outputs = self.model(**inputs, labels = labels, use_cache = False, return_dict = True, output_hidden_states = True)
        output_dict = {'loss': outputs['loss'], 'logits': outputs['logits']}
        return output_dict

    def training_step(self, batch: dict, batch_idx: int):
        ##### check later if labels = batch["labels"] also works
        labels = batch.pop("labels")
        labels_original = labels.clone()
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)

        forward_output = self.forward(batch, labels)
        self.log('loss', forward_output['loss'])

        batch["labels"] = labels_original

        #### ig i dont have this
        if 'loss_aux' in forward_output:
            self.log('loss_classifier', forward_output['loss_aux'])
            return forward_output['loss'] + forward_output['loss_aux']

        return forward_output['loss']

    def validation_step(self, batch: dict, batch_idx):
        #### pop maybe not needed?
        labels = batch.pop("labels")
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)
        
        with torch.no_grad():
            # compute loss on predict data
            forward_output = self.forward(batch, labels)

        forward_output['loss'] = forward_output['loss'].mean().detach()

        #### probably not needed ? why would i want only pred loss
        # if self.hparams.prediction_loss_only:
        #     self.log('val_loss', forward_output['loss'])
        #     return

        forward_output['logits'] = forward_output['logits'].detach()

        if labels.shape[-1] < conf.max_length:
            forward_output['labels'] = self._pad_tensors_to_max_len(labels, conf.max_length)
        else:
            forward_output['labels'] = labels

        metrics = {}
        metrics['val_loss'] = forward_output['loss']

        #### only 1? so why loop lmao
        for key in sorted(metrics.keys()):
            self.log(key, metrics[key])

        outputs = {}
        outputs['predictions'], outputs['labels'] = self.generate_triples(batch, labels)
        return outputs


    def test_step(self, batch, batch_idx):

        #### popping again
        labels = batch.pop("labels")
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)

        with torch.no_grad():
            # compute loss on predict data
            forward_output = self.forward(batch, labels)

        forward_output['loss'] = forward_output['loss'].mean().detach()

        forward_output['logits'] = forward_output['logits'].detach()

        if labels.shape[-1] < conf.max_length:
            forward_output['labels'] = self._pad_tensors_to_max_len(labels, conf.max_length)
        else:
            forward_output['labels'] = labels


        metrics = {}
        metrics['test_loss'] = forward_output['loss']

        #### dont i only have one metric anyways?
        for key in sorted(metrics.keys()):
            self.log(key, metrics[key], prog_bar=True)
        
        self.log("lr", optimizer.get_last_lr())

        #### what does this actually do? how does this change everything
        # if self.hparams.finetune:
        #     return {'predictions': self.forward_samples(batch, labels)}
        # else:

        outputs = {}
        outputs['predictions'], outputs['labels'] = self.generate_triples(batch, labels)
        return outputs



    def validation_epoch_end(self, output: dict):
        
        relations = ["MakePublicStatement","Appeal","ExpressIntentToCooperate","Consult","EngageInDiplomaticCooperation","EngageInMaterialCooperation","ProvideAid","Yield","Investigate","Demand","Disapprove","Reject","Threaten","ExhibitMilitaryPosture","Protest","ReduceRelations","Coerce","Assault","Fight","EngageInUnconventialMassViolence"]
        scores, precision, recall, f1 = re_score([item for pred in output for item in pred['predictions']], [item for pred in output for item in pred['labels']], relations)
        self.log('val_prec_micro', precision)
        self.log('val_recall_micro', recall)
        self.log('val_F1_micro', f1)

    def test_epoch_end(self, output: dict):

        relations = ["MakePublicStatement","Appeal","ExpressIntentToCooperate","Consult","EngageInDiplomaticCooperation","EngageInMaterialCooperation","ProvideAid","Yield","Investigate","Demand","Disapprove","Reject","Threaten","ExhibitMilitaryPosture","Protest","ReduceRelations","Coerce","Assault","Fight","EngageInUnconventialMassViolence"]
        scores, precision, recall, f1 = re_score([item for pred in output for item in pred['predictions']], [item for pred in output for item in pred['labels']], relations)
        self.log('test_prec_micro', precision)
        self.log('test_recall_micro', recall)
        self.log('test_F1_micro', f1)


    # additional functions called in main functions

    def generate_triples(self, batch, labels) -> None:

        generated_tokens = self.model.generate(
            batch["input_ids"].to(self.model.device),
            attention_mask=batch["attention_mask"].to(self.model.device),
            use_cache = True, max_length = conf.max_length, early_stopping = conf.early_stopping, length_penalty = conf.length_penalty, 
            no_repeat_ngram_size = conf.no_repeat_ngram_size, num_beams = conf.num_beams)

        decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
        decoded_labels = self.tokenizer.batch_decode(torch.where(labels != -100, labels, self.config.pad_token_id), skip_special_tokens=False)

        return [extract_triplets(rel) for rel in decoded_preds], [extract_triplets(rel) for rel in decoded_labels]

    def _pad_tensors_to_max_len(self, tensor, max_length):
        # If PAD token is not defined at least EOS token has to be defined
        pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id

        if pad_token_id is None:
            raise ValueError(
                f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}"
            )

        padded_tensor = pad_token_id * torch.ones(
            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
        )
        padded_tensor[:, : tensor.shape[-1]] = tensor
        return padded_tensor

    def configure_optimizers(self):

        ##### HUH
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": conf.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = AdamW(optimizer_grouped_parameters, lr = conf.lr, betas = (conf.beta1, conf.beta2), eps = conf.epsilon, weight_decay = conf.weight_decay)

        def lr_schedule(epoch):
            if epoch < conf.warm_up: lr_scale =  0.1
            else: lr_scale = 1 / (epoch**0.3)
            return lr_scale

        scheduler = lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lr_schedule
        )
        #scheduler = inverse_square_root(optimizer, num_warmup_steps= conf.warmup_steps)

        return [optimizer], [scheduler]

    #def compute_metrics():
        #looks not needed    

In [20]:
def train(conf):
    pl.seed_everything(conf.seed)

    tokenizer = transformers.AutoTokenizer.from_pretrained("Babelscape/rebel-large", use_fast = conf.use_fast_tokenizer,
        additional_special_tokens = ["<obj>", "<subj>", "<triplet>", "<head>", "</head>", "<tail>", "</tail>", "<MASK>"])
    config = transformers.AutoConfig.from_pretrained("Babelscape/rebel-large")
    model = transformers.AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")#, config = config)

    model.resize_token_embeddings(len(tokenizer))

    pl_data_module = GetData(conf, tokenizer, model)
    pl_module = BaseModule(conf, config, tokenizer, model)

    wandb_logger = WandbLogger(project = "project/finetune".split('/')[-1].replace('.py', ''), name = "finetune")

    callbacks_store = []

    if conf.early_stopping:
        callbacks_store.append(
            EarlyStopping(
                monitor=conf.monitor_var,
                mode=conf.monitor_var_mode,
                patience=conf.patience
            )
        )

    # callbacks_store.append(
    #     ModelCheckpoint(
    #         monitor=conf.monitor_var,
    #         # monitor=None,
    #         dirpath=f'models/{conf.model_name}',
    #         save_top_k=conf.save_top_k,
    #         verbose=True,
    #         save_last=True,
    #         mode=conf.monitor_var_mode
    #     )
    # )
    callbacks_store.append(GenerateTextSamplesCallback(conf.samples_interval))
    callbacks_store.append(LearningRateMonitor(logging_interval='step'))

    trainer = pl.Trainer(
        accelerator = "gpu",
        devices = conf.gpus,
        accumulate_grad_batches=conf.gradient_acc_steps,
        gradient_clip_val=conf.gradient_clip_value,
        #val_check_interval=conf.val_check_interval,
        max_epochs = 30,
        min_epochs = 5,
        callbacks=callbacks_store,
        max_steps=conf.max_steps,
        # max_steps=total_steps,
        precision=conf.precision,
        amp_level=conf.amp_level,
        logger=wandb_logger,
        #resume_from_checkpoint=conf.checkpoint_path,
        #limit_val_batches=conf.val_percent_check
    )

    # module fit
    trainer.fit(pl_module, datamodule=pl_data_module)

In [23]:
train(conf)
#api key is fcfb005aa20d2c3af3389e0a2a6d58a829bfd2ee

Global seed set to 0
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  rank_zero_warn(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name    | Type                         | Params
---------------------------------------------------------
0 | model   | BartForConditionalGeneration | 406 M 
1 | loss_fn | CrossEntropyLoss             | 0     
---------------------------------------------------------
406 M     Trainable params
0         Non-trainable params
406 M     Total params
812.599   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]applying mask ...
(197, 2)


100%|██████████| 1/1 [00:00<00:00, 17.99ba/s]


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s]RE Evaluation in *** BOUNDARIES *** mode
processed 16 sentences with 16 relations; found: 20 relations; correct: 0.
	ALL	 TP: 0;	FP: 0;	FN: 6
		(m avg): precision: 0.00;	recall: 0.00;	f1: 0.00 (micro)
		(M avg): precision: 0.00;	recall: 0.00;	f1: 0.00 (Macro)

	MakePublicStatement: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Appeal: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ExpressIntentToCooperate: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Consult: 	TP: 0;	FP: 0;	FN: 1;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	EngageInDiplomaticCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	EngageInMaterialCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ProvideAid: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Yield: 	TP: 0;	FP: 0;	FN: 1;	precision: 0.00;	recall: 0.00;	f1: 0

100%|██████████| 3/3 [00:00<00:00,  6.27ba/s]


Epoch 1: 100%|██████████| 303/303 [06:21<00:00,  1.26s/it, loss=0.105, v_num=e81h]
Epoch 0:  79%|███████▉  | 239/303 [05:41<01:31,  1.43s/it, loss=3.54, v_num=e81h]



Epoch 0: 100%|██████████| 303/303 [06:43<00:00,  1.33s/it, loss=3.27, v_num=e81h]RE Evaluation in *** BOUNDARIES *** mode
processed 197 sentences with 211 relations; found: 196 relations; correct: 1.
	ALL	 TP: 1;	FP: 79;	FN: 100
		(m avg): precision: 1.25;	recall: 0.99;	f1: 1.10 (micro)
		(M avg): precision: 1.00;	recall: 0.42;	f1: 0.59 (Macro)

	MakePublicStatement: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Appeal: 	TP: 0;	FP: 0;	FN: 8;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ExpressIntentToCooperate: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Consult: 	TP: 0;	FP: 46;	FN: 38;	precision: 0.00;	recall: 0.00;	f1: 0.00;	46
	EngageInDiplomaticCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	EngageInMaterialCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ProvideAid: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Yield: 	TP: 0;	FP: 15;	FN: 18;	precision: 0.00;	re



Epoch 12:  28%|██▊       | 85/303 [01:22<03:31,  1.03it/s, loss=0.0029, v_num=e81h] 



Epoch 12: 100%|██████████| 303/303 [04:29<00:00,  1.12it/s, loss=0.00658, v_num=e81h]RE Evaluation in *** BOUNDARIES *** mode
processed 197 sentences with 211 relations; found: 201 relations; correct: 51.
	ALL	 TP: 51;	FP: 46;	FN: 50
		(m avg): precision: 52.58;	recall: 50.50;	f1: 51.52 (micro)
		(M avg): precision: 24.73;	recall: 21.94;	f1: 22.66 (Macro)

	MakePublicStatement: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Appeal: 	TP: 6;	FP: 6;	FN: 2;	precision: 50.00;	recall: 75.00;	f1: 60.00;	12
	ExpressIntentToCooperate: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Consult: 	TP: 21;	FP: 19;	FN: 17;	precision: 52.50;	recall: 55.26;	f1: 53.85;	40
	EngageInDiplomaticCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	EngageInMaterialCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ProvideAid: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Yield: 	TP: 9;	FP: 8;	FN: 9;	pr



Epoch 17: 100%|██████████| 303/303 [04:12<00:00,  1.20it/s, loss=0.00535, v_num=e81h]RE Evaluation in *** BOUNDARIES *** mode
processed 197 sentences with 211 relations; found: 209 relations; correct: 54.
	ALL	 TP: 54;	FP: 55;	FN: 47
		(m avg): precision: 49.54;	recall: 53.47;	f1: 51.43 (micro)
		(M avg): precision: 19.95;	recall: 20.14;	f1: 19.80 (Macro)

	MakePublicStatement: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Appeal: 	TP: 6;	FP: 6;	FN: 2;	precision: 50.00;	recall: 75.00;	f1: 60.00;	12
	ExpressIntentToCooperate: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Consult: 	TP: 19;	FP: 18;	FN: 19;	precision: 51.35;	recall: 50.00;	f1: 50.67;	37
	EngageInDiplomaticCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	EngageInMaterialCooperation: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	ProvideAid: 	TP: 0;	FP: 0;	FN: 0;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	Yield: 	TP: 10;	FP: 11;	FN: 8;	