In [None]:
import asyncio
import math
from math import log, sin, cos, tan, exp, sqrt, pi
import time
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import RandomDataset as Dataset
from classroom import MLPLM
from classroom import MLPLM2
from classroom import TransformerLM
from classroom import AdamW, Sonny, Floyd
from classroom import Plot, Histogram
from classroom import FilteredList
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import CountFilter
from classroom import SumFilter
from classroom.util import numel

In [None]:
constant = lambda c: lambda step: c
linear_warmup_then_decay = (lambda lr, warmup: 
    lambda n: lr*(n/warmup) if n < warmup else lr*(warmup/n))

In [None]:
def mlp_factory(checkpoint=None):
    if checkpoint is None:
        model = MLPLM(
            n_vocab_in=256,
            n_vocab_out=256,
            n_ctx=32,
            d_model=32,
            d_hidden=8192,
            nonlinearity="GELU").to('cuda')
    else:
        model = torch.load(checkpoint).to('cuda')
    optimizer = AdamW(
        parameters=model.parameters(), 
        lr=lambda n: 1e-6*sin(3.14159*n/512)**2,
        alpha=lambda n: 0.0 if n == 0 else 0.0,
        beta1=lambda n: 0.9,
        beta2=lambda n: 0.999,
        weight_decay=lambda n: 0.01,
        n=0)    
#     optimizer = Sonny(
#         parameters=model.parameters(), 
#         lr=lambda n: 0.0 if n < 10000 else 1e-5,
#         alpha=lambda n: 0.0 if n < 10000 else 0.9,
#         n=0)    
    dataset = Dataset()
    batch_size = 1024
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [None]:
classroom = Classroom()
student = mlp_factory('checkpoint.pt')
student.set_baseline(student.model.clone())
classroom.enroll(student)

In [None]:
# classroom = Classroom()
# student = mlp_factory()
# classroom.enroll(student)

In [None]:
numel(student.model)

In [None]:
type(student.model)

In [None]:
student.model

## Autocompleting

In [None]:
print(classroom.students[0].autocomplete(n_generate=1024))

## Training Visualization Plots

In [None]:
# student.loss_shaping = lambda x, y: (lambda z: torch.numel(z)*x/torch.sum(1.0/z).item()*(1.0/z))(torch.clamp(y,min=1e-2,max=1.0))
student.loss_shaping = lambda x, y: x # torch.clamp(x,min=1e-2,max=0.99) / torch.clamp(y,min=1e-2,max=0.99)
# /(1-y)/y/4.0/torch.sum(2048.0/(4.0*y*(1.0-y))).item()

#(x / y) - (1 - x)/(1 - y) + 1.0

In [None]:
student.shaped_losses[-1]

In [None]:
class LogSumFilter:
    def __init__(self):
        self.x = 0

    def __call__(self, x):
        self.x += x
        return log(self.x)/log(2.0)
    
    

In [None]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = FilteredList(student.times, SumFilter())
    Y = FilteredList(student.grades, TwoWindowFilter(lag=64))
    Z = FilteredList(student.baseline_grades, TwoWindowFilter(lag=64))
    plot_data.update({f"grades-{idx}": (X, Y)})
    plot_data.update({f"baseline-{idx}": (X, Z)})

Plot(**plot_data)

In [None]:
student.optimizer.param_groups[0]["lr"] = lambda n: 1e-7*sin(3.14159*n/512)**2
student.optimizer.param_groups[0]["alpha"] = lambda n: 0.0
student.optimizer.param_groups[0]["weight_decay"] = lambda n: 1e-2*sin(3.14159*n/512)**2

student.batch_size = 1024

In [None]:
len(student.grades)

In [None]:
pd = {name: p for (name, p) in student.model.named_parameters()}

In [None]:
H = {}
for (name, p) in pd.items():
    n = torch.numel(p)
    bins = math.floor(math.sqrt(n))
    data = p.detach().cpu().numpy().reshape(-1)
    Y, X = np.histogram(data, bins=int(len(data)**(1/2)), density=True)
    H[name] = (X, Y)

In [None]:
for key in H:
    print(key)

In [None]:
plots = [Plot(x="value",y=f"pdf",**{key: H[key]}) for key in H]

In [None]:
plots[0]

In [None]:
plots[1]

In [None]:
plots[2]

In [None]:
plots[3]

In [None]:
plots[4]

In [None]:
plots[5]

In [None]:
torch.save(student.model, f='checkpoint.pt')