In [1]:
from basic import TasnsfModel
import torch
import torch.nn as nn
import json
import random

In [2]:
with open("../train.json") as f:
    train_data = json.load(f)
with open("../test.json") as f:
    test_data = json.load(f)

In [3]:
vocab = set()
for dataset in (train_data, test_data):
    for dp in dataset:
        vocab.update(dp["human"])
        vocab.update(dp["machine"])

In [4]:
vocab = sorted(list(vocab))
vocab = ['<eos>', '<system>', '<user>'] + vocab

id_to_char = {id: char for id, char in enumerate(vocab)}
char_to_id = {char: id for id, char in enumerate(vocab)}

In [5]:
def encode(text):
    return [char_to_id[c] for c in text]

def decode(tokens):
    return "".join([id_to_char[t] for t in tokens])

In [6]:
max_seq_len = 45
training_data = []

for dp in train_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    training_data.append(tokens)

In [7]:
testing_data = []

for dp in test_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    testing_data.append(tokens)

In [8]:
def get_test_loss():
    model.eval()
    x = torch.tensor(testing_data).to(device)
    logits = model(x)

    targets = x[:, 1:]
    logits = logits[:, :-1, :]

    targets = targets.reshape(-1)
    logits = logits.reshape(-1, logits.shape[-1])

    loss = loss_fn(logits, targets)
    return loss.item()

In [9]:
device = "cuda:3"

model = TasnsfModel(len(vocab), max_seq_len, 64, 4, 2).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)

batch_size = 128
epoch = 25

In [19]:
print("Model Parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

Model Parameters:  107584


In [10]:
for e in range(epoch):
    model.train()
    print("epoch: ", e)
    total_loss = 0
    
    random.shuffle(training_data)
    for i in range(0, len(training_data), batch_size):
        x = torch.tensor(training_data[i:i+batch_size]).to(device)
        
        logits = model(x)
        
        targets = x[:, 1:]
        logits = logits[:, :-1, :]
        
        targets = targets.reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"training loss: ", total_loss)
    print(f"test loss: ", get_test_loss())

epoch:  0
training loss:  220.85991349816322
test loss:  0.3237839937210083
epoch:  1
training loss:  130.5960134267807
test loss:  0.31441500782966614
epoch:  2
training loss:  126.37815243005753
test loss:  0.31046199798583984
epoch:  3
training loss:  124.86842972040176
test loss:  0.31184518337249756
epoch:  4
training loss:  124.15609988570213
test loss:  0.3080276846885681
epoch:  5
training loss:  123.43970802426338
test loss:  0.3077937066555023
epoch:  6
training loss:  122.9217374920845
test loss:  0.3076123893260956
epoch:  7
training loss:  122.27702531218529
test loss:  0.3062433898448944
epoch:  8
training loss:  121.89008575677872
test loss:  0.305796355009079
epoch:  9
training loss:  121.5611073076725
test loss:  0.30617791414260864
epoch:  10
training loss:  121.37296950817108
test loss:  0.3059357702732086
epoch:  11
training loss:  121.28812265396118
test loss:  0.3048659861087799
epoch:  12
training loss:  121.25072211027145
test loss:  0.30582693219184875
epoch:  

In [11]:
# layernorm = 0.33
# embd = 32: 0.34
# embd = 16: 0.39   
# residual:  0.58
# normal: 0.61

In [21]:
model.eval()

date = "10-05-31"
tokens = [2] + encode(date) + [1]
next_token = -1

while next_token != 0 and len(tokens) < max_seq_len:
    out = model(torch.tensor([tokens]).to(device))
    next_token = out[:, -1].argmax(dim=-1).item()
    tokens.append(next_token)

print(decode(tokens))

<user>10-05-31<system>2010-05-31<eos>
