https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

https://github.com/pytorch/examples/blob/main/word_language_model/model.py

In [85]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np
import random

In [None]:
from transformer import Transformer

In [45]:
from net import Net_CBOW
import numpy as np
version = "april27_WT2_nodatalim_10epoch_128dim_100minf"

vocab = torch.load(f"saves/vocab_{version}.pt")
len(vocab)

2156

In [46]:
def lookup_id(word, vocab=vocab):
    if word not in vocab:
        return vocab["<unk>"]
    return vocab[word]
def lookup_token(word_id, vocab=vocab):
    for word in vocab:
        if vocab[word] == word_id:
            return word
    return None

In [107]:
from datasets import load_dataset
wikitext2 = load_dataset("wikitext", "wikitext-2-v1")
text_train = wikitext2["train"]['text']
text_train = [item.lower().strip() for item in text_train if len(item) > 0]
text_test = wikitext2["test"]['text']
text_test = [item.lower().strip() for item in text_test if len(item) > 0]
len(text_test)
text_train = [item.split(" ") + ["\n"] for item in text_train if "=" not in item]
text_test = [item.split(" ") + ["\n"] for item in text_test if "=" not in item]

max_seq_length = 128
start_i = 20

x_train = [[lookup_id(word) for word in paragraph[start_i:max_seq_length]] for paragraph in text_train if len(paragraph) >= max_seq_length + start_i]
y_train = [[word for word in paragraph] for paragraph in x_train]
x_test = [[lookup_id(word) for word in paragraph[start_i:max_seq_length]] for paragraph in text_test if len(paragraph) >= max_seq_length + start_i]
y_test = [[word for word in paragraph] for paragraph in x_test]
# print([[word for word in paragraph[start_i:max_seq_length]] for paragraph in text_train if len(paragraph) >= max_seq_length + start_i])

In [110]:
src_vocab_size = len(vocab)
tgt_vocab_size = len(vocab)
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 128
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

print(len(x_train))
x_train = x_train
y_train = y_train
src_data = torch.tensor(x_train)

unk_id = lookup_id("<unk>")
count_unk = sum(sum((i == unk_id) for i in paragraph) for paragraph in src_data).item()
count_total = sum(sum(1 for i in paragraph) for paragraph in src_data)
print(count_unk, count_total, count_unk/count_total)

# a = torch.randint(1, src_vocab_size, (64, max_seq_length))
# print(a[:, 1:])
# print(a[:, :-1])
len(x_train)


5319
128605 574452 0.22387423144144333


5319

In [111]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

batch_size = 100
num_batches = len(x_train) // batch_size
print("Batches per epoch:", num_batches)

percent_data_per_epoch = 0.5


indices = list(range(len(x_train)))
for epoch in range(20):
    epoch_loss = 0
    x_train_copy = [x_train[indices[i]] for i in range(len(indices))]
    y_train_copy = [y_train[indices[i]] for i in range(len(indices))]
    for batch in range(int(num_batches * percent_data_per_epoch)):
        src_data = torch.tensor(x_train_copy[batch*batch_size:(batch+1)*batch_size])  # (batch_size, seq_length)
        tgt_data = torch.tensor(y_train_copy[batch*batch_size:(batch+1)*batch_size])  # (batch_size, seq_length)
        optimizer.zero_grad()
        output = transformer(src_data, tgt_data[:, :-1])
        loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        print("|", end="")
        epoch_loss += loss.item()
    epoch_loss /= int(num_batches * percent_data_per_epoch)
    random.shuffle(indices)
    print(f" -> Epoch: {epoch+1}, Loss: {epoch_loss}")

torch.save(transformer, "saves/model_transformer_apr29_1200pm.pt")

Batches per epoch: 53
|||||||||||||||||||||||||| -> Epoch: 1, Loss: 5.632079216150137
|||||||||||||||||||||||||| -> Epoch: 2, Loss: 5.072457240178035
|||||||||||||||||||||||||| -> Epoch: 3, Loss: 4.915947877443754
|||||||||||||||||||||||||| -> Epoch: 4, Loss: 4.761176622830904
|||||||||||||||||||||||||| -> Epoch: 5, Loss: 4.634160188528208
|||||||||||||||||||||||||| -> Epoch: 6, Loss: 4.515117828662579
|||||||||||||||||||||||||| -> Epoch: 7, Loss: 4.42727437386146
|||||||||||||||||||||||||| -> Epoch: 8, Loss: 4.333838829627404
|||||||||||||||||||||||||| -> Epoch: 9, Loss: 4.249075962946965
|||||||||||||||||||||||||| -> Epoch: 10, Loss: 4.1785673728356
|||||||||||||||||||||||||| -> Epoch: 11, Loss: 4.112065058488112
|||||||||||||||||||||||||| -> Epoch: 12, Loss: 3.9986279744368334
|||||||||||||||||||||||||| -> Epoch: 13, Loss: 3.891067523222703
|||||||||||||||||||||||||| -> Epoch: 14, Loss: 3.7791364192962646
|||||||||||||||||||||||||| -> Epoch: 15, Loss: 3.6912831709935117
||||||||

KeyboardInterrupt: 

In [112]:
torch.save(transformer, "saves/model_transformer_apr29_1200pm.pt")

In [113]:
o = transformer(src_data[1:2], tgt_data[1:2, :-1])
sm = np.array(torch.softmax(o, 1)[0].detach())
ids = [list(v).index(max(v)) for v in sm]
words = [lookup_token(i) for i in ids]
print([lookup_token(i) for i in src_data[0]])
print(words)

['<unk>', '<unk>', '<unk>', '(', 'jack', '<unk>', ')', 'and', '<unk>', '(', 'james', '<unk>', '<unk>', ')', '.', 'the', 'captain', '<unk>', 'that', 'he', 'was', 'on', 'the', '<unk>', '<unk>', ',', 'returning', 'from', '<unk>', 'space', 'with', 'the', '<unk>', '<unk>', 'of', 'time', '.', 'they', 'had', '<unk>', 'up', 'a', '<unk>', 'on', 'the', 'way', ',', 'a', 'human', 'called', '<unk>', '<unk>', '.', '<unk>', 'the', 'ship', 'found', 'itself', 'some', '200', 'light', 'years', 'away', 'from', 'its', 'previous', 'location', 'and', 'a', 'hundred', 'years', 'in', 'the', 'past', ',', 'near', 'deep', 'space', 'station', '<unk>', 'and', 'found', 'the', '<unk>', '<unk>', 'in', '<unk>', '.', 'they', '<unk>', 'that', 'the', '<unk>', '@-@', '<unk>', 'was', '<unk>', '<unk>', '(', '<unk>', '<unk>', ')', ',', 'a', '<unk>', '<unk>', 'who', 'had']
['march', 'earliest', '@-@', 'yard', 'minor', '(', '.', 'according', '0', 'km', 'mexico', 'first', 'ft', 'operated', 'km', ')', 'been', 'longer', '@.@', "'s"

In [151]:
text = "you are a helpful assistant . Answer the following question . "

text = [lookup_id(word.lower()) for word in text.strip().split(" ")]

for count in range(50):
    i = torch.tensor([text])
    o = transformer(i, i)
    sm = np.array(torch.softmax(o, 2)[0].detach())
    top5 = [np.zeros(5) for _ in range(len(sm))]
    top5p = [np.zeros(5) for _ in range(len(sm))]
    for vi in range(len(sm)-1,len(sm)):
        v = sm[vi]
        for item in v:
            m = top5[vi][list(top5[vi]).index(min(top5[vi]))]
            if lookup_token(list(v).index(item)) != "<unk>":
                top5[vi][list(top5[vi]).index(min(top5[vi]))] = max(m, item)
        top5[vi] = [list(v).index(i) for i in top5[vi]]
        top5p[vi] = [v[i] for i in top5[vi]]
    # ids = [list(v).index(max(v)) for v in sm]
    words = [[lookup_token(i) for i in w] for w in top5][-1]
    # print([lookup_token(i) for i in src_data[0]])
    # print(words)
    print(' '.join([lookup_token(i) for i in text]))
    print(words)
    print(top5p[-1])
    chosen_word = words[int(input("Choose a word: "))-1]
    text.append(lookup_id(chosen_word))
# text


you are a <unk> <unk> . <unk> the following <unk> .
['after', 'it', 'are', 'a', 'this']
[0.008957216, 0.013795887, 0.01188761, 0.11148329, 0.016574614]
you are a <unk> <unk> . <unk> the following <unk> . a
['a', 'few', 'large', 'small', '"']
[0.014941752, 0.004584193, 0.005560547, 0.005065698, 0.0033910982]
you are a <unk> <unk> . <unk> the following <unk> . a few
['can', 'a', '.', 'are', 'is']
[0.0043566437, 0.009275039, 0.0067396043, 0.048805617, 0.008312006]
you are a <unk> <unk> . <unk> the following <unk> . a few can
['have', 'are', 'do', 'a', 'be']
[0.057717957, 0.035070132, 0.0101899505, 0.023004128, 0.12751704]
you are a <unk> <unk> . <unk> the following <unk> . a few can have
['a', 'been', 'also', 'are', 'have']
[0.15154977, 0.12867396, 0.006671486, 0.024408525, 0.011477391]
you are a <unk> <unk> . <unk> the following <unk> . a few can have a
['small', 'a', 'large', 'are', 'few']
[0.012695176, 0.0087218555, 0.007897064, 0.006415769, 0.008232827]
you are a <unk> <unk> . <unk> t

ValueError: invalid literal for int() with base 10: ''