In [1]:
from tokenizer.BPE import tokenize, tokenizer
import pickle as pkl
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchaudio.datasets as datasets
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup
from transformer import TransformerLanguageModel
import wandb
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


### Tokenized dataset creation

In [2]:
current_dir = os.getcwd()

merges_path = os.path.join(current_dir, "tokenizer", "merges.pkl")
vocab_path = os.path.join(current_dir, "tokenizer", "vocabulary.pkl")

# Загрузка merges.pkl
with open(merges_path, "rb") as f:
    merges = pkl.load(f)
    print("Загрузка merges.pkl успешна")

# Загрузка vocab.pkl
with open(vocab_path, "rb") as f:
    vocab = pkl.load(f)
    print("Загрузка vocabulary.pkl успешна")

text = 'HELLO MY NAME IS BILL'
tokens = [vocab[0]] + tokenize(text, merges) + [vocab[1]]
#print(tokens)
#print(tokenizer.convert_tokens_to_string(tokens))

Загрузка merges.pkl успешна
Загрузка vocabulary.pkl успешна


In [3]:
wandb.init(project='TransformerLM')
config = {
    'dim_feedforward': 64,
    'num_heads': 8,
    'num_layers': 8,
    'learning_rate': 0.001,
    'batch_size': 64,
    'epochs': 256,
    'embedding_dim': 64,
    'dataset': "LibriSpeech dev-clean",
    'vocab_size': len(vocab),
}
wandb.config.update(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: roman-kuznetsov (roman-kuznetsov-bmstu-) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


### Data mining

In [4]:
token_to_id = {vocab[i]: i for i in range(len(vocab))}
id_to_token = {i: vocab[i] for i in range(len(vocab))}
PAD_ID = 2

data = datasets.LIBRISPEECH("../data", url="dev-clean", )
corpus = []
for i in range(2800):
    try:
        corpus.append(list(map(lambda x: token_to_id[x], [vocab[0]] + tokenize(data.__getitem__(i)[2], merges) + [vocab[1]])))
    except IndexError as err:
        break

max_length = max(len(seq) for seq in corpus)
print(max_length)
class TextDataset(Dataset):
    def __init__(self, data, max_len):
        self.data = data
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = torch.tensor(self.data[idx], dtype=torch.int16)
        sample = sample[:self.max_len]
        length = sample.shape[-1]
        padding = torch.ones((self.max_len - sample.shape[-1])) * 2
        sample = torch.cat((sample, padding), dim=0)
        return torch.tensor(sample, dtype=torch.float), length

dataset = TextDataset(corpus, max_length)

train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2)

# Создание тренировочного и валидационного датасетов
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# Создание DataLoader-ов для тренировочного и валидационного датасетов
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])

294


In [5]:
print(token_to_id)

{'<|startoftext|>': 0, '<|endoftext|>': 1, '<|padding|>': 2, "'": 3, 'A': 4, 'B': 5, 'C': 6, 'D': 7, 'E': 8, 'F': 9, 'G': 10, 'H': 11, 'I': 12, 'J': 13, 'K': 14, 'L': 15, 'M': 16, 'N': 17, 'O': 18, 'P': 19, 'Q': 20, 'R': 21, 'S': 22, 'T': 23, 'U': 24, 'V': 25, 'W': 26, 'X': 27, 'Y': 28, 'Z': 29, 'Ġ': 30, 'ĠT': 31, 'HE': 32, 'ĠA': 33, 'IN': 34, 'ĠTHE': 35, 'ĠW': 36, 'ĠS': 37, 'ĠO': 38, 'RE': 39, 'ND': 40, 'ĠH': 41, 'ER': 42, 'ĠB': 43, 'ĠM': 44, 'OU': 45, 'IT': 46, 'ĠF': 47, 'IS': 48, 'ĠC': 49, 'AT': 50, 'ED': 51, 'ĠAND': 52, 'ĠOF': 53, 'EN': 54, 'ON': 55, 'ING': 56, 'ĠTO': 57, 'ĠP': 58, 'OR': 59, 'ES': 60, 'ĠD': 61, 'ĠTH': 62, 'ĠL': 63, 'AN': 64, 'AS': 65, 'ĠIN': 66, 'AR': 67, 'LL': 68, 'ĠN': 69, 'ĠHE': 70, 'ĠG': 71, 'AD': 72, 'LE': 73, 'OM': 74, 'ĠE': 75, 'ĠBE': 76, 'OT': 77, 'UT': 78, 'IC': 79, 'OW': 80, 'LY': 81, 'SE': 82, 'ĠI': 83, 'ST': 84, 'VE': 85, 'ĠWAS': 86, 'LD': 87, 'ĠWH': 88, 'GH': 89, 'ĠIT': 90, 'ĠTHAT': 91, 'ĠON': 92, 'ĠU': 93, 'ENT': 94, 'AL': 95, 'THE': 96, 'ID': 97, 'IM

In [6]:
for i, el in enumerate(val_loader):
    if i < 4:
        print(el)
    else:
        break

  return torch.tensor(sample, dtype=torch.float), length


[tensor([[  0.,  67.,   7.,  ...,   2.,   2.,   2.],
        [  0.,  59.,  12.,  ...,   2.,   2.,   2.],
        [  0.,  96.,  61.,  ...,   2.,   2.,   2.],
        ...,
        [  0.,  32.,  86.,  ...,   2.,   2.,   2.],
        [  0.,  32., 116.,  ...,   2.,   2.,   2.],
        [  0.,  12.,  57.,  ...,   2.,   2.,   2.]]), tensor([ 83,  83,  57,  19, 113,  29,  47,  45,  47,  24,  33,  37,  19,  21,
         35, 117,  48, 148,  94, 170,  59,  34, 127,  23,  47,  15,  15,  67,
         44,  79,  22,  50,  45,  53,  21,  36, 128,  46, 196,  53,  51,  87,
         56,  46, 140,  50,  14,  85,  29,  28,  53,  76, 154,  12, 102,  32,
         58,  69,  40,  66,  49,  17,  74,  13])]
[tensor([[  0.,   5.,  28.,  ...,   2.,   2.,   2.],
        [  0.,  50.,  35.,  ...,   2.,   2.,   2.],
        [  0.,  96.,  39.,  ...,   2.,   2.,   2.],
        ...,
        [  0.,  55.,  44.,  ...,   2.,   2.,   2.],
        [  0.,   5.,   8.,  ...,   2.,   2.,   2.],
        [  0.,  12., 116.,  ...,   2

### prerequisites

In [7]:
def length_to_mask(inputs, lengths, dtype=None):
    batch_size = lengths.size(0)
    seq_len = inputs.size(1) if isinstance(inputs, torch.Tensor) else inputs
    # Causal mask [seq_len, seq_len]
    tgt_mask = torch.triu(torch.ones((seq_len, seq_len), device=device)).transpose(0, 1)
    tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float('-inf')).masked_fill(tgt_mask == 1, float(0.0))

    # Padding mask [batch_size, seq_len]
    key_padding_mask = (torch.arange(seq_len, device=device).expand(batch_size, seq_len) >= lengths.unsqueeze(1))

    if dtype is not None:
        key_padding_mask = key_padding_mask.to(dtype=dtype)

    return tgt_mask, key_padding_mask

In [8]:
# Параметры модели
vocab_size = config['vocab_size']
embedding_dim = config['embedding_dim']
dim_feedforward = config['dim_feedforward']
num_heads = config['num_heads']
num_layers = config['num_layers']
num_epochs = config['epochs']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerLanguageModel(vocab_size, embedding_dim, num_heads, dim_feedforward, num_layers).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=PAD_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])

total_steps = len(train_loader) * num_epochs

warmup_steps = 0.2 * total_steps

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

### Training functions

In [9]:
def train_one_epoch(epoch_index, model, training_loader, scheduler, optimizer, loss_fn):
    model.train()
    running_loss = 0.
    total_loss = 0.
    correct = 0
    total = 0

    for i, (data, lengths) in enumerate(training_loader):
        input_data = data.to(device)
        optimizer.zero_grad()
        lengths = lengths.to(device) - 1

        optimizer.zero_grad()

        input_ids = input_data[:, :-1].long()
        targets = input_data[:, 1:].long()

        seq_len = input_ids.size(1)
        tgt_mask, tgt_key_padding_mask = length_to_mask(seq_len, lengths)
        tgt_mask, tgt_key_padding_mask = tgt_mask.to(device), tgt_key_padding_mask.to(device)

        outputs = model(input_ids, tgt_mask=tgt_mask, lengths=lengths, tgt_key_padding_mask=tgt_key_padding_mask)

        outputs = outputs.reshape(-1, outputs.size(-1))  # [B * T, vocab_size]
        targets = targets.reshape(-1)                    # [B * T]

        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        total_loss += loss.item()
        running_loss += loss.item()

        # Accuracy
        preds = outputs.argmax(dim=1)
        mask = (targets != PAD_ID)
        correct += (preds[mask] == targets[mask]).sum().item()
        total += mask.sum().item()

        # Logging
        if i % 10 == 9:
            wandb.log({'Loss': running_loss / 10, "Train": epoch_index * len(training_loader) + i + 1})
            running_loss = 0.

    acc = correct / total
    wandb.log({'Accuracy': acc, "Train": epoch_index + 1})
    return total_loss / len(training_loader)


def validation(val_dataloader):
    val_loss = 0.
    correct = 0
    total = 0
    model.eval()

    with torch.no_grad():
        for input_data, lengths in val_dataloader:
            input_data = input_data.to(device)
            lengths = lengths.to(device) - 1

            input_ids = input_data[:, :-1].long()
            targets = input_data[:, 1:].long()

            tgt_mask, tgt_key_padding_mask = length_to_mask(input_ids, lengths)
            tgt_mask, tgt_key_padding_mask = tgt_mask.to(device), tgt_key_padding_mask.to(device)

            outputs = model(input_ids, tgt_mask=tgt_mask, lengths=lengths, tgt_key_padding_mask=tgt_key_padding_mask)

            outputs = outputs.reshape(-1, outputs.size(-1))
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            val_loss += loss.item()

            preds = outputs.argmax(dim=1)
            mask = (targets != PAD_ID)
            correct += (preds[mask] == targets[mask]).sum().item()
            total += mask.sum().item()

    val_loss = val_loss / len(val_dataloader)
    acc = correct / total
    print("VAL LOSS =", val_loss)
    return val_loss, acc

### Training

In [10]:
best_vloss = 1_000_000.
counter = 0
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

for epoch_number in range(num_epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    model.train()
    train_loss = train_one_epoch(
        epoch_index=epoch_number,
        model=model,
        training_loader=train_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        loss_fn=criterion,)

    model.eval()
    val_loss, val_acc = validation(val_loader)

    wandb.log({'Loss/valid': val_loss, 'Accuracy/valid': val_acc}, step=epoch_number + 1)

    if val_loss < best_vloss:
        best_vloss = val_loss
        model_path = os.path.join("../best_models/transformer", 'model_{}_{}'.format(epoch_number + 1, timestamp))
        torch.save(model.state_dict(), model_path)
    else:
        if counter > 10:
            break

wandb.finish()

EPOCH 1:


  return torch.tensor(sample, dtype=torch.float), length


VAL LOSS = 5.030080530378553
EPOCH 2:
VAL LOSS = 4.98148250579834
EPOCH 3:
VAL LOSS = 4.911378383636475
EPOCH 4:
VAL LOSS = 4.832472483317058
EPOCH 5:
VAL LOSS = 4.764636993408203
EPOCH 6:
VAL LOSS = 4.714720620049371
EPOCH 7:
VAL LOSS = 4.679715580410427
EPOCH 8:
VAL LOSS = 4.655336380004883
EPOCH 9:
VAL LOSS = 4.638492266337077
EPOCH 10:
VAL LOSS = 4.626918474833171
EPOCH 11:
VAL LOSS = 4.619131459130181
EPOCH 12:
VAL LOSS = 4.61301162507799
EPOCH 13:
VAL LOSS = 4.607271194458008
EPOCH 14:
VAL LOSS = 4.600317213270399
EPOCH 15:
VAL LOSS = 4.588482909732395
EPOCH 16:
VAL LOSS = 4.568347507052952
EPOCH 17:
VAL LOSS = 4.531885676913792
EPOCH 18:
VAL LOSS = 4.477660126156277
EPOCH 19:
VAL LOSS = 4.413697878519694
EPOCH 20:
VAL LOSS = 4.353170606825087
EPOCH 21:
VAL LOSS = 4.298386944664849
EPOCH 22:
VAL LOSS = 4.249417304992676
EPOCH 23:
VAL LOSS = 4.20613612069024
EPOCH 24:
VAL LOSS = 4.166478739844428
EPOCH 25:
VAL LOSS = 4.1301601197984485
EPOCH 26:
VAL LOSS = 4.095744927724202
EPOCH 

0,1
Accuracy,▁▁▂▂▂▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████████████
Loss,█▇▇▆▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▁▁▄▄▄▄▁▅▅▅▁▅▅▆▆▆▁▁▁▇▇█▁██▁

0,1
Accuracy,0.2851
Loss,3.20238
Train,256.0


### Test prompt

In [50]:
i = 8
base, b_l = train_dataset.__getitem__(i)[0].clone().detach().tolist(), train_dataset.__getitem__(i)[1]
init_one = []
for tok in base:
    init_one.append(id_to_token[tok])
init_one = ''.join(init_one[1:-1]).replace("Ġ", " ")
print(init_one)


  return torch.tensor(sample, dtype=torch.float), length


HE CAME TO HER SIDE AND SHE GAVE HIM NO GREETING<|endoftext|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|pa

In [51]:
model.eval()
input_data = train_dataset.__getitem__(i)[0].unsqueeze(0).to(device)
lengths = train_dataset.__getitem__(i)[1]
with torch.no_grad():
    input_data = input_data.to(device)
    lengths = (torch.tensor(lengths).to(device) - 1).unsqueeze(0).to(device)

    input_ids = input_data[:, :-1].long()
    targets = input_data[:, 1:].long()

    tgt_mask, tgt_key_padding_mask = length_to_mask(input_ids, lengths)
    tgt_mask, tgt_key_padding_mask = tgt_mask.to(device), tgt_key_padding_mask.to(device)

    outputs = model(input_ids, tgt_mask=tgt_mask, lengths=lengths, tgt_key_padding_mask=tgt_key_padding_mask)

    outputs = outputs.reshape(-1, outputs.size(-1))
    targets = targets.reshape(-1)

    preds = outputs.argmax(dim=1)
preds = preds.cpu().tolist()
init_one = []
for tok in preds:
    init_one.append(id_to_token[tok])
init_one = ''.join(init_one[1:-1]).replace("Ġ", " ")
print(init_one)

  return torch.tensor(sample, dtype=torch.float), length


 WASOUME TO THERSEHEE AND SHE SOOVE THEIM TOO LOATING AIEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE
