In [4]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import numpy as np
from tqdm import trange
from torch.nn.functional import pad

from transformer_model import Transformer, SelfAttention, MLP

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
device = "cuda"

In [2]:
dims = 5
seq_len = 20

In [69]:
def generate_data(batch_size, seq_len, dims):
    w = torch.normal(0, 1, size=(batch_size, dims), device=device)
    x = torch.normal(0, 1, size=(batch_size, seq_len, dims), device=device)
    y = torch.einsum("bp, blp -> bl", w, x)
    y = pad(input=y.unsqueeze(-1), pad=(0, dims-1), mode="constant", value=0.0)
    assert y.shape == (batch_size, seq_len, dims)
    interweaved = torch.stack((x, y), dim=2).view(batch_size, 2 * seq_len, dims)
    return x, y, w, interweaved

In [72]:
transformer = Transformer(
    num_layers=4,
    input_dim=dims,
    num_heads=1,
    mlp_hidden_dim=16,
    is_causal=True,
    activation="softmax"
).to(device)

optimizer = torch.optim.SGD(transformer.parameters(), lr=0.01)
for _ in trange(200):
    x, y, w, interweaved = generate_data(batch_size=8, seq_len=seq_len, dims=dims)
    output = transformer(interweaved[:, :-1, :])
    pred = output[:, ::2, 0]
    loss = torch.mean((pred - [:, :, 0]) ** 2)
    assert not torch.isnan(loss), loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:00<00:00, 246.40it/s]


In [75]:
transformer.eval()
with torch.no_grad():
    x, y, w, interweaved = generate_data(batch_size=8, seq_len=seq_len, dims=dims)
    output = transformer(interweaved[:, :-1, :])
    pred = output[:, ::2, 0]
    loss = torch.mean((pred - y[:, :, 0]) ** 2)
    print(pred)
    print(y)
    print(loss)

tensor([[[ 9.1009e-02, -5.5148e-01,  1.7046e-01,  1.8366e-01,  1.0827e-01],
         [ 1.6617e-01,  1.1745e-01, -3.7107e-01,  5.2386e-01, -3.1543e-01],
         [ 1.4724e+00, -4.7047e-01,  6.5626e-01, -9.9991e-01, -4.8115e-01],
         [ 3.5369e-02,  4.2125e-03,  7.5161e-01, -1.5519e-01, -7.9915e-01],
         [ 2.8118e-01, -9.1933e-01, -5.9567e-02,  2.3553e-01, -3.9159e-01],
         [ 8.3741e-01, -4.2605e-01,  3.3581e-01,  8.3889e-01, -5.3482e-01],
         [ 4.0967e-01,  8.5312e-01,  1.3412e+00, -2.3243e-01, -4.5264e-01],
         [-9.2227e-01, -5.0051e-01,  1.2273e-01,  4.6618e-01, -7.9138e-01],
         [ 1.6432e-01,  4.3471e-01,  2.8724e-01, -6.6718e-01,  5.6889e-01],
         [-4.1473e-01, -6.0615e-01, -7.5731e-01, -1.4083e-01,  3.5127e-01],
         [-2.7257e-01,  8.5586e-02,  2.8208e-01, -6.8479e-01,  3.4827e-02],
         [-8.7636e-01,  2.8781e-01,  1.0884e-01, -4.2279e-01,  5.5273e-01],
         [ 9.6939e-01,  1.5806e-01, -1.3531e-01,  2.9140e-01, -1.0784e+00],
         [ 7