In [None]:
import asyncio
import math
from math import log, sin, cos, tan, exp, sqrt, pi
import time
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import BytesDataset
from classroom import RandomTokensDataset
from classroom import MLPLM
from classroom import MLPLM2
from classroom import TransformerLM
from classroom import AdamW
from classroom import Sonny
from classroom import Floyd
from classroom import Plot
from classroom import Fun
from classroom import Count
from classroom import Sum
from classroom import Diff
from classroom import Log2
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import numel

In [None]:
def student_factory(checkpoint=None):
    if checkpoint is None:
        model = MLPLM(
            n_vocab_in=256,
            n_vocab_out=256,
            n_ctx=32,
            d_model=32,
            d_hidden=8192,
            nonlinearity="GELU").to('cuda')
    else:
        model = torch.load(checkpoint).to('cuda')
    optimizer = AdamW(
        parameters=model.parameters(), 
        lr=lambda n: 1e-7,
        alpha=lambda n: 0.0 if n == 0 else 0.0,
        beta1=lambda n: 0.9,
        beta2=lambda n: 0.999,
        weight_decay=lambda n: 0.01,
        n=0)    
    dataset = BytesDataset(path='/home/sharker/data/gutenberg.utf8')
    batch_size = 1024
    example_length = model.n_ctx + 1
    return Student(
        model=model,
        optimizer=optimizer,
        dataset=dataset,
        batch_size=batch_size,
        example_length=example_length)

In [None]:
path = 'checkpoint.pt'

In [None]:
classroom = Classroom()
student = student_factory()
classroom.enroll(student)

In [None]:
student.model, numel(student.model)

## Autocompleting

In [None]:
print(student.autocomplete(n_generate=1024))

In [None]:
student.push()

## Training Visualization Plots

In [None]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = Fun(Sum(), student.times)
    Y = Fun(KalmanFilter1D(), student.grades)
    Z = Fun(KalmanFilter1D(), student.baseline_grades)
    plot_data.update({f"grades-{idx}": (X, Y)})
    plot_data.update({f"baseline-{idx}": (X, Z)})

Plot(**plot_data)

In [None]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = Fun(Sum(), student.times)
    Y = Fun(KalmanFilter1D(), student.relative_grades)
    plot_data.update({f"rate-{idx}": (X, Y)})

Plot(**plot_data)

In [None]:
student.step/student.time

In [None]:
async def evolution(classroom):
    relative_grades = {student: Fun(TwoWindowFilter(lag=1024), student.relative_grades) for student in classroom.students}
    while True:
        await asyncio.sleep(60.0)
        ranked_students = sorted([student for student in classroom.students], key=lambda s: relative_grades[s].output[-1])
        top_student = ranked_students[-1]
        bot_student = ranked_students[0]
        if bot_student != top_student:
            classroom.graduate(bot_student)
            del relative_grades[bot_student]
        clone = top_student.clone()
        clone.mutate()
        classroom.enroll(clone)
        relative_grades[clone] = Fun(TwoWindowFilter(lag=1024), clone.relative_grades)


In [None]:
pd = {name: p for (name, p) in student.model.named_parameters()}
brakes = lambda n: 1e-8 # 1e-6*sin(3.14159*n/8192)**2
student.optimizer.lr[pd['LM.F.layers.0.weight']] = lambda n: brakes(n)
student.optimizer.lr[pd['LM.F.layers.2.F.layers.0.weight']] = lambda n: brakes(n)
student.optimizer.lr[pd['LM.F.layers.2.F.layers.1.weight']] = lambda n: brakes(n)
student.optimizer.lr[pd['LM.F.layers.2.F.layers.1.bias']] = lambda n: brakes(n)
student.optimizer.lr[pd['LM.F.layers.2.F.layers.3.weight']] = lambda n: brakes(n)
student.optimizer.lr[pd['LM.F.layers.2.F.layers.3.bias']] = lambda n: brakes(n)

In [None]:
torch.save(student.model, f='checkpoint.pt')