In [1]:
import sys
sys.path.append('../../..')

In [3]:
import torch

from omegaconf import OmegaConf
from transformers import (AutoTokenizer, 
                          AutoModelForSeq2SeqLM, 
                          DataCollatorForSeq2Seq, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer)
from tqdm import tqdm

from src.utils import seed_everything
from src.data_prepocessing import load_ds, tokenize_ds
from src.evaluation import Evaluator

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Experiment setup

In [5]:
config = OmegaConf.load("t5_large_4_config.yaml")
print(OmegaConf.to_yaml(config))

model_name: ai-forever/ruT5-large
training_args:
  eval_strategy: epoch
  learning_rate: 0.0004
  batch_size: 8
  gradient_accumulation_steps: 64
  weight_decay: 0.01
  save_total_limit: 3
  num_train_epochs: 20
  predict_with_generate: true
  fp16: false
  push_to_hub: false
  logging_steps: 10
  overwrite_output_dir: true
lora_args:
  r: 16
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules:
  - q
  - v
  bias: none



In [6]:
seed_everything(42)

# Model and data loading

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [8]:
def preprocess_function(sample):
    inputs = [f"{context[-1]}</s>{phrase}</s>"
    for context, phrase in zip(sample["history"], sample["phrase"])]

    model_inputs = tokenizer(inputs, text_target=sample["rewrite"], max_length=128, truncation=True)
    return model_inputs

In [9]:
ds = load_ds("2rca_checked_version.json")
tokenized_ds = tokenize_ds(ds, preprocess_function)

Map: 100%|██████████| 4411/4411 [00:01<00:00, 3441.94 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 3604.16 examples/s]
Map: 100%|██████████| 551/551 [00:00<00:00, 3284.12 examples/s]


# Model training

In [10]:
from peft import LoraConfig, get_peft_model
from peft.utils.peft_types import TaskType


lora_config = LoraConfig(
    r=config.lora_args.r,
    lora_alpha=config.lora_args.lora_alpha,
    lora_dropout=config.lora_args.lora_dropout,
    target_modules=config.lora_args.target_modules,
    bias=config.lora_args.bias,
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(model, lora_config)

In [11]:
peft_model.print_trainable_parameters()

trainable params: 4,718,592 || all params: 742,386,688 || trainable%: 0.6356


In [None]:
training_args = Seq2SeqTrainingArguments(
    overwrite_output_dir=config.training_args.overwrite_output_dir,
    eval_strategy=config.training_args.eval_strategy,
    learning_rate=config.training_args.learning_rate,
    per_device_train_batch_size=config.training_args.batch_size,
    per_device_eval_batch_size=config.training_args.batch_size,
    gradient_accumulation_steps=config.training_args.gradient_accumulation_steps,
    weight_decay=config.training_args.weight_decay,
    save_total_limit=config.training_args.save_total_limit,
    num_train_epochs=config.training_args.num_train_epochs,
    predict_with_generate=config.training_args.predict_with_generate,
    fp16=config.training_args.fp16,
    push_to_hub=config.training_args.push_to_hub,
    logging_steps=config.training_args.logging_steps,
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=config.model_name
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenize_ds["train"],
    eval_dataset=tokenize_ds["val"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Currently logged in as: [33mpvlshkunov[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,No log,0.899513
2,2.124000,0.64162
3,0.946900,0.551066
4,0.723500,0.513759
5,0.635800,0.494639
6,0.583100,0.485009
7,0.535500,0.474099
8,0.496800,0.459754
9,0.495600,0.475097
10,0.444000,0.464069


KeyboardInterrupt: 

In [None]:
model = peft_model.merge_and_unload()

# Evaluation

In [None]:
def infer_ds(ds, model, **kwargs):
    raw_test_results = []
    model = model.to(DEVICE)

    for enc in tqdm(ds['test']['input_ids']):
        input_ids=torch.tensor([enc]).to(DEVICE)
        out = model.generate(inputs=input_ids,
                             eos_token_id=tokenizer.eos_token_id,
                             max_length=config.inference_args.max_length,
                             num_beams=config.inference_args.num_beams,
                             early_stopping=config.inference_args.early_stopping)
        out = tokenizer.decode(out[0][1:],
                               skip_special_tokens=config.inference_args.skip_special_tokens,
                               clean_up_tokenization_spaces=config.inference_args.clean_up_tokenization_spaces)
        raw_test_results.append(out)

    return raw_test_results

In [None]:
evaluator = Evaluator(dataset=tokenized_ds, 
                      model=model, 
                      tokenizer=tokenizer, 
                      infer_func=infer_ds)

evaluator.evaluate()

Unnamed: 0_level_0,bleu_score,rouge-1,rouge-2,rouge-3,rouge-4,rouge-l,rf_score_1,rf_score_2,rf_score_3,rf_score_4
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2rca,73.723347,0.742012,0.634155,0.551436,0.476842,0.741643,0.401312,0.323508,0.292859,0.274809
