In [1]:
from typing import List, Dict, Tuple, Any
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import json
from transformers import T5ForConditionalGeneration, T5Tokenizer
from pathlib import Path
from utils.read_jsonl_data import read_jsonl_data
from tqdm.auto import tqdm

SEED = 12345
val_size = 0.05

random.seed(SEED)

# Train data
TRAIN_DATA_DIR = Path("data/Толока Персона Чат")
gk_1_500_path = TRAIN_DATA_DIR / "TolokaPersonaChat_gk_1_500.jsonl"
gk_test_1_500_path = TRAIN_DATA_DIR / "TolokaPersonaChat_1_500_gk_test.jsonl"
test_stipa_path = TRAIN_DATA_DIR / "gk(test)Stipa.jsonl"
genderized_gk_test_v2_path = TRAIN_DATA_DIR / "TolokaPersonaChat_genderized_gk(test)v2.jsonl"

# Test data
TEST_DATA_DIR = Path("data/test")
all_dialogs_path = TEST_DATA_DIR / "all_dialogs.jsonl"

In [2]:
tokenizer = T5Tokenizer.from_pretrained("cointegrated/rut5-small-chitchat")
model = T5ForConditionalGeneration.from_pretrained("cointegrated/rut5-small-chitchat")

In [3]:
def train_val_split(
    data: List[Dict[str, str]],
    val_size: int = val_size,
    shuffle: bool = True,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
    if shuffle:
        random.shuffle(data)
    val_len = int(val_size * len(data))
    data_train, data_val = data[:-val_len], data[-val_len:]
    return data_train, data_val

In [4]:
gk_1_500_data = read_jsonl_data(gk_1_500_path)
gk_test_1_500_data = read_jsonl_data(gk_test_1_500_path)
test_stipa_data = read_jsonl_data(test_stipa_path)
genderized_gk_test_v2_data = read_jsonl_data(genderized_gk_test_v2_path)

In [5]:
gk_1_500_data_train, gk_1_500_data_val = train_val_split(gk_1_500_data, val_size, True)
gk_test_1_500_data_train, gk_test_1_500_data_val = train_val_split(gk_test_1_500_data, val_size, True)
test_stipa_data_train, test_stipa_data_val = train_val_split(test_stipa_data, val_size, True)
genderized_gk_test_v2_data_train, genderized_gk_test_v2_data_val = train_val_split(genderized_gk_test_v2_data, val_size, True)

data_train = gk_1_500_data_train + gk_test_1_500_data_train + test_stipa_data_train + genderized_gk_test_v2_data_train
data_val = gk_1_500_data_val + gk_test_1_500_data_val + test_stipa_data_val + genderized_gk_test_v2_data_val

In [6]:
len(data_train), len(data_val)

(2838, 147)

In [7]:
def process_text(text: str) -> str:
    text = text.strip()
    return text

def compress_consecutive_statements(dialog: List[Dict[str, Any]]):
    # Сжимаем все подряд идущие высказывания одного спикера
    compressed_dialog: List[Dict[str, Any]] = list()
    
    last_person: int = dialog[0]['person']
    whole_text = [dialog[0]['text']]
    for message in dialog[1:]:
        text, person = message['text'], message['person']

        if last_person == person:
            whole_text.append(text)
        else:
            new_message = {
                "person": last_person,
                "text": " ".join(whole_text)
            }
            compressed_dialog.append(new_message)
            last_person = person
            whole_text = [text]
    
    new_message = {
        "person": last_person,
        "text": " ".join(whole_text)
    }
    compressed_dialog.append(new_message)
    
    return compressed_dialog

def make_pairs(
    data: List[Dict[str, List[Dict[str, Any]]]],
    tokenizer: T5Tokenizer,
    max_history_tokens: int,
    max_history_messages: int = 3,
    # max_target_tokens: int,
) -> List[Tuple[str, str]]:
    # Все пары "история общения -> ответ"
    pairs: List[Tuple[str, str]] = list()
    
    for data_item in tqdm(data):
        # Пары "история общения -> ответ" в рамках одного диалога
        dialog_pairs: List[Tuple[List[str], str]] = list()

        # Сжимаем все подряд идущие высказывания одного спикера
        dialog = compress_consecutive_statements(data_item['dialog'])
        
        historical_text = [dialog[0]['text']]
        for message in dialog[1:]:
            text = message['text']
            for history_messages_len in range(1, max_history_messages+1):
                if len(historical_text) >= history_messages_len:
                    dialog_pairs.append((historical_text[-history_messages_len:], text))
            
            offset = 0
            historical_text = dialog_pairs[-1][0][offset:] + [text]
            # historical_text = "</s>".join(historical_text)
            
            while len(tokenizer("</s>".join(historical_text)).input_ids) > max_history_tokens:
                offset += 1
                historical_text = dialog_pairs[-1][0][offset:] + [text]
                # historical_text = "</s>".join(historical_text)
        
        pairs.extend(dialog_pairs)
    
    return pairs

In [8]:
pairs_train = make_pairs(data_train, tokenizer, 512, 4)
pairs_val = make_pairs(data_val, tokenizer, 512, 4)
print(len(pairs_train), len(pairs_val))

  0%|          | 0/2838 [00:00<?, ?it/s]

  0%|          | 0/147 [00:00<?, ?it/s]

161496 7962


In [9]:
import pandas as pd
pairs_train = pd.DataFrame([("</s>".join(p[0]), p[1]) for p in pairs_train]) # .drop_duplicates()
pairs_val = pd.DataFrame([("</s>".join(p[0]), p[1]) for p in pairs_val]) # .drop_duplicates()

pairs_train = pairs_train[~pairs_train.duplicated()]
pairs_train = pairs_train[~(pairs_train[0].isin(pairs_val[0]) & pairs_train[1].isin(pairs_val[1]))]
pairs_val = pairs_val.drop_duplicates()

In [10]:
import torch 
from transformers import T5ForConditionalGeneration, T5Tokenizer
# raw_model = 'cointegrated/rut5-base-multitask' 
# model = T5ForConditionalGeneration.from_pretrained(raw_model).cuda();
# tokenizer = T5Tokenizer.from_pretrained(raw_model)

device = "cuda:0"

raw_model = "cointegrated/rut5-small-chitchat"
tokenizer = T5Tokenizer.from_pretrained(raw_model)
model = T5ForConditionalGeneration.from_pretrained(raw_model).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.1)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.2)

In [11]:
from tqdm.auto import trange
import random
import numpy as np

batch_size = 24 # сколько примеров показываем модели за один шаг
report_steps = 200  # раз в сколько шагов печатаем результат
epochs = 15  # сколько раз мы покажем данные модели

@torch.no_grad()
def eval(pairs, tokenizer, model) -> float:
    eval_losses = list()
    model.eval()

    pairs = pairs.sample(frac=1)
    for i in range(0, int(len(pairs) / batch_size)):
        batch = pairs.values[i * batch_size: (i + 1) * batch_size]
        # кодируем вопрос и ответ 
        x = tokenizer([p[0] for p in batch], return_tensors='pt', padding="longest").to(model.device)
        y = tokenizer([p[1] for p in batch], return_tensors='pt', padding="longest").to(model.device)
        # -100 - специальное значение, позволяющее не учитывать токены
        y.input_ids[y.input_ids == 0] = -100
        # вычисляем функцию потерь
        loss = model(
            input_ids=x.input_ids,
            attention_mask=x.attention_mask,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
        eval_losses.append(loss.item())
    
    return np.mean(eval_losses)



model.train()
losses = []

best_model = None
best_loss = 1000000

for epoch in range(epochs):
    print('EPOCH', epoch)
    pairs_train = pairs_train.sample(frac=1)
    for i in trange(0, int(len(pairs_train) / batch_size)):
        batch = pairs_train.values[i * batch_size: (i + 1) * batch_size]
        # кодируем вопрос и ответ 
        x = tokenizer([p[0] for p in batch], return_tensors='pt', padding="longest").to(model.device)
        y = tokenizer([p[1] for p in batch], return_tensors='pt', padding="longest").to(model.device)
        # -100 - специальное значение, позволяющее не учитывать токены
        y.input_ids[y.input_ids == 0] = -100
        # вычисляем функцию потерь
        loss = model(
            input_ids=x.input_ids,
            attention_mask=x.attention_mask,
            labels=y.input_ids,
            decoder_attention_mask=y.attention_mask,
            return_dict=True
        ).loss
        # делаем шаг градиентного спуска
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # печатаем скользящее среднее значение функции потерь
        losses.append(loss.item())
        if i % report_steps == 0:
            val_loss = eval(pairs_val, tokenizer, model)
            print('step', i, '| train loss', np.round(np.mean(losses[-report_steps:]), 3), '| val loss', np.round(val_loss, 3))
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
            model.train()

EPOCH 0


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.824 | val loss 3.831
step 200 | train loss 3.985 | val loss 3.628
step 400 | train loss 3.867 | val loss 3.544
step 600 | train loss 3.812 | val loss 3.496
step 800 | train loss 3.772 | val loss 3.468
step 1000 | train loss 3.717 | val loss 3.442
step 1200 | train loss 3.71 | val loss 3.422
step 1400 | train loss 3.664 | val loss 3.406
step 1600 | train loss 3.643 | val loss 3.391
step 1800 | train loss 3.637 | val loss 3.38
step 2000 | train loss 3.642 | val loss 3.365
EPOCH 1


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.635 | val loss 3.362
step 200 | train loss 3.603 | val loss 3.352
step 400 | train loss 3.602 | val loss 3.344
step 600 | train loss 3.582 | val loss 3.333
step 800 | train loss 3.559 | val loss 3.325
step 1000 | train loss 3.567 | val loss 3.318
step 1200 | train loss 3.56 | val loss 3.313
step 1400 | train loss 3.529 | val loss 3.302
step 1600 | train loss 3.554 | val loss 3.296
step 1800 | train loss 3.535 | val loss 3.292
step 2000 | train loss 3.514 | val loss 3.285
EPOCH 2


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.51 | val loss 3.283
step 200 | train loss 3.516 | val loss 3.278
step 400 | train loss 3.508 | val loss 3.275
step 600 | train loss 3.486 | val loss 3.27
step 800 | train loss 3.489 | val loss 3.264
step 1000 | train loss 3.472 | val loss 3.257
step 1200 | train loss 3.478 | val loss 3.252
step 1400 | train loss 3.452 | val loss 3.254
step 1600 | train loss 3.465 | val loss 3.245
step 1800 | train loss 3.448 | val loss 3.241
step 2000 | train loss 3.433 | val loss 3.237
EPOCH 3


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.438 | val loss 3.237
step 200 | train loss 3.439 | val loss 3.232
step 400 | train loss 3.425 | val loss 3.228
step 600 | train loss 3.425 | val loss 3.224
step 800 | train loss 3.441 | val loss 3.223
step 1000 | train loss 3.403 | val loss 3.222
step 1200 | train loss 3.42 | val loss 3.218
step 1400 | train loss 3.4 | val loss 3.211
step 1600 | train loss 3.38 | val loss 3.212
step 1800 | train loss 3.402 | val loss 3.207
step 2000 | train loss 3.411 | val loss 3.207
EPOCH 4


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.406 | val loss 3.204
step 200 | train loss 3.387 | val loss 3.207
step 400 | train loss 3.384 | val loss 3.201
step 600 | train loss 3.393 | val loss 3.198
step 800 | train loss 3.357 | val loss 3.194
step 1000 | train loss 3.371 | val loss 3.194
step 1200 | train loss 3.364 | val loss 3.19
step 1400 | train loss 3.364 | val loss 3.19
step 1600 | train loss 3.354 | val loss 3.189
step 1800 | train loss 3.357 | val loss 3.18
step 2000 | train loss 3.332 | val loss 3.181
EPOCH 5


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.34 | val loss 3.179
step 200 | train loss 3.329 | val loss 3.18
step 400 | train loss 3.326 | val loss 3.179
step 600 | train loss 3.334 | val loss 3.177
step 800 | train loss 3.336 | val loss 3.173
step 1000 | train loss 3.339 | val loss 3.174
step 1200 | train loss 3.308 | val loss 3.17
step 1400 | train loss 3.314 | val loss 3.167
step 1600 | train loss 3.317 | val loss 3.165
step 1800 | train loss 3.32 | val loss 3.164
step 2000 | train loss 3.329 | val loss 3.163
EPOCH 6


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.33 | val loss 3.165
step 200 | train loss 3.295 | val loss 3.163
step 400 | train loss 3.303 | val loss 3.161
step 600 | train loss 3.28 | val loss 3.161
step 800 | train loss 3.31 | val loss 3.159
step 1000 | train loss 3.29 | val loss 3.154
step 1200 | train loss 3.302 | val loss 3.157
step 1400 | train loss 3.278 | val loss 3.156
step 1600 | train loss 3.278 | val loss 3.153
step 1800 | train loss 3.291 | val loss 3.151
step 2000 | train loss 3.276 | val loss 3.151
EPOCH 7


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.272 | val loss 3.152
step 200 | train loss 3.27 | val loss 3.148
step 400 | train loss 3.29 | val loss 3.145
step 600 | train loss 3.238 | val loss 3.149
step 800 | train loss 3.255 | val loss 3.146
step 1000 | train loss 3.266 | val loss 3.144
step 1200 | train loss 3.267 | val loss 3.144
step 1400 | train loss 3.234 | val loss 3.142
step 1600 | train loss 3.26 | val loss 3.141
step 1800 | train loss 3.266 | val loss 3.137
step 2000 | train loss 3.245 | val loss 3.138
EPOCH 8


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.247 | val loss 3.139
step 200 | train loss 3.239 | val loss 3.14
step 400 | train loss 3.232 | val loss 3.138
step 600 | train loss 3.235 | val loss 3.138
step 800 | train loss 3.238 | val loss 3.132
step 1000 | train loss 3.237 | val loss 3.134
step 1200 | train loss 3.211 | val loss 3.135
step 1400 | train loss 3.199 | val loss 3.136
step 1600 | train loss 3.242 | val loss 3.13
step 1800 | train loss 3.235 | val loss 3.132
step 2000 | train loss 3.23 | val loss 3.132
EPOCH 9


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.221 | val loss 3.13
step 200 | train loss 3.216 | val loss 3.126
step 400 | train loss 3.209 | val loss 3.129
step 600 | train loss 3.238 | val loss 3.127
step 800 | train loss 3.204 | val loss 3.125
step 1000 | train loss 3.204 | val loss 3.128
step 1200 | train loss 3.194 | val loss 3.126
step 1400 | train loss 3.186 | val loss 3.125
step 1600 | train loss 3.186 | val loss 3.125
step 1800 | train loss 3.206 | val loss 3.125
step 2000 | train loss 3.195 | val loss 3.12
EPOCH 10


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.199 | val loss 3.125
step 200 | train loss 3.187 | val loss 3.122
step 400 | train loss 3.186 | val loss 3.125
step 600 | train loss 3.187 | val loss 3.124
step 800 | train loss 3.188 | val loss 3.117
step 1000 | train loss 3.196 | val loss 3.118
step 1200 | train loss 3.171 | val loss 3.123
step 1400 | train loss 3.157 | val loss 3.12
step 1600 | train loss 3.181 | val loss 3.12
step 1800 | train loss 3.163 | val loss 3.123
step 2000 | train loss 3.18 | val loss 3.118
EPOCH 11


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.18 | val loss 3.115
step 200 | train loss 3.18 | val loss 3.116
step 400 | train loss 3.165 | val loss 3.115
step 600 | train loss 3.147 | val loss 3.118
step 800 | train loss 3.169 | val loss 3.117
step 1000 | train loss 3.147 | val loss 3.116
step 1200 | train loss 3.151 | val loss 3.116
step 1400 | train loss 3.157 | val loss 3.118
step 1600 | train loss 3.138 | val loss 3.113
step 1800 | train loss 3.135 | val loss 3.112
step 2000 | train loss 3.173 | val loss 3.112
EPOCH 12


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.176 | val loss 3.112
step 200 | train loss 3.131 | val loss 3.113
step 400 | train loss 3.128 | val loss 3.113
step 600 | train loss 3.16 | val loss 3.113
step 800 | train loss 3.14 | val loss 3.11
step 1000 | train loss 3.148 | val loss 3.113
step 1200 | train loss 3.134 | val loss 3.114
step 1400 | train loss 3.119 | val loss 3.113
step 1600 | train loss 3.15 | val loss 3.107
step 1800 | train loss 3.115 | val loss 3.108
step 2000 | train loss 3.113 | val loss 3.11
EPOCH 13


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.114 | val loss 3.107
step 200 | train loss 3.107 | val loss 3.107
step 400 | train loss 3.13 | val loss 3.105
step 600 | train loss 3.128 | val loss 3.108
step 800 | train loss 3.122 | val loss 3.105
step 1000 | train loss 3.102 | val loss 3.107
step 1200 | train loss 3.118 | val loss 3.109
step 1400 | train loss 3.127 | val loss 3.105
step 1600 | train loss 3.096 | val loss 3.106
step 1800 | train loss 3.119 | val loss 3.105
step 2000 | train loss 3.092 | val loss 3.106
EPOCH 14


  0%|          | 0/2027 [00:00<?, ?it/s]

step 0 | train loss 3.101 | val loss 3.104
step 200 | train loss 3.094 | val loss 3.104
step 400 | train loss 3.112 | val loss 3.106
step 600 | train loss 3.1 | val loss 3.106
step 800 | train loss 3.096 | val loss 3.106
step 1000 | train loss 3.083 | val loss 3.106
step 1200 | train loss 3.101 | val loss 3.102
step 1400 | train loss 3.085 | val loss 3.105
step 1600 | train loss 3.073 | val loss 3.102
step 1800 | train loss 3.093 | val loss 3.099
step 2000 | train loss 3.095 | val loss 3.102


In [40]:
step = epoch * int(len(pairs_train) / batch_size) + i
save_dir = Path("experiments/exp1-t5-small-chitchat-finetuning") / f"checkpoints/{step}_steps/"
model.save_pretrained(save_dir)

In [37]:
save_dir = Path("experiments/exp1-t5-small-chitchat-finetuning") / "checkpoints/best_model"
best_model.save_pretrained(save_dir)

In [12]:
best_loss

3.0993898600435106

In [33]:
import time

@torch.no_grad()
def answer(history_text: str, model) -> str:
    model.eval()

    inputs = tokenizer(history_text, return_tensors='pt')
    hypotheses = model.generate(
        **{k: v.to(model.device) for k, v in inputs.items()},
        do_sample=True,
        top_p=0.5,
        num_return_sequences=1,
        repetition_penalty=1.5,
        max_length=1024,
    )
    return tokenizer.decode(hypotheses[0], skip_special_tokens=True)


history_text = ["Привет!"]
print("bot1:", history_text[0])

for idx in range(10):
    history_tmp = "</s>".join(history_text[-3:])
    text = answer(history_tmp, best_model).replace("<pad>", "").strip()
    print(f"bot{idx % 2}:", text)
    
    history_text.append(text)
    time.sleep(0.7)

bot1: Привет!
bot0: Привет, как тебя зовут?
bot1: Меня зовут Евгений. Я учитель
bot0: Чем занимаешься?
bot1: Я работаю инженером. Чем увлекаешься?
bot0: Я люблю готовить, а ты?
bot1: Я люблю читать. А ты?
bot0: Я очень люблю готовить, но это не интересно. А у тебя есть хобби?
bot1: Да, я люблю готовить. Любишь путешествовать?
bot0: Да, я люблю путешествовать. А ты?
bot1: Это здорово. Я тоже люблю путешествовать
