In [None]:
import numpy as np
import pandas as pd
import torch

from tqdm.auto import tqdm
tqdm.pandas()
pd.set_option('display.max_columns', None)

import numpy as np
from datasets import Dataset, load_metric
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import warnings
warnings.filterwarnings('ignore')

In [None]:
from datasets import ClassLabel
df = pd.read_parquet('../data/parquet/dataset.parquet')
labels = df['target'].unique().tolist()

df['text'] = df['text'].astype(str)

lconv = ClassLabel(num_classes=len(labels), names=labels)
df['target'] = df['target'].astype(str)

train = df.sample(frac=0.90, random_state=42).reset_index(drop=True)
test = df.drop(train.index).reset_index(drop=True)

In [None]:
# just for test purposes
#train = train.sample(10).reset_index(drop=True)
#test = test.sample(2).reset_index(drop=True)

In [None]:
ds_train = Dataset.from_pandas(train)
print(ds_train)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
model_name = 't5-base' 
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/mnt/dmif-nas/SMDC/HF-Cache/")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="/mnt/dmif-nas/SMDC/HF-Cache/")

In [None]:
max_input_length = 512
max_target_length = 64

def preprocess_function(examples):
    inputs = [doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["target"], max_length=max_target_length, truncation=True, padding=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = ds_train.map(preprocess_function, batched=True)
tokenized_train

In [None]:
# validation
ds_test = Dataset.from_pandas(test)
tokenized_val = ds_test.map(preprocess_function, batched=True)
tokenized_val

In [None]:
import json
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

training_metrics = []

def string2int(x):
    try:
        return lconv.str2int(x)
    except:
        return -1

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Replace -100 in the predictions as we can't decode them.
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_preds = [string2int(x) for x in decoded_preds]

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_labels = [string2int(x) for x in decoded_labels]

    precision = load_metric("precision")
    recall = load_metric("recall")
    accuracy = load_metric("accuracy")
    f1 = load_metric("f1")

    metrics = {
        **precision.compute(predictions=decoded_preds, references=decoded_labels, average='macro'),
        **recall.compute(predictions=decoded_preds, references=decoded_labels, average='macro'),
        **accuracy.compute(predictions=decoded_preds, references=decoded_labels),
        **f1.compute(predictions=decoded_preds, references=decoded_labels, average='macro'),
        **{"Not valid": len([x for x in decoded_preds if x == -1])/len(decoded_preds)}
    }

    training_metrics.append(metrics)

    return metrics

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

args = Seq2SeqTrainingArguments(
    output_dir='/mnt/dmif-nas/SMDC/HF-tmp/',
    evaluation_strategy = "epoch",
    save_strategy="no",
    num_train_epochs=3,
    predict_with_generate=True,
    no_cuda=False,
    per_device_train_batch_size=2,
)

class CustomTrainer(Seq2SeqTrainer):
#     def compute_loss(self, model, inputs, return_outputs=False):
#         input_ids = inputs.get("img")
#         labels = inputs.get("labels")
#         logits = model(input_ids)
#         y_pred = torch.max(logits, 1).indices.float()
#         y_true = labels.view(-1).float()
#         loss_fct = nn.CrossEntropyLoss()
#         loss = loss_fct(logits, labels)
#         outputs = {'logits':y_pred, 'labels':y_true}
#         return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model: torch.nn.Module, inputs, prediction_loss_only, ignore_keys):
        with torch.no_grad():
            input_ids = inputs.get("input_ids").to('cuda')
            attention_mask = inputs.get("attention_mask").to('cuda')
            labels = inputs.get("labels").to('cuda')
            beam_outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=128,
                early_stopping=True,
                num_beams=1,
                num_return_sequences=1,
            )
            return (None, beam_outputs, labels)

trainer = CustomTrainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset= tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
trainer.train()

In [None]:
from datetime import datetime

model_path = "../models/t5"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

with open(f'../logs/t5-metrics_{datetime.now().strftime("%Y%m%d%H%M")}.json', 'w') as f:
    json.dump(training_metrics, f, indent=2)