In [None]:
from data_utils import ABSADataset, NonABSADataset, Pattern, Prompter
from model import ABSAGenerativeModelWrapper
from training import ABSAGenerativeTrainer
import pandas as pd
from datasets import Dataset
import torch
import os

In [None]:
wrapper_config = {
    "model_name_or_path" : "t5-base",
    "model_args" : {},
    "tokenizer_args" : {}
}

pattern_config = {
    "open_bracket" : '(',
    "close_bracket" : ')',
    "intra_sep" : ',',
    "inter_sep" : ';',
    "categories" : ["NONE"]
}

"""
Prompter masks:
- PATTERN : Pattern mask, resulting the following example -> ( <A> , <O> , <S> )
- CATEGORY : List of categories mask, resulting the following example -> [CAT0,CAT1]
- IMPUTATION_FIELD : Imputation field mask, resulting the following example -> ( pizza , yummy , <S> )
"""
prompter_config = {
    "prompter_template" : {
        "extraction" : "Extract with the format PATTERN for the following text",
        "imputation" : "Impute the following IMPUTATION_FIELD for the following text"
    }
}

data_config = {
    "train" : {
        "absa" : {
            "data_path" : "train_triplets.txt",
            "target_format" : "aos"
        },
        "non_absa" : [{
            "data_path" : "non_absa.csv"
        }],
        "absa_builder_args" : {
            "tasks" : {
                "extraction" : ["as","os",'a','o'],
                "imputation" : {
                    "aos" : ["ao","os"]
                }
            },
            "multiply" : True,
            "shuffle" : True,
            "random_state" : 0
        }
    },
    "val" : {
        "absa" : {
            "data_path" : "eval_triplets.txt",
            "target_format" : "aos"
        },
        "non_absa" : [{
            "data_path" : "non_absa.csv"
        }],
        "absa_builder_args" : {
            "tasks" : {
                "extraction" : ["as","os",'a','o'],
                "imputation" : {
                    "aos" : ["ao","os"]
                }
            },
            "multiply" : True,
            "shuffle" : True,
            "random_state" : 0
        }
    },
    "test" : {
        "absa" : {
            "data_path" : "test_triplets.txt",
            "target_format" : "aos"
        },
        "non_absa" : [{
            "data_path" : "non_absa.csv"
        }],
        "task_tree" : {"aos" : {"ao" : ['a'],"os" : []}}
    }
}

encoding_args = {
    "max_length" : 256,
    "padding" : True,
    "truncation" : True,
    "return_tensors" : "pt"
}

train_args = {
    "num_train_epochs" : 2,
    "learning_rate" : 3e-4,
    "gradient_accumulation_steps" : 2,
    "per_device_train_batch_size" : 4,
    "per_device_eval_batch_size" : 4,
    "save_strategy" : "epoch",
    "evaluation_strategy" : "epoch",
    "save_total_limit" : 2,
    "metric_for_best_model" : "overall_f1_score",
    "load_best_model_at_end" : True,
    "adam_epsilon" : 1e-8,
    "eval_accumulation_steps" : 10,
    "output_dir" : "./output",
    "logging_dir" : "./output/log_history"
}

train_random_seed = 0
gpu = "0,1"

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu

In [None]:
wrapper = ABSAGenerativeModelWrapper(**wrapper_config)
pattern = Pattern(**pattern_config)
prompter = Prompter(**prompter_config)

In [None]:
trainer = ABSAGenerativeTrainer(absa_model_and_tokenizer=wrapper,pattern=pattern)

In [None]:
# ABSA Datasets

train_absa_args = data_config["train"]["absa"]
train_absa_args.update({
    "prompter" : prompter,
    "prompt_side" : wrapper.prompt_side,
    "pattern" : pattern
})
train_absa = ABSADataset(**train_absa_args)

val_absa_args = data_config["val"]["absa"]
val_absa_args.update({
    "prompter" : prompter,
    "prompt_side" : wrapper.prompt_side,
    "pattern" : pattern
})
val_absa = ABSADataset(**val_absa_args)

test_absa_args = data_config["test"]["absa"]
test_absa_args.update({
    "prompter" : prompter,
    "prompt_side" : wrapper.prompt_side,
    "pattern" : pattern
})
test_absa = ABSADataset(**test_absa_args)

In [None]:
# Non ABSA Datasets

non_absa_train = []
for args in data_config["train"]["non_absa"]:
    args.update({
        "prompt_side" : wrapper.prompt_side
    })
    non_absa_train.append(NonABSADataset(**args))

non_absa_val = []
for args in data_config["val"]["non_absa"]:
    args.update({
        "prompt_side" : wrapper.prompt_side
    })
    non_absa_val.append(NonABSADataset(**args))

non_absa_test = []
for args in data_config["test"]["non_absa"]:
    args.update({
        "prompt_side" : wrapper.prompt_side
    })
    non_absa_test.append(NonABSADataset(**args))

In [None]:
train_data = pd.concat([non_absa_ds.build_data().to_pandas() for non_absa_ds in non_absa_train] + [train_absa.build_train_val_data(**data_config["train"]["absa_builder_args"])])
val_data = pd.concat([non_absa_ds.build_data().to_pandas() for non_absa_ds in non_absa_val] + [val_absa.build_train_val_data(**data_config["val"]["absa_builder_args"])])

train_data = Dataset.from_pandas(train_data)
val_data = Dataset.from_pandas(val_data)

In [None]:
trainer.prepare_data(train_dataset=train_data,eval_dataset=val_data, **encoding_args)

In [None]:
trainer.compile_train_args(train_args_dict=train_args)

In [None]:
trainer.prepare_trainer()

In [None]:
trainer.train(output_dir=train_args["output_dir"],random_seed=train_random_seed)

# Prediction

In [None]:
decoding_args = {
    "skip_special_tokens" : True
}
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
non_absa_preds = trainer.predict_non_absa(dataset=non_absa_test,device=device,encoding_args=encoding_args,decoding_args=decoding_args)

In [None]:
absa_preds, summary_score = trainer.predict_absa(dataset=test_absa,task_tree=data_config["test"]["task_tree"],device=device,encoding_args=encoding_args,decoding_args=decoding_args)