In [1]:
import sys
sys.path.append('../../..')

In [None]:
import torch
from omegaconf import OmegaConf
from transformers import (AutoTokenizer, 
                          AutoModelForSeq2SeqLM, 
                          DataCollatorForSeq2Seq, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer)
from tqdm import tqdm

from src.utils import seed_everything
from src.data_prepocessing import load_ds, tokenize_ds
from src.evaluation import Evaluator

# Experiment set up

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
config = OmegaConf.load("t5_base_config.yaml")
print(OmegaConf.to_yaml(config))

model_name: ai-forever/ruT5-base
training_args:
  eval_strategy: epoch
  learning_rate: 0.0002
  batch_size: 16
  gradient_accumulation_steps: 16
  weight_decay: 0.01
  save_total_limit: 3
  num_train_epochs: 50
  predict_with_generate: true
  fp16: false
  push_to_hub: false
  logging_steps: 10
  overwrite_output_dir: true
inference_args:
  max_length: 50
  num_beams: 2
  early_stopping: true
  skip_special_tokens: true
  clean_up_tokenization_spaces: true



In [6]:
seed_everything(42)

# Model and data loading

In [7]:
model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [8]:
model.num_parameters()

222903552

In [9]:
def preprocess_function(sample):
    inputs = [f"{context[-1]}</s>{phrase}</s>"
    for context, phrase in zip(sample["history"], sample["phrase"])]

    model_inputs = tokenizer(inputs, text_target=sample["rewrite"], max_length=128, truncation=True)
    return model_inputs

In [10]:
ds = load_ds("2rca_checked_version.json")
tokenized_ds = tokenize_ds(ds, preprocess_function)

Map: 100%|██████████| 4411/4411 [00:01<00:00, 4070.57 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 4135.96 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 4107.86 examples/s]


## Training

In [None]:
training_args = Seq2SeqTrainingArguments(
    overwrite_output_dir=config.training_args.overwrite_output_dir,
    eval_strategy=config.training_args.eval_strategy,
    learning_rate=config.training_args.learning_rate,
    per_device_train_batch_size=config.training_args.batch_size,
    per_device_eval_batch_size=config.training_args.batch_size,
    gradient_accumulation_steps=config.training_args.gradient_accumulation_steps,
    weight_decay=config.training_args.weight_decay,
    save_total_limit=config.training_args.save_total_limit,
    num_train_epochs=config.training_args.num_train_epochs,
    predict_with_generate=config.training_args.predict_with_generate,
    fp16=config.training_args.fp16,
    push_to_hub=config.training_args.push_to_hub,
    logging_steps=config.training_args.logging_steps,
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=config.model_name
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["val"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

  trainer = Seq2SeqTrainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpvlshkunov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,5.4098,0.765807
2,0.9398,0.572155
3,0.667,0.525874
4,0.5479,0.507955
5,0.4267,0.489329
6,0.4055,0.486064
7,0.355,0.490348
8,0.3102,0.511001


KeyboardInterrupt: 

: 

# Evaluation

In [None]:
def infer_ds(ds, model, **kwargs):
    raw_test_results = []
    model = model.to(DEVICE)

    for enc in tqdm(ds['test']['input_ids']):
        input_ids=torch.tensor([enc]).to(DEVICE)
        out = model.generate(inputs=input_ids,
                             eos_token_id=tokenizer.eos_token_id,
                             max_length=config.inference_args.max_length,
                             num_beams=config.inference_args.num_beams,
                             early_stopping=config.inference_args.early_stopping)
        out = tokenizer.decode(out[0][1:],
                               skip_special_tokens=config.inference_args.skip_special_tokens,
                               clean_up_tokenization_spaces=config.inference_args.clean_up_tokenization_spaces)
        raw_test_results.append(out)

    return raw_test_results

In [None]:
evaluator = Evaluator(dataset=tokenized_ds["test"], 
                      model=model, 
                      tokenizer=tokenizer, 
                      infer_func=infer_ds)

evaluator.evaluate()

Unnamed: 0_level_0,bleu_score,rouge-1,rouge-2,rouge-3,rouge-4,rouge-l,rf_score_1,rf_score_2,rf_score_3,rf_score_4
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2rca,73.343847,0.734728,0.618773,0.530793,0.46508,0.734235,0.40779,0.324883,0.289028,0.270044


: 