In [None]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from testbed import ShortDataset, ByteDataset, SeqByteDataset, Trainer, Net0, Net1, Net2, Net3, Net4, Transformer
from testbed.optimizer import Sonny
from testbed.util import default_device, numel
from testbed.gui import Plot, StatsTicker, ParameterInspector, Histogram, SmoothPlot, LinePlot

In [None]:
utf8_net0_config = {
    "model": {
        "type": Net0,
        "kwargs": {
            "n_vocab_in": 256,
            "n_vocab_out": 256,
            "n_ctx": 64,
            "d_model": 64,
            "d_ff": 8192,
            "nonlinearity": "GELU"}},
    "optimizer": {
        "type": Sonny,
        "kwargs": {
            "eps": 1e-4, 
            "lr": 5e-5, 
            "beta1": .9, 
            "beta2": .999,
            "weight_decay": 0.01,
            "warmup": 10000}},
    "dataset": {
        "type": ByteDataset,
        "kwargs": {
            "batch_size": 8192,
            "example_length": 32+1}}}
utf8_net0_config["dataset"]["kwargs"]["example_length"] = utf8_net0_config["model"]["kwargs"]["n_ctx"]+1 

In [None]:
utf8_transformer_config = {
    "model": {
        "type": Transformer,
        "kwargs": {
            "n_vocab_in": 256,
            "n_vocab_out": 256,
            "max_ctx": 128,
            "d_model": 256,
            "d_k": 16,
            "d_v": 16,
            "n_heads": 16,
            "d_ff": 256,
            "n_layers": 8,
            "p_dropout_embedding": 0.0,
            "p_dropout_attn_mat": 0.0,
            "p_dropout_attn_out": 0.0,
            "p_dropout_ff": 0.0}},
    "optimizer": {
        "type": Sonny,
        "kwargs": {
            "eps": 1e-4, 
            "lr": 1e-4, 
            "beta1": .9, 
            "beta2": .999,
            "weight_decay": 0.01,
            "warmup": 10000}},
    "dataset": {
        "type": SeqByteDataset,
        "kwargs": {
            "batch_size": 256,
            "example_length": 32+1}}}

In [None]:
config = utf8_net0_config

In [None]:
trainer = Trainer(config=config)

In [None]:
trainer.start()

In [None]:
ticker = StatsTicker(trainer.metrics,  x='step', y='mean_loss', kind='line')
ticker

In [None]:
trainer.update("optimizer", lr=5e-04)

In [None]:
trainer.update("dataset", batch_size=8192)

In [None]:
result = ""

In [None]:
async def foo():
    global result
    result += ''.join(list(trainer.autocomplete(n_generate=256, max_ctx=128)))
    return result

In [None]:
for _ in range(2400):
    sleep(15)
    t = asyncio.create_task(foo())
    await t
    from IPython.display import display,HTML,clear_output
    clear_output(wait=True)
    display(HTML(f'<pre>{t.result()}</pre>'))

In [None]:
trainer.metrics[-1]

In [None]:
trainer.save("checkpoint.pt")

### SmoothPlot

In [None]:
L = np.array([[x['step'],x['mean_loss']] for x in trainer.metrics])
X = L[:,0]
Y = L[:,1]
def smoother(data, lag):
    cs = np.cumsum(data)
    return (cs[lag:] - cs[:-lag])/lag

class SmoothPlot(LinePlot):
    def __init__(self, X=None, Y=None, lag=100, log=None):
        if X is not None:
            if Y is None:
                Y = np.array(X)
                X = np.array(range(len(X)))
            else:
                X = np.array(X)
                Y = np.array(Y)
            X = X[lag:]
            Y = smoother(Y, lag)
            if log:
                X = np.log(X)/math.log(2)
        super().__init__(X, Y)
SmoothPlot(X, Y, lag=50, log=False)