In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from pathlib import Path
import torch
from transformers import AutoTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader
from torch import nn
from torch.functional import F
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
from datasets import Dataset
import numpy as np

from src.data_utils import get_sample_from_row_original, filter_irrelevant, prune_frequent_samples
from src.inference_utils import predict
from src.metrics import lemmatization_accuracy

In [2]:
CURDIR = Path.cwd()

DATADIR = CURDIR / "data" / "original"
assert DATADIR.exists()

MODELS_DIR = CURDIR / "models"
assert MODELS_DIR.exists()

TEACHER_ID = MODELS_DIR / 'baseline'
assert TEACHER_ID.exists()

STUDENT_ID = MODELS_DIR / "checkpoint_120226"
assert STUDENT_ID.exists()

RESULT_MODEL_DIR = MODELS_DIR / "checkpoint_120226_2"
if not RESULT_MODEL_DIR.exists():
    RESULT_MODEL_DIR.mkdir()
RESULT_MODEL_PATH = RESULT_MODEL_DIR / "model.pt"

MAX_LENGTH = 512
DEVICE = "cuda"

In [3]:
df_train = pd.read_csv(DATADIR / "train.csv", index_col=0, sep="\t")
df_dev = pd.read_csv(DATADIR / "dev.csv", index_col=0, sep="\t")

In [4]:
tokenizer = AutoTokenizer.from_pretrained(TEACHER_ID)
teacher = BartForConditionalGeneration.from_pretrained(TEACHER_ID).to("cuda")
student = BartForConditionalGeneration.from_pretrained(STUDENT_ID).to("cuda")

teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False

In [5]:
df_train["sample"] = df_train.apply(lambda row: get_sample_from_row_original(row)[0], axis=1)
df_dev["sample"] = df_dev.apply(lambda row: get_sample_from_row_original(row)[0], axis=1)

df_train.shape, df_dev.shape

((2150060, 7), (255992, 7))

In [6]:
df_train = filter_irrelevant(df_train)
df_dev = filter_irrelevant(df_dev)

In [7]:
df_train.shape, df_dev.shape

((2135295, 7), (254996, 7))

In [8]:
df_train = prune_frequent_samples(df_train)  # здесь подрезаем несбалансированное начальное распределение
df_dev = df_dev.drop_duplicates(subset=["sample"])  # отсюда просто удаляем все дубли чтобы честно мериться

In [9]:
df_train.shape, df_dev.shape

((958668, 8), (52827, 7))

In [10]:
df_train = df_train[["sample", "lemma"]]
df_dev = df_dev[["sample", "lemma"]]

In [11]:
train = Dataset.from_pandas(
    df_train[["sample", "lemma"]],
).rename_columns({
    "sample": "source",
    "lemma": "target",
}).shuffle(seed=42)

In [12]:
def tokenize_function(examples,):

    model_inputs = tokenizer(
        examples["source"],
        max_length=70,
        truncation=True,
        padding=False,
    )

    labels = tokenizer(
        examples["target"],
        max_length=70,
        truncation=True,
        padding=False,
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [13]:
tokenized_train = train.map(
    tokenize_function,
    batched=True,
    batch_size=1000,
    remove_columns=train.column_names,
)

Map: 100%|██████████| 958668/958668 [01:03<00:00, 15208.84 examples/s]


In [14]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt",
    label_pad_token_id=-100,
)

In [15]:
BATCH_SIZE = 56  # 56? 48

In [16]:
train_dataloader = DataLoader(
    tokenized_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=data_collator,
    num_workers=1,
    pin_memory=True,
    drop_last=False,
)

In [17]:
class KLDivAndCELoss(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.kldiv = nn.KLDivLoss(reduction="batchmean")
        self.alpha = kwargs.get("alpha", 0.5)
        self.antialpha = 1 - self.alpha
        self.temperature = kwargs.get("temperature", 1.0)

        self.ce = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, student_logits, teacher_logits, targets=None):

        _, _, vocab_size = student_logits.shape

        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) #* (self.temperature ** 2)

        kd_loss = self.kldiv(student_log_probs, teacher_probs)  # маскировать PAD?

        ## плохо работает с таргетами

        # student_logits_flat = student_logits.contiguous().view(-1, vocab_size)
        # targets_flat = targets.contiguous().view(-1)

        # ce_loss = self.ce(student_logits_flat, targets_flat)

        # result = self.alpha * kd_loss + (1 - self.alpha) * ce_loss

        result = kd_loss
        return result

In [18]:
def distill_batch(
    teacher: BartForConditionalGeneration,
    student: BartForConditionalGeneration,
    batch,
    criterion:KLDivAndCELoss,
    optimizer,
    scheduler,
):

    with torch.no_grad():
        teacher_generation = teacher.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_length=32,
            num_beams=1,
            early_stopping=True,
            return_dict_in_generate=True,
            output_scores=True,
        )

    target_sequences = teacher_generation.sequences
    teacher_logits = torch.stack(teacher_generation.scores, dim=1)

    decoder_input_ids = target_sequences[:, :-1]
    student_outputs = student(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        decoder_input_ids=decoder_input_ids,
    )
    student_logits = student_outputs.logits

    labels = target_sequences[:, 1:]

    loss = criterion(
        student_logits=student_logits,
        teacher_logits=teacher_logits,
        targets=labels
    )
    loss.backward()

    torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    return loss

In [19]:
def distillation_epoch(
    teacher: BartForConditionalGeneration,
    student: BartForConditionalGeneration,
    iterator: DataLoader,
    criterion,
    optimizer,
    scheduler,
    device=DEVICE,
):

    teacher.eval()
    student.train()

    losses = []

    pbar = tqdm(iterator, desc="Training", unit="bs")

    for batch_id, batch in enumerate(pbar):
        batch = {k: val.to(device) for k, val in batch.items()}
        loss = distill_batch(teacher, student, batch, criterion, optimizer, scheduler)
        loss_item = loss.item()
        losses.append(loss_item)

        if batch_id % 100 == 0:
            pbar.set_description(f"Training (loss: {np.mean(losses):.6f})")

    return np.mean(losses)

In [20]:
@torch.no_grad()
def validate(
    model: BartForConditionalGeneration,
    tokenizer:AutoTokenizer,
    df: pd.DataFrame,
    device=DEVICE,
):
    model.eval()

    preds = predict(
        df["sample"].tolist(),
        model=model,
        tokenizer=tokenizer,
        device=device,
    )
    targets = df["lemma"].tolist()

    return lemmatization_accuracy(targets, preds)

In [21]:
# optimizer = torch.optim.AdamW(student.parameters(), lr=2e-4)
# optimizer = torch.optim.SGD(student.parameters(), lr=5e-3)
# optimizer = torch.optim.SGD(student.parameters(), lr=5e-4)
# optimizer = torch.optim.SGD(student.parameters(), lr=5e-5)
optimizer = torch.optim.SGD(student.parameters(), lr=3e-5)
# optimizer = torch.optim.AdamW(
#     student.parameters(),
#     lr=3e-4,
#     weight_decay=0.01,
#     betas=(0.9, 0.999)
# )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)  # ничего не делающий скедьюлер
criterion = KLDivAndCELoss(alpha=0.9, temperature=1.3)
# criterion = KLDivAndCELoss(alpha=0.9, temperature=1.15)

In [22]:
EPOCHS = 1

In [23]:
highest_val_acc = 0.95
patience = 0
train_losses = []
val_losses = []

max_patience = 2

for epoch in range(1, EPOCHS+1):

    print(f"Epoch {epoch}/{EPOCHS}")

    val_acc = validate(
        student,
        tokenizer,
        df_dev,
        DEVICE,
    )

    print(f"Val Acc: {val_acc:.4f}")

    train_loss = distillation_epoch(
        teacher,
        student,
        train_dataloader,
        criterion,
        optimizer,
        scheduler,
        DEVICE,
    )

    print(f"Train loss: {train_loss:.4f}")

    val_acc = validate(
        student,
        tokenizer,
        df_dev,
        DEVICE,
    )

    print(f"Val Acc: {val_acc:.4f}")

    if val_acc > highest_val_acc:
        torch.save(student.state_dict(), RESULT_MODEL_PATH)
        highest_val_acc = val_acc
        patience = 0
        best_epoch = epoch
    else:
        patience += 1
        print(f"patience {patience}/{max_patience}")
        if patience == max_patience:
            print("Early Stopping!")
            break

Epoch 1/1


1651it [01:33, 17.57it/s]                          


Val Acc: 0.9500


Training:   0%|          | 0/17120 [00:00<?, ?bs/s]You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Training (loss: 0.312932): 100%|██████████| 17120/17120 [1:33:13<00:00,  3.06bs/s]


Train loss: 0.3129


1651it [01:33, 17.57it/s]                          


Val Acc: 0.9500


In [24]:
student.save_pretrained(RESULT_MODEL_DIR)

In [None]:
tokenizer.save_pretrained(RESULT_MODEL_DIR)

In [26]:
val_acc = validate(
        teacher,
        tokenizer,
        df_dev,
        DEVICE,
    )

print(f"Val Acc: {val_acc:.4f}")

1651it [02:57,  9.28it/s]                          

Val Acc: 0.9697



