In [1]:
import torch
import numpy as np
import random
import transformers
import datasets
from kbqa.seq2seq.utils import convert_to_features
from kbqa.seq2seq.train import train as train_seq2seq

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
torch.manual_seed(8)
random.seed(8)
np.random.seed(0)

In [4]:
model_checkpoint = 'google/t5-large-ssm-nq'

model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

dataset = datasets.load_dataset('AmazonScience/mintaka')
dataset = dataset.map(
    lambda batch: convert_to_features(
        batch, tokenizer, label_feature_name="answerText"
    ),
    batched=True,
)

columns = [
    "input_ids",
    "labels",
    "attention_mask",
]
dataset.set_format(type="torch", columns=columns)

No config specified, defaulting to: mintaka/en
Found cached dataset mintaka (/root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d/cache-f20e8bccc090c16d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d/cache-42272688b433a0e2.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/AmazonScience___mintaka/en/1.0.0/bb35d95f07aed78fa590601245009c5f585efe909dbd4a8f2a4025ccf65bb11d/cache-55611f91ffefca45.arrow


In [5]:
trainer = train_seq2seq(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    valid_dataset=dataset["validation"],
    output_dir=f'/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned/model_t5_large_ssm_nq/models/',
    logging_dir=f'/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned/model_t5_large_ssm_nq/logs/',
    max_steps=10000,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    save_total_limit=1,
    eval_steps=1000,
)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss


In [8]:
tokenizer.save_pretrained('/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned/model_t5_large_ssm_nq/models/checkpoint-7000/')

('/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned//model_t5_large_ssm_nq/models/checkpoint-7000/tokenizer_config.json',
 '/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned//model_t5_large_ssm_nq/models/checkpoint-7000/special_tokens_map.json',
 '/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned//model_t5_large_ssm_nq/models/checkpoint-7000/spiece.model',
 '/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned//model_t5_large_ssm_nq/models/checkpoint-7000/added_tokens.json',
 '/mnt/storage/QA_System_Project/seq2seq_runs/mintaka_only_experiments_mintaka_tunned//model_t5_large_ssm_nq/models/checkpoint-7000/tokenizer.json')