In [1]:
from basic import Transf
import torch
import torch.nn as nn
import json
import random

In [2]:
with open("../train.json") as f:
    train_data = json.load(f)
with open("../test.json") as f:
    test_data = json.load(f)

In [3]:
vocab = set()
for dataset in (train_data, test_data):
    for dp in dataset:
        vocab.update(dp["human"])
        vocab.update(dp["machine"])

In [4]:
vocab = sorted(list(vocab))
vocab = ['<eos>', '<system>', '<user>'] + vocab

id_to_char = {id: char for id, char in enumerate(vocab)}
char_to_id = {char: id for id, char in enumerate(vocab)}

In [5]:
def encode(text):
    return [char_to_id[c] for c in text]

def decode(tokens):
    return "".join([id_to_char[t] for t in tokens])

In [6]:
max_seq_len = 45
training_data = []

for dp in train_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    training_data.append(tokens)

In [7]:
testing_data = []

for dp in test_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    testing_data.append(tokens)

In [8]:
def get_test_loss():
    model.eval()
    x = torch.tensor(testing_data).to(device)
    logits = model(x)

    targets = x[:, 1:]
    logits = logits[:, :-1, :]

    targets = targets.reshape(-1)
    logits = logits.reshape(-1, logits.shape[-1])

    loss = loss_fn(logits, targets)
    return loss.item()

In [9]:
device = "cuda:1"

model = Transf(len(vocab), max_seq_len).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)

batch_size = 64
epoch = 25

In [10]:
for e in range(epoch):
    model.train()
    print("epoch: ", e)
    total_loss = 0
    
    random.shuffle(training_data)
    for i in range(0, len(training_data), batch_size):
        x = torch.tensor(training_data[i:i+batch_size]).to(device)
        
        logits = model(x)
        
        targets = x[:, 1:]
        logits = logits[:, :-1, :]
        
        targets = targets.reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"training loss: ", total_loss)
    print(f"test loss: ", get_test_loss())

epoch:  0
training loss:  545.0939260721207
test loss:  0.5213320255279541
epoch:  1
training loss:  395.15505388379097
test loss:  0.472167044878006
epoch:  2
training loss:  371.17029666900635
test loss:  0.4532800018787384
epoch:  3
training loss:  356.1808550655842
test loss:  0.43435367941856384
epoch:  4
training loss:  345.6050090789795
test loss:  0.4230087399482727
epoch:  5
training loss:  336.89535304903984
test loss:  0.41967064142227173
epoch:  6
training loss:  331.0449362695217
test loss:  0.41087567806243896
epoch:  7
training loss:  326.28539139032364
test loss:  0.40485554933547974
epoch:  8
training loss:  320.8095515370369
test loss:  0.3968316316604614
epoch:  9
training loss:  314.4603282511234
test loss:  0.38931307196617126
epoch:  10
training loss:  310.3669308125973
test loss:  0.38659921288490295
epoch:  11
training loss:  307.912470638752
test loss:  0.38443049788475037
epoch:  12
training loss:  306.09815445542336
test loss:  0.3813968300819397
epoch:  13
t

In [11]:
model.eval()

date = "26 jan 2006"
tokens = [2] + encode(date) + [1]
next_token = -1

while next_token != 0 and len(tokens) < max_seq_len:
    out = model(torch.tensor([tokens]).to(device))
    next_token = out[:, -1].argmax(dim=-1).item()
    tokens.append(next_token)

print(decode(tokens))

<user>26 jan 2006<system>2006-01-26<eos>
