In [1]:
import asyncio
import math
from math import log, sin, cos, tan, exp, sqrt, pi
import time
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import BytesDataset
from classroom import GutenbergBytesDataset
from classroom import GutenbergBitsDataset
from classroom import GutenbergGPT2Dataset
from classroom import MLPLM, MyLM, ABPCNLM
from classroom import TransformerLM
from classroom import AdamW
from classroom import Sonny
from classroom import Floyd
from classroom import Plot
from classroom import Fun
from classroom import Count
from classroom import Sum
from classroom import Diff
from classroom import Log2Sum
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import numel
from classroom import utf8decode, utf8encode, gpt2decode, gpt2encode
from classroom import utf8bitsdecode, utf8bitsencode
from pathlib import Path
import numba

In [2]:
small_model = torch.load("small.pt").to('cuda') # 1024

In [3]:
model = torch.load("2021-12-06-0645.pt").to('cuda') # 512

In [4]:
optimizer = AdamW(parameters=model.named_parameters())
dataset = GutenbergGPT2Dataset(device='cuda')
#dataset = GutenbergBytesDataset()

batch_size = None
example_length = model.n_ctx + 1

student= Student(
    model=model,
    optimizer=optimizer,
    dataset=dataset,
    batch_size=batch_size,
    example_length=example_length,
    device='cuda')

In [45]:
student.batch_size = 8
student.example_length = 513
for (idx, (pn, p)) in enumerate(student.model.named_parameters()):
    batch_multiplier = 10
    lr_base = 1e-5
    warm_up = 0
    lr = lambda n: 0 if n < warm_up else lr_base *(1 + (n%100))/100
    student.optimizer.state[pn]["lr"]           = lambda n: lr(n)
    student.optimizer.state[pn]["beta1"]        = lambda n: 0.9
    student.optimizer.state[pn]["beta2"]        = lambda n: 0.999
    student.optimizer.state[pn]["weight_decay"] = lambda n: 0.001
    student.optimizer.state[pn]["update"]       = lambda n: (n < warm_up) or (n%batch_multiplier == 0)

In [6]:
t_start = time.time()
exp_losses = []

In [7]:
if False:
    student.reset_baseline()
    n_of_last_baseline = len(student.times)
    t_start = time.time()
    t_of_last_baseline = 0

In [8]:
def vulkanmindmeld_study(self, source):
    """
    Use `self.optimizer` to train `self.model` for one step using a batch obtained from `self.dataset` using training hyperparameters `self.batch_size` and `self.example_length`.
    Add/append the resulting training data to `self.time`, `self.times`, `self.grades`, `self.baseline_grades`, and `self.predicted_grades`.
    """
    def closure():
        batch = self.dataset.batch(batch_size=self.batch_size,
                                   example_length=self.example_length-1+512, # 1024 not 1025
                                   offset=None)
        small_batch = batch[:,512:]
        q = self.model.language_model.softmax(self.model.language_model.module(small_batch))
        
        with torch.no_grad():
            #print("A debug", batch.shape)
            ps = []
            for idx in range(self.batch_size):
                #print(f"B {idx}")
                thing = source.language_model.softmax(
                    source.language_model.module(batch[idx:idx+1,:]))
                #print("thing", thing.shape)
                ps.append(thing[0,-1,:].view(1, 1, 50257))
            p = torch.cat(ps)
        #print(q.shape, p.shape, "debug")
        loss = -torch.sum(p * torch.log(q+1e-12))/math.log(2)/self.batch_size
        loss.backward()
        return loss
    loss = self.optimizer.step(closure)
    exp_losses.append(loss.item())


In [9]:
import asyncio
async def train(student):
    while True:
        #print("train A")
        vulkanmindmeld_study(student, small_model)
        #print("train B")
        await asyncio.sleep(1e-4)
        #student.study()
        #print("train C")
        #await asyncio.sleep(1e-4)
        

In [10]:
training_task = asyncio.create_task(train(student))

In [40]:
training_task #.cancel()

<Task pending name='Task-3' coro=<train() running at /tmp/ipykernel_726165/2259477027.py:7> wait_for=<Future pending cb=[<TaskWakeupMethWrapper object at 0x7fc8e730d190>()]>>

In [63]:
import time
plot_data_2 = {}
lag = 20000
X = Fun(Count(), exp_losses)
Y = Fun(TwoWindowFilter(lag=lag), exp_losses)
plot_data_2.update({f"losses": (X, Y)})
Plot(**plot_data_2)



In [65]:
import time

n = len(exp_losses)
t = time.time() - t_start
y = np.mean(np.array(exp_losses[-20000:]))
z = 6.81 # np.mean(np.array(exp_losses))
message = '\n'.join([
    f"batch_size            = {student.batch_size}",
    f"example_length        = {student.example_length}",
    f"y                     = {int(y*1e4)/1e4}",
    f"z                     = {int(z*1e4)/1e4}",
    f"n                     = {n} steps",
    f"t                     = {int(t)} seconds",
    f"n/t                   = {n/t}",
    f"(z-y)/t               = {(z-y)/t}",
    f"y*t/(z-y)             = {y*t/(z-y)//3600} hours",
])
print(message)

batch_size            = 8
example_length        = 513
y                     = 6.7404
z                     = 6.81
n                     = 58679 steps
t                     = 32911 seconds
n/t                   = 1.7829158878580293
(z-y)/t               = 2.1139860349199243e-06
y*t/(z-y)             = 885.0 hours


In [None]:
if True:
    student.reset_baseline()
    n_of_last_baseline = len(student.times)-1
    t_of_last_baseline = time.time() - t_start

In [66]:
path = "2021-12-06-0645-vulkanized.pt"

In [67]:
torch.save(student.model, f=path)