In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader

from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, MBartForConditionalGeneration

from usecrets import WANDB_API_KEY
from lstm import BiLSTM
from distill import collate_fn, preprocess_function, train_distillation


from config_distill_v3_2 import (
    TEACHER_MODEL_NAME,
    MAX_SOURCE_LEN,
    MAX_TARGET_LEN,
    EMBED_DIM,
    ENC_HIDDEN_DIM,
    DEC_HIDDEN_DIM,
    NUM_LAYERS,
    DROPOUT,
    BATCH_SIZE,
    LEARNING_RATE,
    NUM_EPOCHS,
    TEMPERATURE,
    BEAM_SIZE,
    BEAM_MAX_LENGTH,
    WANDB_PROJECT,
    WANDB_RUN_NAME,
    MIN_LEN,
)


# ====================== CONFIG / CONSTANTS ======================
seed = 52
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"Running on: {torch.cuda.get_device_name()}")
    
os.environ["WANDB_API_KEY"] = WANDB_API_KEY


# ====================== MAIN ======================
def main():
    wandb.init(
        project=WANDB_PROJECT,
        name=WANDB_RUN_NAME,
        config={
            "num_train_epochs": NUM_EPOCHS,
            "batch_size": BATCH_SIZE,
            "temperature": TEMPERATURE,
            "beam_size": BEAM_SIZE,
            "teacher": TEACHER_MODEL_NAME,
            "embed_dim": EMBED_DIM,
            "ENC_HIDDEN_DIM": ENC_HIDDEN_DIM,
            "DEC_HIDDEN_DIM": DEC_HIDDEN_DIM,
            "NUM_LAYERS": NUM_LAYERS,
            "DROPOUT": DROPOUT,
            "LEARNING_RATE": LEARNING_RATE,
        },
    )

    teacher_model = MBartForConditionalGeneration.from_pretrained(
        TEACHER_MODEL_NAME
    ).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME)

    tokenizer.src_lang = "ru_RU"
    tokenizer.tgt_lang = "ru_RU"

    dataset = load_dataset("json", data_files="train_smart.jsonl")["train"]
    split_data = dataset.train_test_split(test_size=0.025, seed=52)
    train_raw = split_data["train"]
    val_raw = split_data["test"]

    train_ds = train_raw.map(
        lambda x: preprocess_function(x, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN),
        batched=False,
    )
    val_ds = val_raw.map(
        lambda x: preprocess_function(x, tokenizer, MAX_SOURCE_LEN, MAX_TARGET_LEN),
        batched=False,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
    )

    vocab_size = len(tokenizer)
    student_model = BiLSTM(
        vocab_size=vocab_size,
        embed_dim=EMBED_DIM,
        enc_hidden_dim=ENC_HIDDEN_DIM,
        dec_hidden_dim=DEC_HIDDEN_DIM,
        pad_idx=tokenizer.pad_token_id,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT,
    ).to(DEVICE)

    train_distillation(
        teacher_model=teacher_model,
        student_model=student_model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        tokenizer=tokenizer,
        num_epochs=NUM_EPOCHS,
        lr=LEARNING_RATE,
        temperature=TEMPERATURE,
        device=DEVICE,
        wandb_run_name=WANDB_RUN_NAME
    )

    wandb.watch(student_model, log="all")

    torch.save(
        student_model.state_dict(), f"students/{WANDB_RUN_NAME}/student_model.pt"
    )
    tokenizer.save_pretrained(f"students/{WANDB_RUN_NAME}")


if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using device: cuda
Running on: NVIDIA A100 80GB PCIe


[34m[1mwandb[0m: Currently logged in as: [33mvdoninav[0m ([33mvdoninav-hse[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



===== Epoch 1/50 =====


                                                                                           

Distill Loss (Train): 4.7004


                                                                                           

Validation -- CE: 3.6556, BERT-P: 0.6618, R: 0.6171, F1: 0.6380

===== Epoch 2/50 =====


                                                                                           

Distill Loss (Train): 4.2218


                                                                                           

Validation -- CE: 3.2987, BERT-P: 0.5599, R: 0.6132, F1: 0.5839

===== Epoch 3/50 =====


Training (epoch 3):  53%|█████▊     | 1084/2037 [06:47<03:55,  4.04it/s, batch_loss=4.3772]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 3.0904, BERT-P: 0.6336, R: 0.6349, F1: 0.6338

===== Epoch 5/50 =====


                                                                                           

Distill Loss (Train): 4.0087


                                                                                           

Validation -- CE: 3.0453, BERT-P: 0.5924, R: 0.6188, F1: 0.6042

===== Epoch 6/50 =====


                                                                                           

Validation -- CE: 2.9760, BERT-P: 0.6197, R: 0.6263, F1: 0.6225

===== Epoch 7/50 =====


Training (epoch 7):  72%|███████▉   | 1473/2037 [06:36<04:01,  2.34it/s, batch_loss=3.4398]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Distill Loss (Train): 3.9460


                                                                                           

Distill Loss (Train): 3.9373


                                                                                           

Validation -- CE: 2.9248, BERT-P: 0.6314, R: 0.6369, F1: 0.6336

===== Epoch 10/50 =====


                                                                                           

Distill Loss (Train): 3.9222


                                                                                           

Validation -- CE: 2.9190, BERT-P: 0.6364, R: 0.6332, F1: 0.6344

===== Epoch 11/50 =====


                                                                                           

Distill Loss (Train): 3.9143


                                                                                           

Distill Loss (Train): 3.9067


                                                                                           

Validation -- CE: 2.8760, BERT-P: 0.6602, R: 0.6494, F1: 0.6544

===== Epoch 13/50 =====


                                                                                           

Distill Loss (Train): 3.8995


                                                                                           

Validation -- CE: 2.8634, BERT-P: 0.6585, R: 0.6518, F1: 0.6549

===== Epoch 14/50 =====


                                                                                           

Distill Loss (Train): 3.8938


                                                                                           

Validation -- CE: 2.8812, BERT-P: 0.6471, R: 0.6387, F1: 0.6425

===== Epoch 15/50 =====


                                                                                           

Distill Loss (Train): 3.8913


                                                                                           

Validation -- CE: 2.8606, BERT-P: 0.6610, R: 0.6570, F1: 0.6586

===== Epoch 16/50 =====


                                                                                           

Distill Loss (Train): 3.8853


                                                                                           

Validation -- CE: 2.8477, BERT-P: 0.6575, R: 0.6520, F1: 0.6543

===== Epoch 17/50 =====


                                                                                           

Distill Loss (Train): 3.8792


                                                                                           

Validation -- CE: 2.8511, BERT-P: 0.6619, R: 0.6530, F1: 0.6571

===== Epoch 18/50 =====


                                                                                           

Distill Loss (Train): 3.8753


                                                                                           

Validation -- CE: 2.8226, BERT-P: 0.6647, R: 0.6574, F1: 0.6607

===== Epoch 19/50 =====


Training (epoch 19):  98%|█████████▊| 1998/2037 [13:01<00:09,  4.05it/s, batch_loss=3.8169]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Training (epoch 30):  99%|█████████▉| 2020/2037 [14:24<00:07,  2.32it/s, batch_loss=3.3071]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Training (epoch 31):  94%|█████████▍| 1924/2037 [12:29<00:48,  2.34it/s, batch_loss=4.3457]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the

Distill Loss (Train): 3.8395


Training (epoch 33):  79%|███████▉  | 1606/2037 [07:31<03:04,  2.33it/s, batch_loss=3.5025]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7883, BERT-P: 0.6791, R: 0.6672, F1: 0.6727

===== Epoch 34/50 =====


Training (epoch 34):  73%|███████▎  | 1488/2037 [06:15<02:24,  3.81it/s, batch_loss=4.5951]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7659, BERT-P: 0.6754, R: 0.6645, F1: 0.6697

===== Epoch 35/50 =====


Training (epoch 35):  73%|███████▎  | 1486/2037 [07:15<02:20,  3.92it/s, batch_loss=4.1672]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7755, BERT-P: 0.6801, R: 0.6656, F1: 0.6724

===== Epoch 36/50 =====


                                                                                           

Distill Loss (Train): 3.8328


                                                                                           

Validation -- CE: 2.7694, BERT-P: 0.6791, R: 0.6667, F1: 0.6725

===== Epoch 37/50 =====


Training (epoch 37):  62%|██████▏   | 1268/2037 [08:55<03:15,  3.93it/s, batch_loss=4.4519]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Distill Loss (Train): 3.8299


                                                                                           

Validation -- CE: 2.7447, BERT-P: 0.6740, R: 0.6638, F1: 0.6685

===== Epoch 39/50 =====


Training (epoch 39):  26%|██▊        | 525/2037 [03:42<10:39,  2.36it/s, batch_loss=3.8303]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Distill Loss (Train): 3.8285


Training (epoch 40):  80%|████████  | 1632/2037 [09:49<02:52,  2.35it/s, batch_loss=3.5797]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7557, BERT-P: 0.6746, R: 0.6639, F1: 0.6687

===== Epoch 42/50 =====


Training (epoch 42):  73%|███████▎  | 1477/2037 [07:44<03:58,  2.35it/s, batch_loss=3.3946]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7593, BERT-P: 0.6794, R: 0.6670, F1: 0.6727

===== Epoch 43/50 =====


                                                                                           

Distill Loss (Train): 3.8243


                                                                                           

Validation -- CE: 2.7496, BERT-P: 0.6720, R: 0.6632, F1: 0.6672

===== Epoch 44/50 =====


                                                                                           

Distill Loss (Train): 3.8231


BERT eval:  30%|████████████▉                              | 16/53 [03:21<07:25, 12.05s/it]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7622, BERT-P: 0.6810, R: 0.6686, F1: 0.6743

===== Epoch 45/50 =====


Training (epoch 45):  73%|███████▎  | 1489/2037 [06:08<02:14,  4.07it/s, batch_loss=4.1602]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7524, BERT-P: 0.6725, R: 0.6627, F1: 0.6671

===== Epoch 46/50 =====


Training (epoch 46):  73%|███████▎  | 1484/2037 [06:08<02:16,  4.05it/s, batch_loss=3.3357]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7540, BERT-P: 0.6769, R: 0.6622, F1: 0.6692

===== Epoch 47/50 =====


Training (epoch 47):  73%|███████▎  | 1487/2037 [08:10<02:15,  4.06it/s, batch_loss=3.7393]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

                                                                                           

Validation -- CE: 2.7503, BERT-P: 0.6852, R: 0.6672, F1: 0.6758

===== Epoch 48/50 =====


Training (epoch 48):  73%|███████▎  | 1487/2037 [10:35<03:53,  2.35it/s, batch_loss=3.6604]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Training (epoch 49):  73%|███████▎  | 1493/2037 [10:38<03:53,  2.33it/s, batch_loss=3.4050]