In [None]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import UTF8Dataset as Dataset
from classroom import MLPLM
from classroom import TransformerLM
from classroom import AdamW, Sonny, Floyd
from classroom import Plot
from classroom import FilteredList
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import CountFilter
from classroom import SumFilter
from classroom.util import numel

## Training

## Scheduling helpers

In [None]:
constant = lambda c: lambda step: c
linear_warmup_then_decay = (lambda lr, warmup: 
    lambda n: lr*(n/warmup) if n < warmup else lr*(warmup/n))

## MLP Language Model

In [None]:
def mlp_factory(n_ctx):
    model = MLPLM(
        n_vocab_in=256,
        n_vocab_out=256,
        n_ctx=n_ctx,
        d_model=32,
        d_hidden=8192*2,
        nonlinearity="sigmoid").to('cuda')
    optimizer = AdamW(
        parameters=model.parameters(), 
        eps=constant(1e-4), 
        lr=lambda n: 1e-4, #*sin(3.14159*n/512)**2,
        beta1=constant(0.9),
        beta2=constant(0.999),
        weight_decay=constant(0.01),
        n=0)
    dataset = Dataset()
    batch_size = 512
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [None]:
def trans_factory():
    model = TransformerLM(
        n_vocab_in=256,
        n_vocab_out=256,
        n_ctx=64,
        d_model=1024,
        d_k=32,
        d_v=32,
        n_heads=32,
        d_hidden=4096,
        n_layers=2,
        p_dropout_embedding=0.1,
        p_dropout_attn_mat=0.1,
        p_dropout_attn_out=0.1,
        p_dropout_mlp=0.1).to('cuda')

    optimizer = AdamW(
        parameters=model.parameters(),
        eps=constant(1e-4),
        lr=lambda n: sin(n/500)**2 * 1e-5, 
        beta1=constant(0.9),
        beta2=constant(0.999),
        weight_decay=constant(0.01),
        n=0)
    dataset = Dataset()
    batch_size = 1
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [None]:
classroom = Classroom()
students = [mlp_factory(n_ctx) for n_ctx in [4, 32]]
for student in students:
    classroom.enroll(student)

## Autocompleting

In [None]:
print(classroom.students[1].autocomplete(n_generate=1024))

## Training Visualization Plots

In [None]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = FilteredList(student.times, SumFilter())
    Y = FilteredList(student.grades, TwoWindowFilter(lag=4096))
    plot_data.update({f"{idx}": (X, Y)})
Plot(**plot_data)

In [None]:
sum(classroom.students[0].times)/len(classroom.students[0].times)

In [None]:
sum(classroom.students[-1].times)/len(classroom.students[-1].times)

In [None]:
len(classroom.students[0].times)

## Genetic algorithm

In [None]:
graveyard = []
async def mauler():
    while True:
        classroom.rank_students()
        student = classroom.students[-1].clone()
        student.mutate()
        classroom.enroll(student)
        await asyncio.sleep(10.0)
        classroom.rank_students()
        student = classroom.graduate()
        graveyard.append((student.times, student.grades))
        del student
task = asyncio.create_task(mauler())

In [None]:
import time
plot_data = {}
for (idx, (U, V)) in enumerate(graveyard):
    X = FilteredList(U, SumFilter())
    Y = FilteredList(V, TwoWindowFilter())
    plot_data.update({f"{idx}": (X, Y)})
Plot(legend=False, **plot_data)