In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import math
import os
import re
import string
import json
from typing import List

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor

import youtokentome as yttm
from dotmap import DotMap

from dnnutils.generate import BeamGenerator, ensure_length
from dnnutils.nn_helper import init_random_seed, train_eval_loop, predict_with_model
from dnnutils.tensorboard import SummaryWriterHelper
from sklearn.model_selection import train_test_split

from transformers import MT5ForConditionalGeneration, AutoTokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer

from torch.utils.data import Dataset, DataLoader, TensorDataset

plt.rcParams["figure.figsize"] = [6, 4]
tqdm.pandas()

SEED = 42
init_random_seed(SEED)

2023-04-05 21:42:08.647264: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-05 21:42:09.462745: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-04-05 21:42:09.462836: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
device

  return torch._C._cuda_getDeviceCount() > 0


device(type='cpu')

In [4]:
config = DotMap()
config.MAX_INPUT_TEXT_SEQ_LEN = 512
config.SEQ_SEPARATOR = '<SEP>'

In [68]:
# build dataset sequences

df_dataset = pd.DataFrame(columns=['X', 'y'])

def parse_message(row, messages):
    row_idx = row.name
    message = messages[row_idx]
    row['message_id'] = message['id']
    row['date'] = pd.to_datetime(message['date'])
    if 'from' not in message:
        return row
    row['from'] = message['from']
    row['from_id'] = message['from_id']
    if 'reply_to_message_id' in message:
        row['reply_to_message_id'] = message['reply_to_message_id']
    else:
        row['reply_to_message_id'] = pd.NA
    message_text = ''
    if 'text_entities' in message:
        for text_entity in message['text_entities']:
            message_text += ' ' + text_entity['text']
    elif 'text' in message:
        if type(message['text']) == str:
            message_text = message['text']
        else:
            for text_entity in message['text']:
                if type(text_entity) == str:
                    message_text += ' ' + text_entity
                else:
                     message_text += ' ' + text_entity['text']
    row['message_text'] = message_text
    return row

def clean(row):
    line = row['message_text']
    line = re.sub(r'http\S+', '', line)
    line = re.sub(r'\s+', ' ', line)
    line = line.strip()
    
    return line

def find_replies(df_data, message_id) -> list:
    result = []
    reply_ids = df_data.loc[df_data.reply_to_message_id == message_id].message_id
    for reply_id in reply_ids:
        replies = find_replies(df_data, reply_id)
        if len(replies) == 0:
            replies.insert(0, reply_id)
            replies.insert(0, message_id)
            result.append(replies)
        else:
            for reply in replies:
                reply.insert(0, message_id)
                result.append(reply)
    return result


def build_sequences_for_message(row, df_data):
    message_id = row.message_id
    reply_sequences = find_replies(df_data, message_id)
    new_data = []
    for seq in reply_sequences:
        begin_from = 0
        for i in range(1, len(seq)):
            X_inidces = seq[begin_from:i]
            y_indices = seq[i]
            X = df_data[df_data.message_id.isin(X_inidces)].cleaned.str.cat(sep=config.SEQ_SEPARATOR)
            y = df_data[df_data.message_id == y_indices].cleaned.iloc[0]
            if len(X) > config.MAX_INPUT_TEXT_SEQ_LEN:
                begin_from += 1
            else:
               new_data.append([X, y])
    return new_data

def prepare_dataset_from_tg_channel(file_name: str, parse_json: bool = False):
    if parse_json:
        with open('data/tg/'+tg_filename+'.json', 'r') as json_file:
            print(f'Load json file {tg_filename}')
            js = json.load(json_file)
            print(f'Parse json file {tg_filename}')
            messages_len = len(js['messages'])
            df_data = pd.DataFrame(index=np.arange(messages_len), columns=['message_id'])
            df_data = df_data.progress_apply(parse_message, messages=js['messages'], axis=1)
            
        df_data.to_csv(f'data/tg/{tg_filename}.csv', index=False)
    else:
        df_data = pd.read_csv(f'data/tg/{tg_filename}.csv')
    print(f'Prepare dataset from file {tg_filename}')
    df_data = df_data[:100_000]
    df_data = df_data.dropna(subset=['message_text'])
    df_data = df_data.drop_duplicates(subset=['message_text'], keep='first')
    df_data = df_data[df_data.message_text.str.len() <= 256].copy()
    # df_data = df_data.drop_duplicates(subset=['reply_to_message_id'], keep='first')
    df_data['cleaned'] = df_data.progress_apply(clean, axis = 1)
    df_data = df_data.loc[df_data.cleaned.str.len() >= 2]
    # Find root messages and build sequences
    sr_sequences = df_data.loc[~df_data.reply_to_message_id.isin(df_data.message_id)].progress_apply(build_sequences_for_message, df_data=df_data, axis=1)
    sr_sequences = sr_sequences[sr_sequences.apply(lambda x: len(x) > 0)]
    sr_sequences = sr_sequences.explode().apply(lambda x: pd.Series([x[0], x[1]]))
    sr_sequences.columns = ['X', 'y']
    sr_sequences.drop_duplicates(subset=['X'], keep='first', inplace=True)
    sr_sequences.reset_index(drop=True, inplace=True)
    print(f'Created {len(sr_sequences)} sequences for file {tg_filename}')
    return sr_sequences
    

tg_filenames = ['vremya_valeri', 'israel_repatriation', 'better_data_community']
df_dataset = pd.DataFrame(columns=['X', 'y'])
for tg_filename in tg_filenames:
    df_dataset = pd.concat([df_dataset, prepare_dataset_from_tg_channel(tg_filename, parse_json=False)], ignore_index=True)
df_dataset.reset_index(drop=True, inplace=True)
df_dataset

Prepare dataset from file vremya_valeri


100%|██████████| 16218/16218 [00:00<00:00, 55358.03it/s]
100%|██████████| 5747/5747 [00:15<00:00, 370.96it/s]


Created 6099 sequences for file vremya_valeri
Prepare dataset from file israel_repatriation


100%|██████████| 80482/80482 [00:01<00:00, 79079.26it/s]
100%|██████████| 31181/31181 [02:02<00:00, 253.85it/s]


Created 34174 sequences for file israel_repatriation
Prepare dataset from file better_data_community


100%|██████████| 6697/6697 [00:00<00:00, 49456.31it/s]
100%|██████████| 3688/3688 [00:05<00:00, 665.74it/s]


Created 2286 sequences for file better_data_community


Unnamed: 0,X,y
0,Попробую добавить возможность комментировать,Не работает...
1,Попробую добавить возможность комментировать<S...,"Прекрасно, ждём"
2,data-driven team management,Это скорее для корпораций. А тут вопрос в прав...
3,Недавно я улетал из Лондона и прилетал в Москв...,Кажется что правильным было бы сравнивать два ...
4,Недавно я улетал из Лондона и прилетал в Москв...,В таком случае и классы домов должны быть сопо...
...,...,...
42554,Ещё есть страшные слова metric learning и trip...,Contrastive learning посмотри
42555,"Не нашел инфу об Eqvilent. Чем занимаетесь, ка...",вот же
42556,"у нас есть джун вакансии, можешь податься и пи...",С вами я уже в процессе)
42557,На дзене))),А видео на рутубе


In [6]:
%%script echo skipping

BPE_MODEL_FILENAME = f'data/tg/{tg_filename}.yttm'
TRAIN_TEXTS_FILENAME = f'data/tg/{tg_filename}.txt'

EOS_ID = 3
BOS_ID = 2

separator = '<SEP>'
separator = ''

def write_line(row, file):
    if row.cleaned:
        file.write(row.cleaned+'\n')
    return row

def decode_sample(input_ids) -> map:
    return {
        'encoder_text': self.decode_tensor(input_ids[0]['encoder_ids']),
        'decoder_text': self.decode_tensor(input_ids[0]['decoder_ids']),
        'target_text': self.decode_tensor(input_ids[1])
    }
    
def decode_tensor(ids) -> str:
    return self.tokenizer.decode(ids.cpu().detach().numpy().tolist())[0]      

if True:
    with open(TRAIN_TEXTS_FILENAME, 'w') as outf:
        # df_data.apply(write_line, file = outf, axis=1)
        df_questions.apply(write_line, file = outf, axis=1)
        df_answers.apply(write_line, file = outf, axis=1)
    yttm.BPE.train(data=TRAIN_TEXTS_FILENAME, vocab_size=8000, model=BPE_MODEL_FILENAME)
tokenizer = yttm.BPE(BPE_MODEL_FILENAME)

skipping


In [7]:
%%script echo skipping

class CharTokenizer():
    def __init__(self, text_file_path) -> None:
        with open(text_file_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # here are all the unique characters that occur in this text
        chars = sorted(list(set(text)))
        self._vocab_size = len(chars)+1
        # create a mapping from characters to integers
        stoi = { ch:i+1 for i,ch in enumerate(chars) }
        itos = { i+1:ch for i,ch in enumerate(chars) }
        stoi['<PAD>'] = 0
        itos[0] = '<PAD>'
        self._encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
        self._decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
        
    # def encode(self, sentences: List[str], bos=True, eos=True):
    #     return [self._encode(sentence) for sentence in sentences]
    
    # def decode(self, sentences: List[str], bos=True, eos=True):
    #     return [self._decode(sentence) for sentence in sentences]
    
    def encode(self, sentence: str, bos=True, eos=True):
        return self._encode(sentence)
    
    def decode(self, sentence: str, bos=True, eos=True):
        return self._decode(sentence)
    
    def vocab_size(self):
        return self._vocab_size
        
tokenizer = CharTokenizer(TRAIN_TEXTS_FILENAME)
tokenizer.vocab_size()

skipping


In [8]:
config.mt5_model_name = "cointegrated/rut5-small"

tokenizer = T5Tokenizer.from_pretrained(config.mt5_model_name)

def decode_sample(input_ids, skip_special_tokens=False) -> map:
    return {
        'encoder_text': decode_tensor(input_ids[0]['encoder_ids'], skip_special_tokens),
        'decoder_text': decode_tensor(input_ids[0]['decoder_ids'], skip_special_tokens),
        'target_text': decode_tensor(input_ids[1], skip_special_tokens)
    }
    
def decode_tensor(ids, skip_special_tokens=False) -> str:
    return tokenizer.decode(torch.where(ids != -100, ids, tokenizer.pad_token_id), skip_special_tokens)

In [9]:
%%script echo skipping
questions_tokenized = tokenizer.batch_encode_plus(df_dataset.X.to_list(), return_attention_mask=False).input_ids
answers_tokenized = tokenizer.batch_encode_plus(df_dataset.y.to_list(), return_attention_mask=False).input_ids
questions_lengs = [len(tokenized) for tokenized in questions_tokenized]
answers_lengs = [len(tokenized) for tokenized in answers_tokenized]
_, axes = plt.subplots(nrows=1, ncols=2, sharex=False, figsize=(15, 5))
axes[0].hist(questions_lengs, bins=100);
axes[0].set_xticks(np.arange(min(questions_lengs), max(questions_lengs)+1, 50.0));
axes[1].hist(answers_lengs, bins=100);
axes[1].set_xticks(np.arange(min(questions_lengs), max(questions_lengs)+1, 50.0));

skipping


In [69]:
# %%script echo scipping

config.ENCODER_SEQ_LEN = 200
config.DECODER_SEQ_LEN = 100
class LanguageModelDataset(Dataset):
    
    def __init__(self, tokenizer, df_dataset: pd.DataFrame, is_hf=True) -> None:
        super().__init__()
        self.is_hf = is_hf
        self.df_dataset = df_dataset
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.df_dataset)

    def __getitem__(self, index):
        sample = self.df_dataset.iloc[index]
        
        question_text = sample.X
        answer_text = sample.y
        
        if self.is_hf:
            question_tokens = self.tokenizer.encode(question_text, max_length= config.ENCODER_SEQ_LEN, truncation=True, padding='max_length',return_tensors='pt')
            answer_tokens = self.tokenizer.encode('<pad>'+answer_text, max_length= config.DECODER_SEQ_LEN, truncation=True, padding='max_length',return_tensors='pt',add_special_tokens=False)
            result_tokens = self.tokenizer.encode(answer_text, max_length= config.DECODER_SEQ_LEN, truncation=True, padding='max_length',return_tensors='pt')
            src_ids = question_tokens.squeeze(0)
            trg_ids = answer_tokens.squeeze(0)
            res_ids = result_tokens.squeeze(0)
            res_ids[res_ids == self.tokenizer.pad_token_id] = -100
            return (
                {'encoder_ids': src_ids, 'decoder_ids': trg_ids},
                res_ids
            )
        else:
            question_tokens = self.tokenizer.encode(question_text, bos=False, eos=False)
            answer_tokens = self.tokenizer.encode(answer_text, bos=True, eos=True)
            trg_tokens = answer_tokens[:-1]
            res_tokens = answer_tokens[1:]
            src_ids, trg_ids, res_ids = self.ensure_length(question_tokens, config.SEQ_LEN), self.ensure_length(trg_tokens, config.SEQ_LEN), self.ensure_length(res_tokens, config.SEQ_LEN)
            src_ids = torch.tensor(src_ids)
            trg_ids = torch.tensor(trg_ids)
            res_ids = torch.tensor(res_ids)
            return (
                {'encoder_ids': src_ids, 'decoder_ids': trg_ids},
                res_ids
            )
        
train, val = train_test_split(df_dataset, test_size=0.2, random_state=SEED)
train.reset_index(drop=True, inplace=True)
val.reset_index(drop=True, inplace=True)
train_dataset = LanguageModelDataset(tokenizer, train)
val_dataset = LanguageModelDataset(tokenizer, val)
sample = train_dataset[np.random.randint(0, train_dataset.__len__())]
print(sample)
print(decode_sample(sample))

({'encoder_ids': tensor([  259,  7472,  6334,  4013,   917,  5627, 10791, 17248,  3330,   259,
         7236,  6131,   279,   261,   310,  5375,   833,   805,  7955,  3604,
        11598,  5688,   433,   260, 11896,   995,  5907,   259,  9079,  8110,
          833,  7423,  1965,  1066,   315,  5145, 12904,   261,   922,   259,
          279,   315, 12314,  3725,  1625,   260,  1051,  5441,   315,  6390,
         1323,  6136,  8371,   259,   279, 11029,  8371, 14425, 19765,   433,
          267,  1617,   261,   966,   261,  4071,   259,   279,  5401, 16600,
         6813,  2709,   399,   478,   559,   669, 16488,  1348,   388, 13373,
          324,  3942,  2766,   262,   348,   371,   737,  2622,   419,   310,
        12420,   291,   688,  5172, 14301,   892,   259,  1802,  5313,   966,
          992,   259,  5627,   259,  8072,  5029,   587,   374, 10232,  2460,
            1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     

In [11]:
# %%script echo skipping

class TransformerModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.backbone = T5ForConditionalGeneration.from_pretrained(config.mt5_model_name)

    def forward(self, src: Tensor) -> Tensor:
        src_ids = src['encoder_ids']
        trg_ids = src['decoder_ids']
        
        src_padding_mask = (src_ids != 0).type(torch.long)
        trg_padding_mask = (trg_ids != 0).type(torch.long)
        trg_padding_mask[:, 0] = 1
        
        output = self.backbone(input_ids=src_ids, attention_mask=src_padding_mask, decoder_input_ids=trg_ids, decoder_attention_mask=trg_padding_mask)

        return output.logits
    
    def generate(self, seed_token_ids, max_len=40, return_hypotheses_n=5, beamsize=5):
        output = []
        result = self.backbone.generate(
                    inputs=seed_token_ids['encoder_ids'].unsqueeze(0), 
                    num_return_sequences=return_hypotheses_n,  
                    num_beams=beamsize, 
                    max_length=max_len,
                    do_sample=True, 
                    return_dict_in_generate=True, 
                    output_scores=True, 
                )
        for i in range(return_hypotheses_n):
            output.append((result.sequences_scores[i].item(), result.sequences[i]))
        return output

model = TransformerModel().to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


64.64448 M parameters


In [12]:
%%script echo scipping

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, embedding_dim: int, nhead: int, d_hid: int,
                 nlayers: int, seq_len: int, dropout: float = 0.2):
        super().__init__()
        self.seq_len = seq_len
        self.nhead = nhead
        self.embedding_dim = embedding_dim
        self.token_embedding = nn.Embedding(ntoken, embedding_dim, padding_idx=0)
        self.pos_encoder = nn.Embedding(seq_len, embedding_dim)
        
        
        encoder_norm = nn.LayerNorm(embedding_dim)
        encoder_layers = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=nhead, dim_feedforward=d_hid, dropout=dropout, batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers, norm=encoder_norm)
        
        decoder_norm = nn.LayerNorm(embedding_dim)
        decoder_layers = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=nhead, dim_feedforward=d_hid, dropout=dropout, batch_first=True, norm_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, nlayers, norm=decoder_norm)
        
        self.lm_head = nn.Linear(embedding_dim, ntoken)
        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def src_mask(self, seq_len):
        full_mask = torch.ones(seq_len, seq_len)
        ignore_mask = torch.tril(full_mask) < 1
        full_mask.masked_fill_(ignore_mask, float('-inf'))
        full_mask.masked_fill_(~ignore_mask, 0)
        return full_mask.to(device)
    
    def make_positional_encoding(self, max_length, embedding_size):
        time = np.pi * torch.arange(0, max_length).float()
        freq_dividers = torch.arange(1, embedding_size // 2 + 1).float()
        inputs = time[:, None] / freq_dividers[None, :]
        
        result = torch.zeros(max_length, embedding_size)
        result[:, 0::2] = torch.sin(inputs)
        result[:, 1::2] = torch.cos(inputs)
        return result.to(device)

    def threeD_src_mask(self, n_heads, source_len, target_len, answers_start):
        answers_start = answers_start.repeat_interleave(n_heads)
        B = answers_start.shape[0]
        full_mask = torch.ones(B, target_len, source_len)
        ignore_mask = torch.stack([torch.tril(x, int(y)) for x,y in zip(full_mask, answers_start)]) < 1
        full_mask.masked_fill_(ignore_mask, float('-inf'))
        full_mask.masked_fill_(~ignore_mask, 0)
        return full_mask
    
    def forward(self, src: Tensor) -> Tensor:
        src_ids = src['encoder_ids']
        trg_ids = src['decoder_ids']
        srcB, srcT = src_ids.shape
        trgB, trgT = trg_ids.shape
        src_padding_mask = src_ids == 0
        trg_padding_mask = trg_ids == 0
        src_ids = self.token_embedding(src_ids) # * math.sqrt(self.embedding_dim)
        trg_ids = self.token_embedding(trg_ids) # * math.sqrt(self.embedding_dim)
        src_ids = src_ids + self.pos_encoder(torch.arange(srcT, device=device))
        trg_ids = trg_ids + self.pos_encoder(torch.arange(trgT, device=device))
        trg_attention_mask = self.src_mask(trgT)
        encoder_mem = self.transformer_encoder(src_ids, src_key_padding_mask = src_padding_mask)
        input_ids = self.transformer_decoder(tgt=trg_ids, memory = encoder_mem, tgt_mask = trg_attention_mask, tgt_key_padding_mask = trg_padding_mask, memory_key_padding_mask = src_padding_mask)
        logits = self.lm_head(input_ids)

        return logits

seq_len = config.SEQ_LEN
ntokens = tokenizer.vocab_size()  # size of vocabulary
emsize = 480  # embedding dimension
d_hid = 480  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 8  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability

model = TransformerModel(ntoken = ntokens, embedding_dim=emsize, nhead=nhead, d_hid=d_hid, nlayers=nlayers, seq_len=seq_len, dropout=dropout)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

scipping


In [73]:
sample_indices = np.random.randint(0, len(val_dataset), 3, dtype=np.int32)
samples = []
for sample_idx in sample_indices:
    sample = val_dataset[sample_idx]
    sample_text = decode_sample(sample, skip_special_tokens=True)
    samples.append(
        {
            'encoder_ids': sample[0]['encoder_ids'],
            'decoder_ids': sample[0]['decoder_ids'],
            'target_ids': sample[1],
            'encoder_text': sample_text['encoder_text'],
            'target_text': sample_text['target_text'],
        }
    )

def generate_sample(epoch_i, model, tensorboard: SummaryWriterHelper=None):
    for sample in samples:
        encoder_ids = sample['encoder_ids'].to(device)
        decoder_ids = sample['decoder_ids'][:5].to(device)
        seed_token_ids = { 'encoder_ids' : encoder_ids, 'decoder_ids' : decoder_ids }
        hypothesis_list = model.generate(seed_token_ids=seed_token_ids, max_len=config.DECODER_SEQ_LEN, return_hypotheses_n=2, beamsize=5)
        for score, hypothesis in hypothesis_list:
            predicted_text = decode_tensor(hypothesis, skip_special_tokens=True)
            sample_table = {
                'epoch_i': epoch_i,
                'score': score,
                'encoder_text' : sample['encoder_text'],
                'predicted_text' : predicted_text,
                'target_text' : sample['target_text'],
            }
            print(f"Sample generation={sample_table}")
            if tensorboard:
                tensorboard.add_text("sample", str(sample_table))
generate_sample(-1, model, None)

Sample generation={'epoch_i': -1, 'score': -1.0803672075271606, 'encoder_text': 'Может там какая-то нейросеть ответы шлет ))<SEP>Мне вот тоже интересно) а если вручную, то думаю анкету убрать. У меня там с ней 110 страниц получилось😹<SEP>110 это жестко) я джипегами приложил отдельно 9 документов, потом уже когда отправил понял что некоторые доки забыл, потом читаю люди пдфами шлют, вот и думаю че делать ждать или повторно слать<SEP>Ну вот и я думаю. Там наверное устанут их листать😂 их них 36 вроде анкета. Остальное доки', 'predicted_text': 'Да, я так понимаю, что если я не правильно поняла, то я не уверена что они не уверены, что они могут сделать анкету, а потом отправить и повторно слать и т.п. 🤷 ♀️🤷 ♀️🤷 ♀️🤷 ��', 'target_text': '74 дока да они в шоке там'}
Sample generation={'epoch_i': -1, 'score': -1.2391160726547241, 'encoder_text': 'Может там какая-то нейросеть ответы шлет ))<SEP>Мне вот тоже интересно) а если вручную, то думаю анкету убрать. У меня там с ней 110 страниц получилос

In [None]:
config.BATCH_SIZE = 6

def cross_entropy_loss(predict, target):
    B, T, C = predict.shape
    predict = predict.view(B*T, C)
    target = target.view(B*T)
    loss = F.cross_entropy(predict, target, ignore_index=-100)
    return loss

writer = SummaryWriterHelper()

def scheduler(optim): return \
    torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=3, factor=0.5, verbose=True)

best_val_loss, best_model = train_eval_loop(model=model,
                                            device=device,
                                            train_dataset=train_dataset,
                                            val_dataset=val_dataset,
                                            criterion=cross_entropy_loss,
                                            lr=1e-4,
                                            epoch_n=200,
                                            batch_size=config.BATCH_SIZE,
                                            l2_reg_alpha=1e-2,
                                            lr_scheduler_ctor=scheduler,
                                            shuffle_train=True,
                                            tensorboard=writer,
                                            on_epoch_cb=generate_sample
                                            )
print("saving best model...")
torch.save(best_model.state_dict(), 'data/model.pth')
writer.close()

In [72]:
model.load_state_dict(torch.load(f'data/model.pth', map_location=device))
best_model = model

In [80]:
input_ids = tokenizer.encode('Странная модель получилась', max_length= config.ENCODER_SEQ_LEN, truncation=True, padding='max_length',return_tensors='pt').to(device)
output_ids = model.backbone.generate(
                    inputs=input_ids, 
                    num_return_sequences=1,  
                    num_beams=5, 
                    max_length=config.DECODER_SEQ_LEN,
                    do_sample=True,
                )
tokenizer.decode(output_ids[0])

'<pad> Да, вот и я думал, что это новая версия, а я и не уверена, что может быть ещё и не сильно предсказуемая.</s>'