In [1]:
from datasets import load_dataset
import sentencepiece as sentencepiece
from transformers import MT5Tokenizer, MT5Config, MT5ForConditionalGeneration
from transformers.optimization import Adafactor, AdafactorSchedule

import os
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda')

### Dataset

In [3]:
dataset = load_dataset("text", data_files={"train": ['data/merged_delete_later.txt', 'data/Haji-Murat_kbd-ru.txt']})

Using custom data configuration default-7d8fd473ca3f8090
Reusing dataset text (C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
max_tokenization_length = 64


def encode(text, _tokenizer):
    return _tokenizer(text,
                      padding='longest',
                      max_length=max_tokenization_length,
                      truncation=True,
                      return_tensors="pt")


def encode_batch(raw_batch, _input_tokenizer, _output_tokenizer):
    input_texts = []
    target_texts = []

    for translation_pair in raw_batch['text']:
        target_text, input_text = translation_pair.split('😀')
        input_texts.append(input_text)
        target_texts.append(target_text)

    encoding = encode(input_texts, _input_tokenizer)
    input_ids, input_attention_masks_ids = encoding.input_ids, encoding.attention_mask

    target_encoding = encode(target_texts, _output_tokenizer)
    target_ids, target_attention_masks_ids = target_encoding.input_ids, target_encoding.attention_mask

    input_ids.to(device)
    input_attention_masks_ids.to(device)
    target_ids.to(device)
    target_attention_masks_ids.to(device)

    return input_ids, input_attention_masks_ids, target_ids, target_attention_masks_ids


def get_data_generator(_dataset, _input_tokenizer, _output_tokenizer, _batch_size=32):
    shuffled_dataset = _dataset.shuffle()[
        'train']  #https://huggingface.co/docs/datasets/dataset_streaming.html#reshuffle-the-dataset-at-each-epoch
    for i in range(0, len(shuffled_dataset), _batch_size):
        raw_batch = shuffled_dataset[i:i + _batch_size]
        yield encode_batch(raw_batch, _input_tokenizer, _output_tokenizer)

### Tokenizer

##### Pretraining

In [5]:
# def pretrain_tokenizer(input_files, vocab_size, model_prefix, save_directory):
#     sentencepiece.SentencePieceTrainer.Train('--input=' + input_files +
#                                              ' --model_prefix=' + model_prefix +
#                                              ' --vocab_size=' + str(vocab_size) +
#                                              ' --model_type=unigram'
#                                              ' --character_coverage=1.0'
#                                              ' --pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1'
#                                              ' --pad_piece=<pad> --eos_piece=</s> --unk_piece=<unk>')  #--bos_piece=<s>
#
#     mt5_tokenizer = MT5Tokenizer.from_pretrained(model_prefix + '.model', extra_ids=0)
#     mt5_tokenizer.save_pretrained(save_directory)
#
#     # delete temporary files
#     if os.path.exists(model_prefix + '.model'):
#         os.remove(model_prefix + '.model')
#     if os.path.exists(model_prefix + '.vocab'):
#         os.remove(model_prefix + '.vocab')

In [6]:
input_lang = 'ru'
output_lang = 'kbd'

input_tokenizer_path = 'tokenizers/' + input_lang
output_tokenizer_path = 'tokenizers/' + output_lang

# pretrain_tokenizer(input_files='data/for_tokenizer/tokenizer_data_ru.txt', model_prefix=input_lang, vocab_size=32000,
#                    save_directory=input_tokenizer_path)
# pretrain_tokenizer(input_files='data/for_tokenizer/tokenizer_data_kbd.txt', model_prefix=output_lang, vocab_size=500,
#                    save_directory=output_tokenizer_path)

##### Using pretrained

In [7]:
input_tokenizer = MT5Tokenizer.from_pretrained(input_tokenizer_path + '/spiece.model', extra_ids=0)
output_tokenizer = MT5Tokenizer.from_pretrained(output_tokenizer_path + '/spiece.model', extra_ids=0)



In [8]:
# spm = input_tokenizer
# print(spm.convert_ids_to_tokens(spm.encode('сэ сощIэ')))
# print(spm.convert_ids_to_tokens(spm.encode('уэ уощIэ')))
# print(spm.convert_ids_to_tokens(spm.encode('зэрыфщIэщи')))
# print(spm.convert_ids_to_tokens(spm.encode('дызэрыщIэнщ')))
# print(spm.convert_ids_to_tokens(spm.encode('дызэрыщIэнущ')))
# print(spm.convert_ids_to_tokens(spm.encode('дызэрыщIащ')))
# # print(spm.convert_ids_to_tokens(spm.encode('мы знали друг друга')))
# # print(spm.convert_ids_to_tokens(spm.encode('откуда тебе знать')))
# print(spm.convert_ids_to_tokens(spm.encode('джэд')))
# print(spm.convert_ids_to_tokens(spm.encode('джэдыкIэ')))
# print(spm.convert_ids_to_tokens(spm.encode('IуэхущIапIэ')))
# print(spm.convert_ids_to_tokens(spm.encode('лэжьапIэ')))
# print(spm.convert_ids_to_tokens(spm.encode('зыбгъэдэлъ щIэныгъэлIт')))
# print(spm.convert_ids_to_tokens(spm.encode('Лос-Анджелес къалэм щопсэу')))
# print(output_tokenizer.convert_ids_to_tokens(output_tokenizer.encode('Мухаммед')))

### Model

In [9]:
config = MT5Config.from_pretrained("google/mt5-small")
config.vocab_size = len(input_tokenizer.get_vocab()) + len(output_tokenizer.get_vocab()) - len(
    output_tokenizer.special_tokens_map)  # subtract special tokens once so they're not counted twice
config.max_length = 64

In [10]:
model = MT5ForConditionalGeneration(config)
# model = MT5ForConditionalGeneration.from_pretrained('models/epoch_0')
model.to(device)

MT5ForConditionalGeneration(
  (shared): Embedding(32497, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32497, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedGeluDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo)

### Training

In [11]:
# Constants
n_epochs = 30
batch_size = 116
print_freq = 10
# checkpoint_freq = 2500
lr = 1e-4
n_batches = int(np.ceil(len(dataset['train']) / batch_size))

In [12]:
optimizer = Adafactor(
    model.parameters(),
    lr=lr,
    eps=(1e-30, 1e-3),
    clip_threshold=1.0,
    decay_rate=-0.8,
    beta1=None,
    weight_decay=0.0,
    relative_step=False,
    scale_parameter=False,
    warmup_init=False,
)

In [13]:
# %load_ext tensorboard
# %tensorboard --logdir 'runs'

In [14]:
from torch.utils.tensorboard import SummaryWriter

losses = []

writer = SummaryWriter()

pbar = tqdm(range(0, n_epochs), total=n_epochs)
for epoch_idx in pbar:

    # Get data batch
    data_generator = get_data_generator(dataset, input_tokenizer, output_tokenizer, batch_size)

    for batch_idx, (input_batch, input_attention_mask_batch, target_batch, target_attention_mask_batch) in tqdm(
            enumerate(data_generator), total=n_batches):

        optimizer.zero_grad()

        # Forward pass
        model_out = model.forward(
            input_ids=input_batch.to(device),
            labels=target_batch.to(device),
            attention_mask=input_attention_mask_batch.to(device),
            decoder_attention_mask=target_attention_mask_batch.to(device)
        )

        # Calculate loss and update weights
        loss = model_out.loss
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        # scheduler.step()

        # Print training update info
        if (batch_idx + 1) % print_freq == 0:
            avg_loss = np.mean(losses[-print_freq:])
            writer.add_scalar('Loss/train', avg_loss, batch_idx + 1 + epoch_idx * n_batches)
            # writer.add_scalar('lr/train', scheduler.get_last_lr()[0], batch_idx + 1 + epoch_idx * n_batches)
            writer.add_scalar('lr/train', AdafactorSchedule(optimizer).get_lr()[0],
                              batch_idx + 1 + epoch_idx * n_batches)

        pbar.set_description("loss = %s" % loss.item())

    print('Saving model for epoch %d' % epoch_idx)
    model.save_pretrained('models/epoch_%d' % epoch_idx)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/51 [00:00<?, ?it/s]

Saving model for epoch 0


  0%|          | 0/51 [00:00<?, ?it/s]

Saving model for epoch 1


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 2


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 3


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 4


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 5


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 6


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 7


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 8


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 9


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 10


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 11


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 12


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 13


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 14


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 15


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 16


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 17


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 18


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 19


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 20


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 21


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 22


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 23


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 24


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 25


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 26


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 27


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 28


  0%|          | 0/51 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at C:\Users\anzor\.cache\huggingface\datasets\text\default-7d8fd473ca3f8090\0.0.0\e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5\cache-96c491c21127f319.arrow


Saving model for epoch 29
