In [1]:
from basic import Transf
import torch
import torch.nn as nn
import json
import random

In [2]:
with open("../train.json") as f:
    train_data = json.load(f)
with open("../test.json") as f:
    test_data = json.load(f)

In [3]:
vocab = set()
for dataset in (train_data, test_data):
    for dp in dataset:
        vocab.update(dp["human"])
        vocab.update(dp["machine"])

In [4]:
vocab = sorted(list(vocab))
vocab = ['<eos>', '<system>', '<user>'] + vocab

id_to_char = {id: char for id, char in enumerate(vocab)}
char_to_id = {char: id for id, char in enumerate(vocab)}

In [5]:
def encode(text):
    return [char_to_id[c] for c in text]

def decode(tokens):
    return "".join([id_to_char[t] for t in tokens])

In [6]:
max_seq_len = 45
training_data = []

for dp in train_data[:1000]:
    tokens = [2] + encode(dp["human"]) + [1] + encode(dp["machine"]) + [0]
    tokens = tokens + [0] * (max_seq_len - len(tokens))
    training_data.append(tokens)

In [7]:
model = Transf(len(vocab), max_seq_len)

batch_size = 32
epoch = 10

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [8]:
for e in range(epoch):
    print("training for epoch: ", e)
    
    random.shuffle(training_data)
    for i in range(0, len(training_data), batch_size):
        x = torch.tensor(training_data[i:i+batch_size])
        
        logits = model(x)
        
        targets = x[:, 1:]
        logits = logits[:, :-1, :]
        
        targets = targets.reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"loss at step {i}: ", loss.item())

training for epoch:  0
loss at step 992:  1.7269295454025269
training for epoch:  1
loss at step 992:  1.6334483623504639
training for epoch:  2
loss at step 992:  1.3470579385757446
training for epoch:  3
loss at step 992:  1.293759822845459
training for epoch:  4
loss at step 992:  1.3505126237869263
training for epoch:  5
loss at step 992:  0.9953502416610718
training for epoch:  6
loss at step 992:  1.1052472591400146
training for epoch:  7
loss at step 992:  1.2807233333587646
training for epoch:  8
loss at step 992:  0.9422165155410767
training for epoch:  9
loss at step 992:  0.9453085064888


In [12]:
model.eval()

date = "october 22 1999"
tokens = [2] + encode(date) + [1]
next_token = -1

while next_token != 0 and len(tokens) < max_seq_len:
    out = model(torch.tensor([tokens]))
    next_token = out[:, -1].argmax(dim=-1).item()
    tokens.append(next_token)

print(decode(tokens))

<user>october 22 1999<system>200-01-0-0-0<eos>
