In [1]:
from basic import Transf
import torch
import torch.nn as nn
import json
import random

In [2]:
with open("../train.json") as f:
    train_data = json.load(f)
with open("../test.json") as f:
    test_data = json.load(f)

In [3]:
vocab = set()
for dataset in (train_data, test_data):
    for dp in dataset:
        vocab.update(dp["human"])
        vocab.update(dp["machine"])

In [4]:
vocab = sorted(list(vocab))
vocab = ['<eos>', '<system>', '<user>'] + vocab

id_to_char = {id: char for id, char in enumerate(vocab)}
char_to_id = {char: id for id, char in enumerate(vocab)}

In [5]:
def encode(text):
    return [char_to_id[c] for c in text]

def decode(tokens):
    return "".join([id_to_char[t] for t in tokens])

In [6]:
max_seq_len = 45
training_data = []

for dp in train_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    training_data.append(tokens)

In [7]:
testing_data = []

for dp in test_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    testing_data.append(tokens)

In [8]:
def get_test_loss():
    model.eval()
    x = torch.tensor(testing_data).to(device)
    logits = model(x)

    targets = x[:, 1:]
    logits = logits[:, :-1, :]

    targets = targets.reshape(-1)
    logits = logits.reshape(-1, logits.shape[-1])

    loss = loss_fn(logits, targets)
    return loss.item()

In [9]:
device = "cuda:1"

model = Transf(len(vocab), max_seq_len).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)

batch_size = 64
epoch = 25

In [10]:
for e in range(epoch):
    model.train()
    print("epoch: ", e)
    total_loss = 0
    
    random.shuffle(training_data)
    for i in range(0, len(training_data), batch_size):
        x = torch.tensor(training_data[i:i+batch_size]).to(device)
        
        logits = model(x)
        
        targets = x[:, 1:]
        logits = logits[:, :-1, :]
        
        targets = targets.reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"training loss: ", total_loss)
    print(f"test loss: ", get_test_loss())

epoch:  0
training loss:  551.5218180418015
test loss:  0.5036930441856384
epoch:  1
training loss:  376.1496481001377
test loss:  0.4441390335559845
epoch:  2
training loss:  347.49298372864723
test loss:  0.4287894070148468
epoch:  3
training loss:  334.11578699946404
test loss:  0.41549453139305115
epoch:  4
training loss:  324.5137833058834
test loss:  0.40112394094467163
epoch:  5
training loss:  316.6678135693073
test loss:  0.3929485082626343
epoch:  6
training loss:  311.2699226140976
test loss:  0.38605552911758423
epoch:  7
training loss:  307.67882242798805
test loss:  0.38026532530784607
epoch:  8
training loss:  304.87169539928436
test loss:  0.3829655051231384
epoch:  9
training loss:  303.7044238448143
test loss:  0.37791192531585693
epoch:  10
training loss:  298.7654845416546
test loss:  0.3727497458457947
epoch:  11
training loss:  296.89222145080566
test loss:  0.37366777658462524
epoch:  12
training loss:  295.6288547217846
test loss:  0.370231032371521
epoch:  13
t

In [None]:
# layernorm = 0.33
# embd = 32: 0.34
# embd = 16: 0.39   
# residual:  0.58
# normal: 0.61

In [16]:
model.eval()

date = "26 jan 2006"
tokens = [2] + encode(date) + [1]
next_token = -1

while next_token != 0 and len(tokens) < max_seq_len:
    out = model(torch.tensor([tokens]).to(device))
    next_token = out[:, -1].argmax(dim=-1).item()
    tokens.append(next_token)

print(decode(tokens))

<user>26 jan 2006<system>2006-01-26<eos>
