In [35]:
!pip install wandb -q

## Init wandb for logging

In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkwargs[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Import all libraries and set seeds

In [45]:
import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import random
import os

def seed_everything(seed: int):   
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

# Data

## Make custom dataset class

In [4]:
from itertools import product


class NumbersDataset(Dataset):
    def __init__(self, left_len, right_len, tokenizer):
        super(NumbersDataset, self).__init__()
        self.left_len = left_len
        self.right_len = right_len
        self.tokenizer = tokenizer
        self._build()
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def _build(self):
        self.X = []
        self.y = []
        left_range = range(10**self.left_len)
        right_range = range(10**self.right_len)
        for i in product(left_range, right_range):
            self.X.append(self.tokenizer(f'{i[0]}+{i[1]}')) # I'm really sorry for this mess
            self.y.append([i[0]+i[1]])
        self.X = self.X
        self.y = self.y

## Simple tokenizer, nothing more

In [5]:
def tokenizer(string):
    subs = {
         "+": 10, 
    }
    return [int(subs.get(item, item)) for item in string]

tokenizer('2+22')

[2, 10, 2, 2]

## Assemble average-sized dataset

In [6]:
%%time
nums = NumbersDataset(2, 2, tokenizer)

CPU times: user 26.6 ms, sys: 3.95 ms, total: 30.5 ms
Wall time: 36.1 ms


In [7]:
len(nums)

10000

## Get dataloaders

In [8]:
def collate_fn(items):
    tokens = [torch.tensor(i[0]) for i in items]
    labels = [torch.tensor(i[1]) for i in items]
    packed_tokens = torch.nn.utils.rnn.pack_sequence(tokens, enforce_sorted=False)
    packed_labels = torch.nn.utils.rnn.pack_sequence(labels, enforce_sorted=False)
    return packed_tokens, torch.stack(labels)

train, test = torch.utils.data.random_split(nums, [0.8, 0.2])
train_loader = DataLoader(train, num_workers=2, batch_size=4, drop_last=True, shuffle=True, pin_memory=True, collate_fn=collate_fn) 
test_loader = DataLoader(test, num_workers=2, batch_size=4, drop_last=True, shuffle=True, pin_memory=True, collate_fn=collate_fn)
ex = next(iter(train_loader))
# ex[0], ex[1]

In [9]:
%%time
nums[0]

CPU times: user 6 µs, sys: 2 µs, total: 8 µs
Wall time: 13.8 µs


([0, 10, 0], [0])

In [10]:
%%time
next(iter(train_loader));

CPU times: user 7.24 ms, sys: 80.4 ms, total: 87.6 ms
Wall time: 215 ms


[PackedSequence(data=tensor([ 3,  9,  1,  0,  3,  6,  8, 10, 10, 10, 10,  7,  6,  2,  3,  2,  0,  6]), batch_sizes=tensor([4, 4, 4, 4, 2]), sorted_indices=tensor([0, 3, 1, 2]), unsorted_indices=tensor([0, 2, 3, 1])),
 tensor([[ 93],
         [ 21],
         [ 72],
         [122]])]

# Train

## Assemble model
That is embeddings layer + rnn + linear + relu + linear

In [11]:
class LSTMCalc(nn.Module):
    def __init__(self, hidden_size: int = 128, num_layers: int = 4, dropout: float = 0.1):
        super().__init__()
        # 11 stands for 0-9 and +
        self.embeddings = nn.Embedding(11, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers, dropout=dropout)
        self.ln = nn.LayerNorm(hidden_size)
        self.lin1 = nn.Linear(hidden_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.head = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, length_inputs):
        """
        Args:
            input_ids: Torch.Tensor, shape: (seq_length, batch_size, hidden_size)
        """
        embs = self.embeddings(input_ids)

        packed_sequences = torch.nn.utils.rnn.pack_padded_sequence(embs, length_inputs, enforce_sorted=False)
        rnn_outputs, _ = self.rnn(packed_sequences)
        unpacked_sequences, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_outputs)
        unpacked_sequences = unpacked_sequences.permute(1, 0, 2)
        out = self.ln(unpacked_sequences)
        out = self.lin1(out)
        out = self.relu1(out)
        out = self.head(unpacked_sequences)
        return out.mean(dim=1)

In [12]:
config = dict(
    hidden_size=256,
    num_layers=4,
    dropout=0.1
)
model = LSTMCalc(**config)

## Run train loop

In [13]:
# Set number of epochs
epochs = 20

device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')

loss_fn = nn.L1Loss()

optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

wandb.init(
    # set the wandb project where this run will be logged
    project="llmcalc",
    
    # track hyperparameters and run metadata
    config={
        "group":"e",
        "model":"lstm",
        "hidden_size":config['hidden_size'],
        "num_layers":config['num_layers'],
        "left": 2,
        "right": 2,
        "bs":4
    }
)

model = model.to(device, non_blocking=True)
for epoch in tqdm(range(epochs)):
    model.train()
    losses = []
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inputs, labels = batch
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True) 
        inputs, inputs_length = torch.nn.utils.rnn.pad_packed_sequence(inputs)
        
        outputs = model(inputs, inputs_length)
        loss = loss_fn(outputs, labels)
        losses.append(loss)
        loss.backward()

        optimizer.step()

    model.eval()
    test_losses = []
    for i, batch in enumerate(test_loader):
        with torch.no_grad():
            optimizer.zero_grad()
            inputs, labels = batch
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True) 
            
            inputs, inputs_length = torch.nn.utils.rnn.pad_packed_sequence(inputs)
            
            outputs = model(inputs, inputs_length)
            test_loss = loss_fn(outputs, labels,)
            test_losses.append(test_loss)

    wandb.log({
        "epoch":epoch,
        "train/loss": sum(losses)/len(losses),
        "test/loss": sum(test_losses)/len(test_losses)
        })
    print(f"Epoch: {epoch} | Loss: {sum(losses)/len(losses):.5f} | Test loss:{sum(test_losses)/len(test_losses):.5f}") 
wandb.finish()

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0 | Loss: 34.85438 | Test loss:24.48872
Epoch: 1 | Loss: 8.71373 | Test loss:4.09806
Epoch: 2 | Loss: 3.83348 | Test loss:3.11417
Epoch: 3 | Loss: 3.11059 | Test loss:3.14050
Epoch: 4 | Loss: 2.68024 | Test loss:2.55691
Epoch: 5 | Loss: 2.37404 | Test loss:2.33686
Epoch: 6 | Loss: 2.14232 | Test loss:2.10334
Epoch: 7 | Loss: 1.92731 | Test loss:1.90592
Epoch: 8 | Loss: 1.78601 | Test loss:1.77497
Epoch: 9 | Loss: 1.66357 | Test loss:1.69086
Epoch: 10 | Loss: 1.58841 | Test loss:1.78035
Epoch: 11 | Loss: 1.53379 | Test loss:1.64932
Epoch: 12 | Loss: 1.44270 | Test loss:1.47186
Epoch: 13 | Loss: 1.38389 | Test loss:1.43398
Epoch: 14 | Loss: 1.31618 | Test loss:1.38979
Epoch: 15 | Loss: 1.30379 | Test loss:1.35755
Epoch: 16 | Loss: 1.24413 | Test loss:1.31339
Epoch: 17 | Loss: 1.17736 | Test loss:1.35129
Epoch: 18 | Loss: 1.15963 | Test loss:1.20077
Epoch: 19 | Loss: 1.10976 | Test loss:1.35080


VBox(children=(Label(value='0.001 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.110090…

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
test/loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19.0
test/loss,1.3508
train/loss,1.10976


# Save/Load model

In [28]:
if not os.path.exists('lstm'):
    torch.save(model.to('cpu').state_dict(), 'lstm')

In [29]:
if os.path.exists('lstm'):
    model.load_state_dict(torch.load('lstm'))
else:
    print('No saved model!')

# Inference

In [39]:
# Once again, sorry for this mess
inputs = torch.tensor(test[0][0])
print(inputs)
labels = test[0][1]
print(labels)
def collate_fn(items):
    tokens = [torch.tensor(i[0]) for i in items]
    labels = [torch.tensor(i[1]) for i in items]
    packed_tokens = torch.nn.utils.rnn.pack_sequence(tokens, enforce_sorted=False)
    packed_labels = torch.nn.utils.rnn.pack_sequence(labels, enforce_sorted=False)
    return packed_tokens, torch.stack(labels)
inputs, labels = collate_fn([[inputs, labels]])
inputs, inputs_length = torch.nn.utils.rnn.pad_packed_sequence(inputs)
(
model(inputs, inputs_length),
test[0][0], 
test[0][1],
loss_fn(
        model(inputs, inputs_length),
        torch.tensor(test[0][1])
       )
)

tensor([ 4,  8, 10,  9,  5])
[143]


  tokens = [torch.tensor(i[0]) for i in items]
  return F.l1_loss(input, target, reduction=self.reduction)


(tensor([[142.6019]], grad_fn=<MeanBackward1>),
 [4, 8, 10, 9, 5],
 [143],
 tensor(0.3981, grad_fn=<MeanBackward0>))

# Some frontend


In [58]:
text = '48+95' #@param {type:"string"}
label = 143 #@param {type:"number"}
tokenized = torch.tensor(tokenizer(text))
label = [label]
def collate_fn(items):
    tokens = [torch.tensor(i[0]) for i in items]
    labels = [torch.tensor(i[1]) for i in items]
    packed_tokens = torch.nn.utils.rnn.pack_sequence(tokens, enforce_sorted=False)
    packed_labels = torch.nn.utils.rnn.pack_sequence(labels, enforce_sorted=False)
    return packed_tokens, torch.stack(labels)
inputs, labels = collate_fn([[tokenized, label]])
inputs, inputs_length = torch.nn.utils.rnn.pad_packed_sequence(inputs)

out = model(inputs, inputs_length)[0][0]

print(f"Model answer: ~{out.round()} ({out}),\nTrue answer is {label[0]}")
if label != None:
    print("MAE is",
        loss_fn(
                model(inputs, inputs_length),
                torch.tensor(label)
            ).item()
    )

Model answer: ~143.0 (142.60191345214844),
True answer is 143
MAE is 0.3980865478515625
