In [1]:
import os

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from transformers import T5TokenizerFast, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers.optimization import Adafactor, AdafactorSchedule

In [2]:
device = torch.device('cuda')

In [3]:
def cyrillic_to_latin(text):
    with open('../data/kbd cyrillic-latin alphabet table.txt', 'r', encoding='utf-8') as alphabet_table:
        for line in alphabet_table:
            key, value = line.split(':')
            text = text.replace(key, value.replace('\n', ''))
    return text


def replace_cyrillic_with_latin(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()

    text = cyrillic_to_latin(text)

    _path, _filename = os.path.split(file_path)
    new_file_path = _path + '/latin_' + _filename
    with open(new_file_path, 'w', encoding='utf-8') as file:
        file.write(text)

### Dataset

In [4]:
dataset = load_dataset("text", data_files={"train": '../data/books/kbd-ru*.txt'})

Using custom data configuration default-e417feb4da9299e9
Reusing dataset text (C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
len(dataset['train'])

17754

In [6]:
def clear_cache():
    # To flush GPU memory
    import torch, gc
    gc.collect()
    torch.cuda.empty_cache()

In [7]:
max_tokenization_length = 128


def encode(text, _tokenizer, is_input):
    text = [('kbd->ru: ' if is_input else '') + s for s in text]
    return _tokenizer(text,
                      padding='longest',
                      max_length=max_tokenization_length,
                      truncation=True,
                      return_attention_mask=True,
                      return_tensors="pt")


def encode_batch(raw_batch, _tokenizer):
    input_texts = []
    target_texts = []

    for translation_pair in raw_batch['text']:
        # ### ru-kbd
        # target_text, input_text = translation_pair.split('😀')
        # target_text = cyrillic_to_latin(target_text) ### delete

        ## kbd-ru
        input_text, target_text = translation_pair.split('😀')
        # input_text = cyrillic_to_latin(input_text)  ### delete # TODO: uncomment

        input_texts.append(input_text)
        target_texts.append(target_text)

    encoding = encode(input_texts, _tokenizer, is_input=True)
    input_ids, input_attention_masks_ids = encoding.input_ids, encoding.attention_mask

    target_encoding = encode(target_texts, _tokenizer, is_input=False)
    target_ids, target_attention_masks_ids = target_encoding.input_ids, target_encoding.attention_mask

    input_ids.to(device)
    input_attention_masks_ids.to(device)
    target_ids.to(device)
    target_attention_masks_ids.to(device)

    return input_ids, input_attention_masks_ids, target_ids, target_attention_masks_ids


def get_data_generator(epoch_id, _dataset, _tokenizer, _batch_size=32):
    shuffled_dataset = _dataset.shuffle(seed=1+epoch_id)['train']
    # shuffled_dataset.set_epoch(epoch_id)  #https://huggingface.co/docs/datasets/dataset_streaming.html#reshuffle-the-dataset-at-each-epoch
    for i in range(0, len(shuffled_dataset), _batch_size):
        raw_batch = shuffled_dataset[i:i + _batch_size]
        yield encode_batch(raw_batch, _tokenizer)

### Model path

In [8]:
# model_path = 'anzorq/kbd_lat-835k_ru-3M_t5-small'
model_path = 'google/mt5-small'

### Tokenizer

In [9]:
tokenizer = T5TokenizerFast.from_pretrained(model_path, extra_ids=0)

### Model

In [10]:
model = MT5ForConditionalGeneration.from_pretrained(model_path)
model.to(device)

MT5ForConditionalGeneration(
  (shared): Embedding(250112, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(250112, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (w

### Training

In [11]:
# Constants
n_epochs = 30
batch_size = 12
print_freq = 10
# checkpoint_freq = 2500
lr = 1e-4
n_batches = int(np.ceil(len(dataset['train']) / batch_size))

In [12]:
optimizer = Adafactor(
    model.parameters(),
    lr=lr,
    eps=(1e-30, 1e-3),
    clip_threshold=1.0,
    decay_rate=-0.8,
    beta1=None,
    weight_decay=0.0,
    relative_step=False,
    scale_parameter=False,
    warmup_init=False,
)

In [13]:
# %load_ext tensorboard
# %tensorboard --logdir 'runs'

In [14]:
losses = []

writer = SummaryWriter()

pbar = tqdm(range(0, n_epochs), total=n_epochs)
for epoch_idx in pbar:

    # Get data batch
    data_generator = get_data_generator(epoch_idx, dataset, tokenizer, batch_size)

    for batch_idx, (input_batch, input_attention_mask_batch, target_batch, target_attention_mask_batch) in tqdm(
            enumerate(data_generator), total=n_batches):

        optimizer.zero_grad()

        # Forward pass
        model_out = model.forward(
            input_ids=input_batch.to(device),
            labels=target_batch.to(device),
            attention_mask=input_attention_mask_batch.to(device),
            decoder_attention_mask=target_attention_mask_batch.to(device)
        )

        # Calculate loss and update weights
        loss = model_out.loss
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        # scheduler.step()

        # Print training update info
        if (batch_idx + 1) % print_freq == 0:
            avg_loss = np.mean(losses[-print_freq:])
            writer.add_scalar('Loss/train', avg_loss, batch_idx + 1 + epoch_idx * n_batches)
            # writer.add_scalar('lr/train', scheduler.get_last_lr()[0], batch_idx + 1 + epoch_idx * n_batches)
            writer.add_scalar('lr/train', AdafactorSchedule(optimizer).get_lr()[0],
                              batch_idx + 1 + epoch_idx * n_batches)

        pbar.set_description("loss = %s" % loss.item())

        clear_cache()

    path = 'models/' + model_path
    if not os.path.exists(path):
        os.makedirs(path)

    print('Saving model for epoch %d' % epoch_idx)
    model.save_pretrained(path + '/epoch_%d' % epoch_idx)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-6e071bd2d795e9a5.arrow


Saving model for epoch 0


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-5a6ad6b9ad6f9503.arrow


Saving model for epoch 1


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-3ff74d5b11893657.arrow


Saving model for epoch 2


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-5492f2cedf1dae8c.arrow


Saving model for epoch 3


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-208914c09e727aba.arrow


Saving model for epoch 4


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-83f5b3f0764a8c99.arrow


Saving model for epoch 5


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-2306662a118260ca.arrow


Saving model for epoch 6


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-127b47c12e1d4941.arrow


Saving model for epoch 7


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-f5bfd3426fc11a60.arrow


Saving model for epoch 8


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-44a37123d8a25216.arrow


Saving model for epoch 9


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-d68c3ab575f33c3a.arrow


Saving model for epoch 10


  0%|          | 0/1480 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-e417feb4da9299e9\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-d96e8e3540605682.arrow


KeyboardInterrupt: 

In [None]:
# cyrillic_to_latin('Йоплъ, сабий, дадэ, мажэ, йощакӀуэ, Ӏэхъуэ, щакӀуэ, йоджэ, еджакӀуэ, къопс, дыгъэ, вагъуэ, мэлыд, ехь.')

In [None]:
# replace_cyrillic_with_latin('data/for_tokenizer/tokenizer_data_kbd.txt')

In [None]:
# tokenizer.save_pretrained('./models/kbd_lat-835k_ru-3M_t5-small')