In [None]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from testbed import ShortDataset, ByteDataset, SeqByteDataset, Trainer, Net0, Net1, Net2, Net3, Net4, Transformer
from testbed.optimizer import Sonny
from testbed.util import default_device, numel
from testbed.gui import Plot, StatsTicker, ParameterInspector, Histogram, SmoothPlot, LinePlot

In [None]:
utf8_net0_config = {
    "model": {
        "type": Net0,
        "kwargs": {
            "n_vocab_in": 256,
            "n_vocab_out": 256,
            "n_ctx": 64,
            "d_model": 64,
            "d_ff": 8192,
            "nonlinearity": "GELU"}},
    "optimizer": {
        "type": Sonny,
        "kwargs": {
            "eps": 1e-4, 
            "lr": 5e-5, 
            "beta1": .9, 
            "beta2": .999,
            "weight_decay": 0.01,
            "warmup": 10000}},
    "dataset": {
        "type": ByteDataset,
        "kwargs": {
            "batch_size": 256}}}
utf8_net0_config["dataset"]["kwargs"]["example_length"] = utf8_net0_config["model"]["kwargs"]["n_ctx"]+1 

In [None]:
utf8_transformer_config = {
    "model": {
        "type": Transformer,
        "kwargs": {
            "n_vocab_in": 256,
            "n_vocab_out": 256,
            "max_ctx": 128,
            "d_model": 256,
            "d_k": 16,
            "d_v": 16,
            "n_heads": 16,
            "d_ff": 256,
            "n_layers": 8,
            "p_dropout_embedding": 0.0,
            "p_dropout_attn_mat": 0.0,
            "p_dropout_attn_out": 0.0,
            "p_dropout_ff": 0.0}},
    "optimizer": {
        "type": Sonny,
        "kwargs": {
            "eps": 1e-4, 
            "lr": 1e-4, 
            "beta1": .9, 
            "beta2": .999,
            "weight_decay": 0.01,
            "warmup": 10000}},
    "dataset": {
        "type": SeqByteDataset,
        "kwargs": {
            "batch_size": 256,
            "example_length": 32+1}}}

In [None]:
config = utf8_net0_config

In [None]:
trainer = Trainer(config=config)

In [None]:
trainer.start(2000)

In [None]:
ticker = StatsTicker(trainer,  x='step', y='mean_loss')
ticker

In [None]:
trainer.update("optimizer", lr=1e-03)

In [None]:
trainer.update("dataset", batch_size=8192)

In [None]:
result = ""


In [None]:
more = ''.join(list(trainer.autocomplete(result[:128],n_generate=256, max_ctx=128)))
print(more)

In [None]:
with open('gibberish.txt', 'w') as outfile:
    outfile.write(result)

In [None]:
async def foo():
    global result
    more = ''.join(list(trainer.autocomplete(n_generate=256, max_ctx=128)))
    result += more
    with open('gibberish.txt', 'a') as outfile:
        outfile.write(more)
    return result

In [None]:
for _ in range(2400):
    sleep(15)
    t = asyncio.create_task(foo())
    await t

In [None]:
trainer.metrics[-1]

In [None]:
trainer.save("checkpoint.pt")

In [None]:
trainer.load('checkpoint.pt')

### SmoothPlot

In [None]:
import scipy.ndimage

def smoother(X, Y, lag):
    Y = np.cumsum(Y)
    return X[lag:], (Y[lag:] - Y[:-lag])/lag

def gsmoother(X, Y, lag):
    X = X[lag:-lag]
    Y = scipy.ndimage.gaussian_filter1d(Y, sigma=lag)[lag:-lag]
    return (X, Y)

class SmoothPlot(LinePlot):
    def __init__(self, trainer, lag=100, log=None):
        L = np.array([[x['step'],x['mean_loss']] for x in trainer.metrics])
        n = len(L[:,0])
        k = n//1000 + 1
        X = L[:,0]
        Y = L[:,1]
        X,Y = gsmoother(X, Y, lag)
        X = X[::k]
        Y = Y[::k]
        if log:
            X = np.log(X)/math.log(2)
        super().__init__(X, Y)

class GaussianSmoothedLossRate(LinePlot):
    def __init__(self, trainer, lag=100, log=None):
        L = np.array([[x['step'],x['mean_loss']] for x in trainer.metrics])
        X = L[1:,0]
        Y = -L[1:,1] + L[:-1,1]
        X,Y = gsmoother(X, Y, lag)
        if log:
            X = np.log(X)/math.log(2)
        super().__init__(X, Y)

In [None]:
SmoothPlot(trainer, lag=10, log=False)

In [None]:
GaussianSmoothedLossRate(trainer, lag=10000, log=False)