In [1]:
from basic import TasnsfModel
import torch
import torch.nn as nn
import json
import random

In [2]:
with open("../train.json") as f:
    train_data = json.load(f)
with open("../test.json") as f:
    test_data = json.load(f)

In [3]:
vocab = set()
for dataset in (train_data, test_data):
    for dp in dataset:
        vocab.update(dp["human"])
        vocab.update(dp["machine"])

In [4]:
vocab = sorted(list(vocab))
vocab = ['<eos>', '<system>', '<user>'] + vocab

id_to_char = {id: char for id, char in enumerate(vocab)}
char_to_id = {char: id for id, char in enumerate(vocab)}

In [5]:
def encode(text):
    return [char_to_id[c] for c in text]

def decode(tokens):
    return "".join([id_to_char[t] for t in tokens])

In [6]:
max_seq_len = 45
training_data = []

for dp in train_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    training_data.append(tokens)

In [7]:
testing_data = []

for dp in test_data:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    testing_data.append(tokens)

In [8]:
def get_test_loss():
    model.eval()
    x = torch.tensor(testing_data).to(device)
    logits = model(x)

    targets = x[:, 1:]
    logits = logits[:, :-1, :]

    targets = targets.reshape(-1)
    logits = logits.reshape(-1, logits.shape[-1])

    loss = loss_fn(logits, targets)
    return loss.item()

In [9]:
device = "cuda:3"

model = TasnsfModel(len(vocab), max_seq_len, 64, 4, 2).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3)

batch_size = 128
epoch = 25

In [10]:
print("Model Parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

Model Parameters:  107584


In [11]:
for e in range(epoch):
    model.train()
    print("epoch: ", e)
    total_loss = 0
    
    random.shuffle(training_data)
    for i in range(0, len(training_data), batch_size):
        x = torch.tensor(training_data[i:i+batch_size]).to(device)
        
        logits = model(x)
        
        targets = x[:, 1:]
        logits = logits[:, :-1, :]
        
        targets = targets.reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits, targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"training loss: ", total_loss)
    print(f"test loss: ", get_test_loss())

epoch:  0
training loss:  224.95236161351204
test loss:  0.32465335726737976
epoch:  1
training loss:  131.09185615181923
test loss:  0.3132728040218353
epoch:  2
training loss:  126.56800454854965
test loss:  0.31135252118110657
epoch:  3
training loss:  124.96015945076942
test loss:  0.3097771108150482
epoch:  4
training loss:  124.01046970486641
test loss:  0.3068845272064209
epoch:  5
training loss:  123.36487957835197
test loss:  0.30769580602645874
epoch:  6
training loss:  122.87485787272453
test loss:  0.3071589469909668
epoch:  7
training loss:  122.64819866418839
test loss:  0.3065885901451111
epoch:  8
training loss:  122.22549682855606
test loss:  0.3072598874568939
epoch:  9
training loss:  121.94286334514618
test loss:  0.3056713938713074
epoch:  10
training loss:  121.6965734064579
test loss:  0.30629098415374756
epoch:  11
training loss:  121.33313184976578
test loss:  0.3057418465614319
epoch:  12
training loss:  121.30527094006538
test loss:  0.304918110370636
epoch: 

In [12]:
model.eval()

date = "10-05-31"
tokens = [2] + encode(date) + [1]
next_token = -1

while next_token != 0 and len(tokens) < max_seq_len:
    out = model(torch.tensor([tokens]).to(device))
    next_token = out[:, -1].argmax(dim=-1).item()
    tokens.append(next_token)

print(decode(tokens))

<user>10-05-31<system>2010-05-31<eos>
