In [None]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from classroom import UTF8Dataset, MLPLM, TransformerLM
from classroom import AdamW, Learner, Plot, Filter
from classroom import KalmanFilter1D, MedianFilter, TwoWindowFilter

## Training

In [None]:
lock = asyncio.Lock()
total_time = 0.0
def start_training(config, hyp):
    learner = Learner(config=config)
    X = []
    Y = []
    time = lambda: asyncio.get_event_loop().time()
    begin = time()
    elapsed = 0.0
    last_release = 0.0
    async def train(n_steps, batch_size, example_length):
        global total_time
        nonlocal X, Y, elapsed
        try:
            for step in range(n_steps):
                start = time()
                mean_loss = np.sum(learner.step(batch_size, example_length))/batch_size
                end = time()
                elapsed += end - start
                async with lock:
                    total_time += end-start
                if elapsed / total_time > .51:
                    await asyncio.sleep(2*(end-start))
                else:
                    await asyncio.sleep(0.01)
                X.append(elapsed)
                Y.append(mean_loss)
        except Exception as e:
            print(e)
            return e
        return None
    task = asyncio.create_task(train(**hyp))
    return (learner, X, Y)

## Scheduling helpers

In [None]:
constant = lambda c: lambda step: c
linear_warmup_then_decay = (lambda lr, warmup: 
    lambda n: lr*(n/warmup) if n < warmup else lr*(warmup/n))

## MLP Language Model

In [None]:
mlp_model = MLPLM(n_vocab_in=256, n_vocab_out=256, n_ctx=32,
              d_model=32, d_hidden=8192, nonlinearity="sigmoid").to('cuda')
mlp_optimizer = AdamW(parameters=mlp_model.parameters(), eps=constant(1e-4), 
                  lr=constant(1e-3), 
                  beta1=constant(0.9), beta2=constant(0.999),
                  weight_decay=constant(0.01), initial_step=0)
mlp_dataset = UTF8Dataset()
mlp_config = {"model": mlp_model, "optimizer": mlp_optimizer, "dataset": mlp_dataset}
mlp_hyp = {"n_steps": 2**20, "batch_size": 256, "example_length": 33}
(mlp_learner, mlp_times, mlp_losses) = start_training(mlp_config, mlp_hyp)
mlp_plot = {"MLP": (mlp_times, mlp_losses)}

## Transformer Language Model

In [None]:
trans_model = TransformerLM(n_vocab_in=256, n_vocab_out=256, n_ctx=32, d_model=256,
                      d_k=16, d_v=16, n_heads=16, d_hidden=8192, n_layers=8, p_dropout_embedding=0.1,
                      p_dropout_attn_mat=0.1, p_dropout_attn_out=0.1, p_dropout_mlp=0.1).to('cuda')

trans_optimizer = AdamW(parameters=trans_model.parameters(), eps=constant(1e-4),
                  lr=linear_warmup_then_decay(lr=1e-4,warmup=100), 
                  beta1=constant(0.9), beta2=constant(0.999), weight_decay=constant(0.01),
                  initial_step=0)

trans_dataset = UTF8Dataset()

trans_config = {"model": trans_model, "optimizer": trans_optimizer, "dataset": trans_dataset}
trans_hyp = {"n_steps": 2**20, "batch_size": 256, "example_length": 33}
(trans_learner, trans_times, trans_losses) = start_training(trans_config, trans_hyp)
trans_plot = {"Transformer": (trans_times, trans_losses)}

## Autocompleting

In [None]:
print(mlp_learner.autocomplete(n_generate=1024))

In [None]:
print(trans_learner.autocomplete(n_generate=1024))

## Training Visualization Plots

In [None]:
plot_data = {}
plot_data.update({"MLP": (mlp_times, mlp_losses)})
plot_data.update({"Transformer": (trans_times, trans_losses)})
Plot(**plot_data)

In [None]:
X, Y = mlp_plot["MLP"]
KY = Filter(Y, KalmanFilter1D())
MY = Filter(Y, MedianFilter())
TWY = Filter(Y, TwoWindowFilter())
plot_data = {}
plot_data.update({"MLP": (X, Y)})
plot_data.update({"MLP-kalman": (X, KY)})
plot_data.update({"MLP-median": (X, MY)})
plot_data.update({"MLP-twowindow": (X, TWY)})

In [None]:
Plot(**plot_data)