In [6]:
import transformers 
import pandas as pd
import numpy as np
import json
import os

In [7]:
from typing import Any, Union, List, Optional
import omegaconf
from omegaconf import DictConfig

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import datasets
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    default_data_collator,
    set_seed,
)

from torch.optim import lr_scheduler

from pytorch_lightning.loggers.wandb import WandbLogger

In [8]:
read = pd.read_json("data_src/raw/out1.json", lines = True)

In [9]:
df = read[read.answer == "accept"]
df.head()

Unnamed: 0,text,article_id,sentence_id,_input_hash,_task_hash,_is_binary,spans,tokens,_view_id,relations,answer,_timestamp
0,"Still, the restrictions imposed on the project...",183,18315,-1595512685,351917106,False,"[{'start': 35, 'end': 46, 'token_start': 6, 't...","[{'text': 'Still', 'start': 0, 'end': 5, 'id':...",relations,"[{'head': 18, 'child': 14, 'head_span': {'star...",accept,1666082270
1,"Days later, a released fighter and others, mos...",13,1317,399796307,1377411962,False,"[{'start': 12, 'end': 30, 'token_start': 3, 't...","[{'text': 'Days', 'start': 0, 'end': 4, 'id': ...",relations,"[{'head': 20, 'child': 5, 'head_span': {'start...",accept,1666082881
2,Are countries allowed to turn away asylum seek...,11,1111,-1470052778,-160075520,False,"[{'start': 4, 'end': 13, 'token_start': 1, 'to...","[{'text': 'Are', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 1, 'child': 7, 'head_span': {'start'...",accept,1666082907
3,"One man, 28, spoke to The Washington Post on t...",167,1673,-61112669,-1028177687,False,"[{'start': 0, 'end': 11, 'token_start': 0, 'to...","[{'text': 'One', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 3, 'child': 9, 'head_span': {'start'...",accept,1666082929
4,A U.S. defense official said Monday that two N...,174,1749,-2118794475,-1006559555,False,"[{'start': 0, 'end': 23, 'token_start': 0, 'to...","[{'text': 'A', 'start': 0, 'end': 1, 'id': 0, ...",relations,[],accept,1666082966


In [10]:
#fix sentence_id mistake
df = df.astype({"sentence_id":str,"article_id":str}) 

def fix_id(row):
    return row.sentence_id[:len(row.article_id)] + "_" + row.sentence_id[len(row.article_id):]

df["sentence_id"] = df.apply(fix_id, axis = 1)

#add annotator column in case more annotators will be used
df["annotator"] = 0

In [11]:
#for EDA
df["relations_count"] = df.relations.apply(lambda x: len(x))

In [12]:
df.head()

Unnamed: 0,text,article_id,sentence_id,_input_hash,_task_hash,_is_binary,spans,tokens,_view_id,relations,answer,_timestamp,annotator,relations_count
0,"Still, the restrictions imposed on the project...",183,183_15,-1595512685,351917106,False,"[{'start': 35, 'end': 46, 'token_start': 6, 't...","[{'text': 'Still', 'start': 0, 'end': 5, 'id':...",relations,"[{'head': 18, 'child': 14, 'head_span': {'star...",accept,1666082270,0,1
1,"Days later, a released fighter and others, mos...",13,13_17,399796307,1377411962,False,"[{'start': 12, 'end': 30, 'token_start': 3, 't...","[{'text': 'Days', 'start': 0, 'end': 4, 'id': ...",relations,"[{'head': 20, 'child': 5, 'head_span': {'start...",accept,1666082881,0,2
2,Are countries allowed to turn away asylum seek...,11,11_11,-1470052778,-160075520,False,"[{'start': 4, 'end': 13, 'token_start': 1, 'to...","[{'text': 'Are', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 1, 'child': 7, 'head_span': {'start'...",accept,1666082907,0,1
3,"One man, 28, spoke to The Washington Post on t...",167,167_3,-61112669,-1028177687,False,"[{'start': 0, 'end': 11, 'token_start': 0, 'to...","[{'text': 'One', 'start': 0, 'end': 3, 'id': 0...",relations,"[{'head': 3, 'child': 9, 'head_span': {'start'...",accept,1666082929,0,1
4,A U.S. defense official said Monday that two N...,174,174_9,-2118794475,-1006559555,False,"[{'start': 0, 'end': 23, 'token_start': 0, 'to...","[{'text': 'A', 'start': 0, 'end': 1, 'id': 0, ...",relations,[],accept,1666082966,0,0


In [13]:
#prune df
df = df.drop(columns = ["_input_hash","_task_hash","_is_binary","tokens","_view_id","answer","_timestamp"])
df = df.reset_index().drop(columns = ["index"])

In [14]:
##### check if it performs better with empty sentences or without
#df = df[df.relations_count > 0]

In [15]:
#get data into needed format and create new DataFrame
relations = []
for index,row in df.iterrows():
    for rel in row.relations:
        relations.append(rel['label'])

relations = list(set(relations))
relations

['Demand',
 'Assault',
 'Consult',
 'Coerce',
 'Disapprove',
 'ExpressIntentToCooperate',
 'Appeal',
 'EngageInMaterialCooperation',
 'EngageInDiplomaticCooperation',
 'ProvideAid',
 'ReduceRelations',
 'Yield',
 'Reject',
 'Fight',
 'MakePublicStatement',
 'Protest',
 'Investigate',
 'ExhibitMilitaryPosture']

In [16]:
#get data into needed format and create new DataFrame
data = []
for index,row in df.iterrows():
    rel_list = []
    for rel in row.relations:
        subj = row.text[rel["head_span"]["start"]:rel["head_span"]["end"]]
        obj = row.text[rel["child_span"]["start"]:rel["child_span"]["end"]]
        rel_list.append(f"<triplet> {subj} <subj> {obj} <obj> {rel['label']}")
    data.append({"doc_id":row.sentence_id, "text": row.text, "triplets": " ".join(rel_list)})

data = pd.DataFrame(data, columns = ["doc_id","text","triplets"])

In [17]:
ds = datasets.Dataset.from_pandas(data.drop(columns=["doc_id"]))

#### just random splits for testing, change later to proper splitting
###### and then change later again for CV
split = ds.train_test_split(test_size = 0.2)
split_val = split["test"].train_test_split(test_size = 0.3)
ds = datasets.DatasetDict({"train":split["train"],"val":split_val["train"], "test":split_val["test"]})

#ds = ds.with_format("torch")
ds

DatasetDict({
    train: Dataset({
        features: ['text', 'triplets'],
        num_rows: 131
    })
    val: Dataset({
        features: ['text', 'triplets'],
        num_rows: 23
    })
    test: Dataset({
        features: ['text', 'triplets'],
        num_rows: 10
    })
})

In [29]:
class conf:
    #general
    seed = 0
    gpus = 0
    
    #input
    batch_size = 32
    max_length = 128
    ignore_pad_token_for_loss = True
    use_fast_tokenizer = True
    gradient_acc_steps = 1
    gradient_clip_value = 10.0
    load_workers = 8

    #optimizer
    lr = 0.00005
    weight_decay = 0.01
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 0.00000001
    warmup_steps = 1000

    #training
    max_steps = 1200000
    samples_interval = 1000

    monitor_var  = "val_loss"
    monitor_var_mode = "min"
    # val_check_interval = 0.5
    # val_percent_check = 0.1

    model_name = "model1.pth"
    checkpoint_path = f"models/{model_name}"
    save_top_k = 1

    early_stopping = False
    patience = "5"

    length_penalty = 0
    no_repeat_ngram_size = 0
    num_beams = 3
    precision = 16
    amp_level = None

    

In [19]:
tokenizer = transformers.AutoTokenizer.from_pretrained("Babelscape/rebel-large", use_fast = conf.use_fast_tokenizer,
        additional_special_tokens = ["<obj>", "<subj>", "<triplet>", "<head>", "</head>", "<tail>", "</tail>"])
config = transformers.AutoConfig.from_pretrained("Babelscape/rebel-large")
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large", config = config)

In [32]:
class GetData(pl.LightningDataModule):
    def __init__(self, conf: conf, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM):
        """init params from config"""
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.datasets = ds
        
        # Data collator
        label_pad_token_id = -100                       
        self.data_collator = DataCollatorForSeq2Seq(self.tokenizer, self.model, label_pad_token_id=label_pad_token_id, padding = True)

    def preprocess_function(self, data):
        """tokenize, pad, truncate"""
        #split into input and labels
        inputs = data["text"]       
        outputs = data["triplets"]

        #process input
        model_inputs = self.tokenizer(inputs, max_length = conf.max_length, padding = True, truncation = True)

        #process labels
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(outputs, max_length = conf.max_length, padding = True, truncation = True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
        
    #apply the preprocessing and load data
    def train_dataloader(self, *args, **kwargs): 
        self.train_dataset = self.datasets["train"]
        self.train_dataset = self.train_dataset.map(self.preprocess_function, remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.train_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, shuffle = True, num_workers= conf.load_workers)

    def val_dataloader(self, *args, **kwargs): 
        self.eval_dataset = self.datasets["val"]
        self.eval_dataset = self.eval_dataset.map(self.preprocess_function, remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.eval_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, num_workers= conf.load_workers)

    def test_dataloader(self, *args, **kwargs): 
        self.test_dataset = self.datasets["test"]
        self.test_dataset = self.test_dataset.map(self.preprocess_function,  remove_columns = ["text", "triplets"], batched = True)
        return DataLoader(self.test_dataset, batch_size = conf.batch_size, collate_fn = self.data_collator, num_workers= conf.load_workers)

In [21]:
# #from REBEL
# def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
#     """From fairseq"""
#     if target.dim() == lprobs.dim() - 1:
#         target = target.unsqueeze(-1)
#     nll_loss = -lprobs.gather(dim=-1, index=target)
#     smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
#     if ignore_index is not None:
#         pad_mask = target.eq(ignore_index)
#         nll_loss.masked_fill_(pad_mask, 0.0)
#         smooth_loss.masked_fill_(pad_mask, 0.0)
#     else:
#         nll_loss = nll_loss.squeeze(-1)
#         smooth_loss = smooth_loss.squeeze(-1)

#     nll_loss = nll_loss.sum()  # mean()? Scared to break other math.
#     smooth_loss = smooth_loss.sum()
#     eps_i = epsilon / lprobs.size(-1)
#     loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
#     return loss, nll_loss


In [22]:
#from REBEL
from typing import Sequence

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from torch import nn
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
import wandb

class GenerateTextSamplesCallback(Callback):
    """
    PL Callback to generate triplets along training
    """

    def __init__(self, logging_batch_interval):
        """
        Args:
            logging_batch_interval: How frequently to inspect/potentially plot something
        """
        super().__init__()
        self.logging_batch_interval = logging_batch_interval

    def on_train_batch_end(self,trainer: Trainer,pl_module: LightningModule, outputs: Sequence, batch: Sequence, batch_idx: int, dataloader_idx: int,) -> None:
        wandb_table = wandb.Table(columns=["Source", "Pred", "Gold"])
        # pl_module.logger.info("Executing translation callback")
        if (trainer.batch_idx + 1) % self.logging_batch_interval != 0:  # type: ignore[attr-defined]
            return
        labels = batch.pop("labels")
        gen_kwargs = {
            "max_length": conf.max_length,
            "early_stopping": False,
            "no_repeat_ngram_size": 0,
            "num_beams": conf.num_beans
        }
        pl_module.eval()

        decoder_inputs = torch.roll(labels, 1, 1)[:,0:2]
        decoder_inputs[:, 0] = 0
        generated_tokens = pl_module.model.generate(
            batch["input_ids"].to(pl_module.model.device),
            attention_mask=batch["attention_mask"].to(pl_module.model.device),
            decoder_input_ids=decoder_inputs.to(pl_module.model.device),
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = pl_module._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
        pl_module.train()
        decoded_preds = pl_module.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

        if pl_module.hparams.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = torch.where(labels != -100, labels, pl_module.tokenizer.pad_token_id)

        decoded_labels = pl_module.tokenizer.batch_decode(labels, skip_special_tokens=False)
        decoded_inputs = pl_module.tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)

        # pl_module.logger.experiment.log_text('generated samples', '\n'.join(decoded_preds).replace('<pad>', ''))
        # pl_module.logger.experiment.log_text('original samples', '\n'.join(decoded_labels).replace('<pad>', ''))
        for source, translation, gold_output in zip(decoded_inputs, decoded_preds, decoded_labels):
            wandb_table.add_data(
                source.replace('<pad>', ''), translation.replace('<pad>', ''), gold_output.replace('<pad>', '')
            )
        pl_module.logger.experiment.log({"Triplets": wandb_table})

In [23]:
#from REBEL
def shift_tokens_left(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, :-1] = input_ids[:, 1:].clone()
    shifted_input_ids[:, -1] = pad_token_id

    assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."

    return shifted_input_ids

In [24]:
#from REBEL

##### can maybe be rewritten more effective
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

In [25]:
class BaseModule(pl.LightningModule):

    def __init__(self, conf, config: AutoConfig, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM):
        super().__init__()
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        #self.loss_fn = label_smoothed_nll_loss
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
        self.num_beans = conf.num_beans

    def forward(self, inputs, labels):
        ##### Check later if smooth labeled loss is better 
        outputs = self.model(**inputs, labels = labels, use_cache = False, return_dict = True, output_hidden_states = True)
        output_dict = {'loss': outputs['loss'], 'logits': outputs['logits']}
        return output_dict

    def training_step(self, batch: dict, batch_idx: int):
        ##### check later if labels = batch["labels"] also works
        labels = batch.pop("labels")
        labels_original = labels.clone()
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)

        forward_output = self.forward(batch, labels)
        self.log('loss', forward_output['loss'])

        batch["labels"] = labels_original

        #### ig i dont have this
        if 'loss_aux' in forward_output:
            self.log('loss_classifier', forward_output['loss_aux'])
            return forward_output['loss'] + forward_output['loss_aux']

        return forward_output['loss']

    def validation_step(self, batch: dict, batch_idx):
        #### pop maybe not needed?
        labels = batch.pop("labels")
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)
        
        with torch.no_grad():
            # compute loss on predict data
            forward_output = self.forward(batch, labels)

        forward_output['loss'] = forward_output['loss'].mean().detach()

        #### probably not needed ? why would i want only pred loss
        # if self.hparams.prediction_loss_only:
        #     self.log('val_loss', forward_output['loss'])
        #     return

        forward_output['logits'] = forward_output['logits'].detach()

        if labels.shape[-1] < conf.max_length:
            padded_tensor = self.config.pad_token_id * torch.ones((labels.shape[0], conf.max_length), dtype=tensor.dtype, device=tensor.device)
            padded_tensor[:, : labels.shape[-1]] = labels
            forward_output["labels"] = padded_tensor
        else:
            forward_output['labels'] = labels

        metrics = {}
        metrics['val_loss'] = forward_output['loss']

        #### only 1? so why loop lmao
        for key in sorted(metrics.keys()):
            self.log(key, metrics[key])

        outputs = {}
        outputs['predictions'], outputs['labels'] = self.generate_triples(batch, labels)
        return outputs


    def test_step(self, batch, batch_idx):

        #### popping again
        labels = batch.pop("labels")
        batch["decoder_input_ids"] = torch.where(labels != -100, labels, self.config.pad_token_id)
        labels = shift_tokens_left(labels, -100)

        with torch.no_grad():
            # compute loss on predict data
            forward_output = self.forward(batch, labels)

        forward_output['loss'] = forward_output['loss'].mean().detach()

        ##### probably not needed? would would i only want pred loss
        # if self.hparams.prediction_loss_only:
        #     self.log('test_loss', forward_output['loss'])
        #     return

        forward_output['logits'] = generated_tokens.detach() if self.hparams.predict_with_generate else forward_output['logits'].detach()

        if labels.shape[-1] < gen_kwargs["max_length"]:
            padded_tensor = self.config.pad_token_id * torch.ones((labels.shape[0], conf.max_length), dtype=tensor.dtype, device=tensor.device)
            padded_tensor[:, : labels.shape[-1]] = labels
            forward_output["labels"] = padded_tensor
        else:
            forward_output['labels'] = labels


        metrics = {}
        metrics['test_loss'] = forward_output['loss']

        #### dont i only have one metric anyways?
        for key in sorted(metrics.keys()):
            self.log(key, metrics[key], prog_bar=True)

        #### what does this actually do? how does this change everything
        # if self.hparams.finetune:
        #     return {'predictions': self.forward_samples(batch, labels)}
        # else:

        outputs = {}
        outputs['predictions'], outputs['labels'] = self.generate_triples(batch, labels)
        return outputs



    def validation_epoch_end():
        
        relations = ["MakePublicStatement","Appeal","ExpressIntentToCooperate","Consult","EngageInDiplomaticCooperation","EngageInMaterialCooperation","ProvideAid","Yield","Investigate","Demand","Disapprove","Reject","Threaten","ExhibitMilitaryPosture","Protest","ReduceRelations","Coerce","Assault","Fight","EngageInUnconventialMassViolence"]
        scores, precision, recall, f1 = re_score([item for pred in output for item in pred['predictions']], [item for pred in output for item in pred['labels']], relations)
        self.log('val_prec_micro', precision)
        self.log('val_recall_micro', recall)
        self.log('val_F1_micro', f1)

    def test_epoch_end():

        relations = ["MakePublicStatement","Appeal","ExpressIntentToCooperate","Consult","EngageInDiplomaticCooperation","EngageInMaterialCooperation","ProvideAid","Yield","Investigate","Demand","Disapprove","Reject","Threaten","ExhibitMilitaryPosture","Protest","ReduceRelations","Coerce","Assault","Fight","EngageInUnconventialMassViolence"]
        scores, precision, recall, f1 = re_score([item for pred in output for item in pred['predictions']], [item for pred in output for item in pred['labels']], relations)
        self.log('test_prec_micro', precision)
        self.log('test_recall_micro', recall)
        self.log('test_F1_micro', f1)


    # additional functions called in main functions

    def generate_triples(self, batch, labels) -> None:

        generated_tokens = self.model.generate(
            batch["input_ids"].to(self.model.device),
            attention_mask=batch["attention_mask"].to(self.model.device),
            use_cache = True, max_length = conf.max_length, early_stopping = conf.early_stopping, length_penalty = conf.length_penalty, 
            no_repeat_ngram_size = conf.no_repeat_ngram_size, num_beans = conf.num_beams)

        decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
        decoded_labels = self.tokenizer.batch_decode(torch.where(labels != -100, labels, self.config.pad_token_id), skip_special_tokens=False)

        return [extract_triplets(rel) for rel in decoded_preds], [extract_triplets(rel) for rel in decoded_labels]

    def configure_optimizers(self):

        ##### HUH
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": conf.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr = conf.lr, betas = (conf.beta1, conf.beta2), eps = conf.epsilon, weight_decay = conf.weight_decay)
        #### also check warmup later
        factor = lambda epoch: 0.95
        scheduler = lr_scheduler.MultiplicativeLR(optimizer, factor)
        
        #scheduler = inverse_square_root(optimizer, num_warmup_steps= conf.warmup_steps)

        return [optimizer], [scheduler]

    #def compute_metrics():
        #looks not needed    

In [26]:
def train(conf):
    pl.seed_everything(conf.seed)

    tokenizer = transformers.AutoTokenizer.from_pretrained("Babelscape/rebel-large", use_fast = conf.use_fast_tokenizer,
        additional_special_tokens = ["<obj>", "<subj>", "<triplet>", "<head>", "</head>", "<tail>", "</tail>"])
    config = transformers.AutoConfig.from_pretrained("Babelscape/rebel-large")
    model = transformers.AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large", config = config)

    model.resize_token_embeddings(len(tokenizer))

    pl_data_module = GetData(conf, tokenizer, model)
    pl_module = BaseModule(conf, config, tokenizer, model)

    wandb_logger = WandbLogger(project = "project/finetune".split('/')[-1].replace('.py', ''), name = "finetune")

    callbacks_store = []

    if conf.early_stopping:
        callbacks_store.append(
            EarlyStopping(
                monitor=conf.monitor_var,
                mode=conf.monitor_var_mode,
                patience=conf.patience
            )
        )

    # callbacks_store.append(
    #     ModelCheckpoint(
    #         monitor=conf.monitor_var,
    #         # monitor=None,
    #         dirpath=f'models/{conf.model_name}',
    #         save_top_k=conf.save_top_k,
    #         verbose=True,
    #         save_last=True,
    #         mode=conf.monitor_var_mode
    #     )
    # )
    callbacks_store.append(GenerateTextSamplesCallback(conf.samples_interval))
    callbacks_store.append(LearningRateMonitor(logging_interval='step'))

    trainer = pl.Trainer(
        gpus=conf.gpus,
        accumulate_grad_batches=conf.gradient_acc_steps,
        gradient_clip_val=conf.gradient_clip_value,
        #val_check_interval=conf.val_check_interval,
        max_epochs = 10,
        min_epochs = 5,
        callbacks=callbacks_store,
        max_steps=conf.max_steps,
        # max_steps=total_steps,
        precision=conf.precision,
        amp_level=conf.amp_level,
        logger=wandb_logger,
        #resume_from_checkpoint=conf.checkpoint_path,
        #limit_val_batches=conf.val_percent_check
    )

    # module fit
    trainer.fit(pl_module, datamodule=pl_data_module)

In [33]:
train(conf)

Global seed set to 0
Using bfloat16 Automatic Mixed Precision (AMP)


                                                                   

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type                         | Params
---------------------------------------------------------
0 | model   | BartForConditionalGeneration | 406 M 
1 | loss_fn | CrossEntropyLoss             | 0     
---------------------------------------------------------
406 M     Trainable params
0         Non-trainable params
406 M     Total params
1,625.194 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 25.00ba/s]


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

TypeError: forward() got an unexpected keyword argument 'num_beans'