In [1]:
import os
import torch
import wandb

from transformers import (
    AutoTokenizer,
    MBartForConditionalGeneration,
    MBartConfig,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from datasets import load_dataset, Dataset
from torch.utils.data import Subset
from bert_score import score as bert_score_metric

import config
from usecrets import WANDB_API_KEY


os.environ["WANDB_API_KEY"] = WANDB_API_KEY

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
if device.type == "cuda":
    print("GPU Device Name:", torch.cuda.get_device_name(0))


wandb.init(
    project=config.WANDB_PROJECT,
    name=config.WANDB_RUN_NAME,
    config={
        "num_train_epochs": config.NUM_EPOCHS,
        "batch_size": config.BATCH_SIZE,
        "weight_decay": config.WEIGHT_DECAY,
    }
)

raw_dataset = load_dataset("json", data_files="train_smart.jsonl", field=None)["train"]
split_dataset = raw_dataset.train_test_split(test_size=0.1, seed=52)
train_raw = split_dataset["train"]
val_raw = split_dataset["test"]

model_name = "d0rj/ru-mbart-large-summ"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
#model.config.dropout = 0.30
#model.config.attention_dropout = 0.20
#model.config.classifier_dropout = 0.30

tokenizer.src_lang = "ru_RU"
tokenizer.tgt_lang = "ru_RU"

max_input_length = 1024
max_target_length = 128


def preprocess_function(example):
    encoded_inputs = tokenizer(
        example["text"],
        max_length=max_input_length,
        truncation=True,
        padding="max_length"
    )
    
    with tokenizer.as_target_tokenizer():
        encoded_targets = tokenizer(
            example["summary"],
            max_length=max_target_length,
            truncation=True,
            padding="max_length"
        )
    
    labels = [
        (tid if tid != tokenizer.pad_token_id else -100)
        for tid in encoded_targets["input_ids"]
    ]
    encoded_inputs["labels"] = labels

    return encoded_inputs

train_dataset = train_raw.map(preprocess_function, batched=False)
val_dataset = val_raw.map(preprocess_function, batched=False)

class MetricsLoggerCallback(TrainerCallback):
    def __init__(self, model, tokenizer, eval_dataset, log_file="training_log.txt", max_eval_samples=200):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.eval_dataset = eval_dataset
        self.log_file = log_file
        self.max_eval_samples = max_eval_samples

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is None:
            metrics = {}
        
        subset_size = min(len(self.eval_dataset), self.max_eval_samples)
        small_eval = self.eval_dataset.select(range(subset_size))
        
        all_preds = []
        all_labels = []
        
        for start_idx in range(0, subset_size, args.per_device_eval_batch_size):
            sub_eval = small_eval[start_idx : start_idx + args.per_device_eval_batch_size]
        
            batch_input_ids = sub_eval["input_ids"]
            batch_attn_mask = sub_eval["attention_mask"]
            batch_labels = sub_eval["labels"]
        
            input_ids = torch.tensor(batch_input_ids, dtype=torch.long).to(self.model.device)
            attention_mask = torch.tensor(batch_attn_mask, dtype=torch.long).to(self.model.device)
            labels = torch.tensor(batch_labels, dtype=torch.long).to(self.model.device)
        
            with torch.no_grad():
                generated_tokens = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_length=128,
                    num_beams=4,
                    early_stopping=True
                )
        
            preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            tgts = []
            for seq in labels.tolist():
                seq = [tokenizer.pad_token_id if tok < 0 else tok for tok in seq]
                tgts.append(seq)
            
            tgts = self.tokenizer.batch_decode(tgts, skip_special_tokens=True)

            all_preds.extend(preds)
            all_labels.extend(tgts)


        P, R, F1 = bert_score_metric(
            all_preds, 
            all_labels, 
            lang="ru",
            model_type="google-bert/bert-base-multilingual-cased",
            num_layers=9,
            verbose=False
        )

        p_mean = float(torch.mean(P))
        r_mean = float(torch.mean(R))
        f1_mean = float(torch.mean(F1))

        metrics["eval_bert_score_precision"] = p_mean
        metrics["eval_bert_score_recall"] = r_mean
        metrics["eval_bert_score_f1"] = f1_mean

        wandb.log(metrics)

        os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
        with open(self.log_file, "a", encoding="utf-8") as f:
            f.write(f"Epoch {state.epoch} evaluation metrics:\n")
            for k, v in metrics.items():
                f.write(f"{k}: {v}\n")
            f.write("\n")


optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=config.RLR_MODE,
    patience=config.RLR_PATIENCE,
    factor=config.RLR_FACTOR
)

optims = (optimizer, scheduler)

training_args = TrainingArguments(
    output_dir=f"./tuned/{config.WANDB_RUN_NAME}",
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    learning_rate=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY,
    lr_scheduler_type='reduce_lr_on_plateau',
    lr_scheduler_kwargs={'mode': config.RLR_MODE, 'patience': config.RLR_PATIENCE, 'factor': config.RLR_FACTOR},
    logging_steps=50,
    eval_strategy="epoch",
    save_total_limit=config.NUM_EPOCHS,
    save_strategy="epoch",
    report_to=["wandb"],
    overwrite_output_dir=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_bert_score_f1",
    greater_is_better=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    optimizers=optims,
)

trainer.add_callback(MetricsLoggerCallback(model, tokenizer, val_dataset, log_file=f"logs/{config.WANDB_RUN_NAME}/training_log.txt"))

wandb.watch(model, log="all")

trainer.train()

model.save_pretrained(os.path.join(training_args.output_dir, "final_model"))
tokenizer.save_pretrained(os.path.join(training_args.output_dir, "final_model"))

wandb.finish()


Running on: cuda
GPU Device Name: NVIDIA A100 80GB PCIe


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvdoninav[0m ([33mvdoninav-hse[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,2.0758,1.990182
2,1.7196,1.871663
3,1.4412,1.854473
4,1.2519,1.915237




KeyboardInterrupt: 

In [None]:
!tar -cvf lstm_sota.tar students/v3_2/e50