In [None]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import RandomDataset as Dataset
from classroom import MLPLM
from classroom import TransformerLM
from classroom import AdamW, Sonny, Floyd
from classroom import Plot, Histogram
from classroom import FilteredList
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import CountFilter
from classroom import SumFilter
from classroom.util import numel

## Training

## Scheduling helpers

In [None]:
constant = lambda c: lambda step: c
linear_warmup_then_decay = (lambda lr, warmup: 
    lambda n: lr*(n/warmup) if n < warmup else lr*(warmup/n))

## MLP Language Model

In [None]:
def mlp_factory():
#     model = MLPLM(
#         n_vocab_in=256,
#         n_vocab_out=256,
#         n_ctx=64,
#         d_model=64,
#         d_hidden=8192*2,
#         nonlinearity="GELU").to('cuda')
    model = torch.load('checkpoint_mlp.pt').to('cuda')
    optimizer = AdamW(
        parameters=model.parameters(), 
        eps=constant(1e-4), 
        lr=lambda n: 1e-7*sin(3.14159*n/512)**2 if n > 1000 else 0.0,
        beta1=constant(0.9),
        beta2=constant(0.999),
        weight_decay=constant(0.00),
        n=0)    
#     optimizer = Sonny(
#         parameters=model.parameters(), 
#         lr=lambda n: 1e-7*sin(3.14159*n/8192)**2 if n > 1000 else 0.0,
#         n=0)

    dataset = Dataset()
    batch_size = 2048
    example_length = model.n_ctx + 1
    cost = (32*32*8192 + 256*32 + 8192*256)*3*batch_size
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [None]:
student = mlp_factory()

In [None]:
student.model

In [None]:
def trans_factory():
    model = TransformerLM(
        n_vocab_in=256,
        n_vocab_out=256,
        n_ctx=64,
        d_model=256,
        d_k=16,
        d_v=16,
        n_heads=16,
        d_hidden=4096,
        n_layers=2,
        p_dropout_embedding=0.1,
        p_dropout_attn_mat=0.1,
        p_dropout_attn_out=0.1,
        p_dropout_mlp=0.1).to('cuda')

    optimizer = AdamW(
        parameters=model.parameters(),
        eps=constant(1e-4),
        lr=lambda n: sin(n/512)**2 * 1e-5, 
        beta1=constant(0.9),
        beta2=constant(0.999),
        weight_decay=constant(0.01),
        n=0)
    dataset = Dataset()
    batch_size = 64
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [None]:
classroom = Classroom()
student = mlp_factory()
student.set_baseline(student.model.clone())
classroom.enroll(student)

In [None]:
numel(student.model)

In [None]:
type(student.model)

In [None]:
student.model

## Autocompleting

In [None]:
print(classroom.students[0].autocomplete(n_generate=1024))

## Training Visualization Plots

In [None]:
# student.loss_shaping = lambda x, y: (lambda z: torch.numel(z)*x/torch.sum(1.0/z).item()*(1.0/z))(torch.clamp(y,min=1e-2,max=1.0))
student.loss_shaping = lambda x, y: x # torch.clamp(x,min=1e-2,max=0.99) / torch.clamp(y,min=1e-2,max=0.99)
# /(1-y)/y/4.0/torch.sum(2048.0/(4.0*y*(1.0-y))).item()

#(x / y) - (1 - x)/(1 - y) + 1.0

In [None]:
student.shaped_losses[-1]

In [None]:
class LogSumFilter:
    def __init__(self):
        self.x = 0

    def __call__(self, x):
        self.x += x
        return log(self.x)/log(2.0)
    
    

In [None]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = FilteredList(student.times, SumFilter())
    Y = FilteredList(student.grades, TwoWindowFilter(lag=1024))
    Z = FilteredList(student.baseline_grades, TwoWindowFilter(lag=1024))
    plot_data.update({f"grades-{idx}": (X, Y)})
    plot_data.update({f"baseline-{idx}": (X, Z)})

Plot(**plot_data)

In [None]:
pd = {name: p for (name, p) in student.model.named_parameters()}

In [None]:
H = {}
for (name, p) in pd.items():
    n = torch.numel(p)
    bins = math.floor(math.sqrt(n))
    data = p.detach().cpu().numpy().reshape(-1)
    Y, X = np.histogram(data, bins=int(len(data)**(1/2)))
    H[name] = (X, Y)

In [None]:
Plot(**H)

In [None]:
for key in H:
    print(key)

In [None]:
plots = [Plot(x="value",y=f"pdf",**{key: H[key]}) for key in H]

In [None]:
plots[0]

In [None]:
plots[1]

In [None]:
plots[2]

In [None]:
plots[3]

In [None]:
plots[4]

In [None]:
plots[5]

## modifying layer weights

the first part "canonicalizes", meaning to use a preferred set of weights given some invariances due to LayerNorm, algebraic identities, etc.

In [None]:
def check_layer_stats():
    E = pd["LM.F.layers.0.weight"]
    with torch.no_grad():
        print(torch.mean(E))
        print(torch.var(E))

    W = pd["LM.F.layers.2.F.layers.0.weight"]
    with torch.no_grad():
        print(torch.mean(W))
        print(torch.var(W))

In [None]:
check_layer_stats()

In [None]:
student.model.canonicalize()

In [None]:
check_layer_stats()

## old plots

In [None]:
import time
plot_data_2 = {}
for (idx, student) in enumerate(classroom.students):
    X = FilteredList(student.times, SumFilter())
    Y = FilteredList(student.relative_grades, TwoWindowFilter(lag=1024))
    Z = FilteredList(student.shaped_losses, TwoWindowFilter(lag=1024))
    plot_data_2.update({f"relative-grades-{idx}": (X, Y)})
    plot_data_2.update({f"shaped-losses-{idx}": (X, Z)})
Plot(**plot_data_2)

In [None]:
student.grades[-1]

In [None]:
T = 16
student.optimizer.param_groups[0]["lr"] = lambda n: 0.0 if n<1000 or n%T != 0 else 1e-5/T
student.optimizer.param_groups[0]["beta1"] = lambda n: (1.0-1.0/T)# if n%T != 0 else 0.0
student.optimizer.param_groups[0]["beta2"] = lambda n: (1.0-1.0/T)# if n%T != 0 else 0.0

In [None]:
student.optimizer.param_groups[0]["lr"] = lambda n: 1e-6 * sin(3.14159*n/512)**2

In [None]:
len(student.grades), student.time, student.grades[-1]

In [None]:
torch.save(student.model, f='checkpoint_mlp.pt')

In [None]:
len(student.grades)