In [None]:
! pip -q install transformers wandb pytorch-lightning

In [2]:
import os
import logging
import pickle
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import pytorch_lightning as pl
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
import wandb
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import torchmetrics

In [None]:
SEED = 42
def set_seed(seed: int = 42, set_torch=True):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if set_torch:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
set_seed(SEED)

# Args

In [None]:
class Config():
    def __init__(self):
        self.seed = 42
        self.word_dropout = 0
        self.batch_size = 2
        self.val_size = 0.2
        self.learning_rate = 1e-4
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.max_epochs = 5
        self.total_steps = int(13832 * self.max_epochs)
        self.warmup_steps = int(0.05 * self.total_steps)

cfg = Config()

# Data

In [None]:
data = pd.read_csv("/content/data2.csv")
data.sample(5)

Unnamed: 0,context_3,context_2,context_1,response
5828,,,,в clearml можно тегами версионировать и фильтр...
6729,,,,А зеркально или параллельно учить на русском и...
12126,,,"Очень легко сделать симуляцию в GTA5, использу...",Есть какой-нибудь гайд как это сделать? Уже да...
2974,,,,"Хм, осталось 4 реакции"
2308,,,"привет всем! \n\nя Сюзанна, \nr&d в RobotMIA, ...",А что делаете в RobotMIA если не секрет?


In [None]:
data.shape

(13832, 4)

# Dataset

In [None]:
class ApplyWordDropout:
    def __init__(self, replace_with, eos_token_id, word_dropout=0.0):
        self.keep_prop = 1.0 - word_dropout
        self.replace_with = replace_with
        self.eos_token_id = eos_token_id

    def _apply_word_dropout(self, tensor):
        dropout_mask = torch.rand(tensor.shape) < self.keep_prop
        dropout_mask &= tensor != self.eos_token_id
        result = torch.where(dropout_mask, tensor, torch.full_like(tensor, self.replace_with))
        return result

    def __call__(self, sample):
        return self._apply_word_dropout(sample)

In [None]:
class ConversationDataset(Dataset):
    def __init__(self, df, cfg):
        self.tokenizer = AutoTokenizer.from_pretrained(
            'tinkoff-ai/ruDialoGPT-small',
            padding_side='left'
        )
        self.word_dropout = ApplyWordDropout(
            replace_with=self.tokenizer(self.tokenizer.unk_token)['input_ids'][0],
            eos_token_id=self.tokenizer.eos_token_id,
            word_dropout=cfg.word_dropout,
        )
        self.samples = []
        for _, sentences in df.iterrows():
            conv = self._concat_conv(sentences, self.tokenizer)
            self.samples.append(conv)
        if cfg.word_dropout:
            self.samples = [self.word_dropout(sample) for sample in self.samples]

    def _concat_conv(self, sentences, tokenizer):
        eos_list = [50257, 50258, 50257, 50258, 50257]
        conv = [
            torch.cat(
                (
                    torch.tensor([eos_list.pop()]).unsqueeze(0),
                    tokenizer(sentence, return_tensors="pt")["input_ids"],
                ),
                dim=1,
            )
            for sentence in sentences
            if sentence != ""
        ]
        conv[-1] = torch.cat(
            (
                conv[-1],
                torch.tensor([eos_list.pop()]).unsqueeze(0),
            ),
            dim=1,
        )
        conv_flat = torch.cat(conv, dim=1).view(-1)
        return conv_flat

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        return self.samples[item].to(torch.long)

In [None]:
class ConversationDataModule(pl.LightningDataModule):
    def __init__(self, data, cfg):
        super().__init__()
        train_data, val_data = train_test_split(data, test_size=cfg.val_size)
        self.train_data = train_data
        self.val_data = val_data
        self.cfg = cfg

    def setup(self, stage=None):
        self.train_dataset = ConversationDataset(self.train_data, self.cfg)
        self.val_dataset = ConversationDataset(self.val_data, self.cfg)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.cfg.batch_size, shuffle=True, collate_fn=self._collate)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.cfg.batch_size, collate_fn=self._collate)
        
    def _collate(self, examples: list[torch.Tensor]):
        max_length = max([len(ex) for ex in examples])
        padded_examples = [F.pad(ex, (max_length - len(ex), 0)) for ex in examples]
        return torch.stack(padded_examples, dim=0)

# pl Module

In [None]:
class DialoTuner(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained("tinkoff-ai/ruDialoGPT-small")
        self.cfg = cfg
        self.perplexity = torchmetrics.text.Perplexity()

    def forward(self, batch):
        inputs, labels = (batch, batch)
        return self.model(input, labels=labels)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.cfg.learning_rate)
        # Calculate the total number of training steps
        total_steps = self.cfg.total_steps
        # Create the scheduler with linear warmup and decay
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.cfg.warmup_steps,
            num_training_steps=total_steps
        )

        return [optimizer], [lr_scheduler]

    def training_step(self, train_batch, batch_idx):
        inputs, labels = (train_batch, train_batch)
        outputs = self.model(inputs, labels=labels)
        loss = outputs[0]
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        inputs, labels = (val_batch, val_batch)
        outputs = self.model(inputs, labels=labels)
        loss = outputs[0]
        perplexity_score = self.perplexity(outputs.logits, labels)
        self.log("val_loss", loss)
        self.log("val_perplexity", perplexity_score)
        return loss

    def generate(self, **kwargs):
        return self.model.generate(**kwargs)

# Trainer

In [None]:
conversation_data_module = ConversationDataModule(data, cfg)
model = DialoTuner(cfg)

In [None]:
import wandb

config_dict = {attr: getattr(cfg, attr) for attr in dir(cfg) if not callable(getattr(cfg, attr)) and not attr.startswith("__")}
logger = pl.loggers.WandbLogger(project='sandbox', config=config_dict, log_model=True)

In [None]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=cfg.max_epochs,
    logger=logger,
    log_every_n_steps=1,
    gradient_clip_val=1.0,
)

In [None]:
trainer.fit(model, datamodule=conversation_data_module)

In [None]:
trainer.validate(model, datamodule=conversation_data_module)

In [None]:
torch.save(model, 'finetuned_model_10ep_1e4')

In [None]:
wandb.finish()

In [None]:
chat_history_ids = ""
tokenizer = AutoTokenizer.from_pretrained(
            'tinkoff-ai/ruDialoGPT-small',
            padding_side='left'
        )


for step in range(5):
    chat_history_ids = chat_history_ids + "@@ПЕРВЫЙ@@ " + input(">> User: ") + "@@ВТОРОЙ@@"
    new_input_ids = tokenizer(chat_history_ids, return_tensors='pt')
    generated_token_ids = model.generate(
        **new_input_ids,
        top_k=10,
        top_p=0.95,
        num_beams=1,
        num_return_sequences=1,
        do_sample=True,
        no_repeat_ngram_size=2,
        temperature=1.7,
        repetition_penalty=1.2,
        length_penalty=1.0,
        eos_token_id=50257,
        max_new_tokens=40,
        pad_token_id=tokenizer.eos_token_id
    )

    context_with_response = tokenizer.decode(generated_token_ids[0])
    cutted_answer = context_with_response[len(chat_history_ids):]
    if "@@ПЕРВЫЙ@@" in cutted_answer:
        cutted_answer = cutted_answer.split("@@ПЕРВЫЙ@@")[0]
    if "@@ВТОРОЙ@@" in cutted_answer:
        cutted_answer = cutted_answer.split("@@ВТОРОЙ@@")[0]
    chat_history_ids = chat_history_ids + cutted_answer
    print(f"ruDialoGPT: ", cutted_answer)