In [None]:
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch
from torch.utils.data import Dataset

In [None]:
train = pd.read_csv('./bird_train.csv')
test = pd.read_csv('./bird_test.csv')


In [None]:
tokenizer = AutoTokenizer.from_pretrained('t5-base')
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')

In [None]:
class BirdDataset(Dataset):
    def __init__(self, data):
        self.s1 = data['s1'].values
        self.s2 = data['s2'].values
    
    def __len__(self):
        return len(self.s1)
    
    def __getitem__ (self, idx):
        s1 = self.s1[idx]
        s2 = self.s2[idx]
        return {'input_ids':tokenizer(s1, padding='max_length', max_length=378, return_tensors='pt')['input_ids'].squeeze(0),
                'labels':tokenizer(s2, padding='max_length',max_length=70, return_tensors='pt')['input_ids'].squeeze(0)}

In [None]:
train_b = BirdDataset(train)
test_b = BirdDataset(test)

In [None]:
batch_size = 4
args = Seq2SeqTrainingArguments(
    output_dir="./t5-project-test",
    evaluation_strategy = "epoch",
    learning_rate=2e-6,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    push_to_hub=False,
)

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_b,
    eval_dataset=test_b
)

trainer.train()