In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, MBartForConditionalGeneration

from usecrets import WANDB_API_KEY
from lstm import BiLSTMSeq2SeqMHA_Residual
from distill import collate_fn, preprocess_function, train_distillation


from config_distill_d6 import (
    TEACHER_MODEL_NAME,
    MAX_SOURCE_LEN,
    MAX_TARGET_LEN,
    EMBED_DIM,
    ENC_HIDDEN_DIM,
    DEC_HIDDEN_DIM,
    NUM_LAYERS,
    DROPOUT,
    MHA_NUM_HEADS,
    BATCH_SIZE,
    LEARNING_RATE,
    NUM_EPOCHS,
    TEMPERATURE,
    BEAM_SIZE,
    BEAM_MAX_LENGTH,
    WANDB_PROJECT,
    WANDB_RUN_NAME,
    MIN_LEN,
)


# ====================== CONFIG / CONSTANTS ======================
seed = 52
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"Running on: {torch.cuda.get_device_name()}")
    
os.environ["WANDB_API_KEY"] = WANDB_API_KEY


# ====================== MAIN ======================
def main():
    wandb.init(
        project=WANDB_PROJECT,
        name=WANDB_RUN_NAME,
        config={
            "num_train_epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            "temperature": TEMPERATURE,
            "beam_size": BEAM_SIZE,
            "teacher": TEACHER_MODEL_NAME,
            "embed_dim": EMBED_DIM,
            "ENC_HIDDEN_DIM": ENC_HIDDEN_DIM,
            "DEC_HIDDEN_DIM": DEC_HIDDEN_DIM,
            "NUM_LAYERS": NUM_LAYERS,
            "DROPOUT": DROPOUT,
            "MHA_NUM_HEADS": MHA_NUM_HEADS,
            "LEARNING_RATE": LEARNING_RATE,
        },
    )

    teacher_model = MBartForConditionalGeneration.from_pretrained(
        TEACHER_MODEL_NAME
    ).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME)

    tokenizer.src_lang = "ru_RU"
    tokenizer.tgt_lang = "ru_RU"

    dataset = load_dataset("json", data_files="train_smart.jsonl")["train"]
    split_data = dataset.train_test_split(test_size=0.025, seed=52)
    train_raw = split_data["train"]
    val_raw = split_data["test"]

    train_ds = train_raw.map(
        lambda x: preprocess_function(x, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN),
        batched=False,
    )
    val_ds = val_raw.map(
        lambda x: preprocess_function(x, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN),
        batched=False,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
    )

    vocab_size = len(tokenizer)
    student_model = BiLSTMSeq2SeqMHA_Residual(
        vocab_size=vocab_size,
        embed_dim=EMBED_DIM,
        enc_hidden_dim=ENC_HIDDEN_DIM,
        dec_hidden_dim=DEC_HIDDEN_DIM,
        pad_idx=tokenizer.pad_token_id,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT,
        num_heads=MHA_NUM_HEADS,
    ).to(DEVICE)

    train_distillation(
        teacher_model=teacher_model,
        student_model=student_model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        tokenizer=tokenizer,
        num_epochs=NUM_EPOCHS,
        lr=LEARNING_RATE,
        temperature=TEMPERATURE,
        device=DEVICE,
        wandb_run_name=WANDB_RUN_NAME
    )

    wandb.watch(student_model, log="all")

    torch.save(
        student_model.state_dict(), f"students/{WANDB_RUN_NAME}/student_model.pt"
    )
    tokenizer.save_pretrained(f"students/{WANDB_RUN_NAME}")


if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using device: cuda
Running on: NVIDIA A100 80GB PCIe


[34m[1mwandb[0m: Currently logged in as: [33mvdoninav[0m ([33mvdoninav-hse[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



===== Epoch 1/50 =====


                                                                                           

Distill Loss (Train): 4.7873


                                                                                           

Validation -- CE: 3.7744, BERT-P: 0.6696, R: 0.6081, F1: 0.6360

===== Epoch 2/50 =====


Training (epoch 2):  73%|████████   | 1490/2037 [07:28<02:45,  3.30it/s, batch_loss=4.4137]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Distill Loss (Train): 3.9371


                                                                                           

Validation -- CE: 3.1015, BERT-P: 0.6684, R: 0.6379, F1: 0.6521

===== Epoch 4/50 =====


                                                                                           

Distill Loss (Train): 3.8499


                                                                                           

Validation -- CE: 2.9644, BERT-P: 0.6401, R: 0.6381, F1: 0.6377

===== Epoch 5/50 =====


                                                                                           

Distill Loss (Train): 3.7943


                                                                                           

Validation -- CE: 2.9219, BERT-P: 0.6626, R: 0.6499, F1: 0.6556

===== Epoch 6/50 =====


                                                                                           

Distill Loss (Train): 3.7540


                                                                                           

Validation -- CE: 2.8398, BERT-P: 0.6563, R: 0.6469, F1: 0.6508

===== Epoch 7/50 =====


                                                                                           

Distill Loss (Train): 3.7233


                                                                                           

Validation -- CE: 2.8169, BERT-P: 0.6605, R: 0.6528, F1: 0.6559

===== Epoch 8/50 =====


Training (epoch 8):   3%|▎            | 51/2037 [00:15<09:55,  3.34it/s, batch_loss=3.7203]