# a Replication of "The Quantization Model of Neural Scaling"
https://browse.arxiv.org/pdf/2303.13506.pdf

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from matplotlib import pyplot as plt

from lib.generate_data import Sampler, DummyData, MultitaskSparseParity
from lib.tracking import AnalyticsManager

## Set up data

In [None]:
n_data_bits = 100
# n_control_bits = 500
n_control_bits = 5
k = 3
alpha = 0.4

sampler: Sampler = MultitaskSparseParity(n_control_bits, n_data_bits, k=k, alpha=alpha)
# sampler: Sampler = DummyData(n_control_bits + n_data_bits)

In [None]:
# sampler.generate_data(2)

In [None]:

%timeit sampler.generate_data(20000)

## Train Network

In [None]:
batch_size = 20000
training_size = 1e5
test_size = 1000

n_hidden = 200
lr = 1e-3
n_epochs = 500
optimizer_func = lambda model: torch.optim.Adam(model.parameters(), lr=lr)
loss_func = torch.nn.BCELoss()

logger = AnalyticsManager()

In [None]:
class TinyModel(torch.nn.Module):
    """Single hidden layer model with relu activations."""
    def __init__(self, n_hidden: int):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(n_control_bits + n_data_bits, n_hidden)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(n_hidden, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        return x

model = TinyModel(n_hidden)
optimizer = optimizer_func(model)

In [None]:
# %matplotlib notebook
# from time import sleep

# fig = plt.figure()
# ax1 = fig.add_subplot(211)
# ax2 = fig.add_subplot(212, sharex=ax1)

# for i in range(30):
#     ax1.plot(range(i))
#     plt.show()
#     sleep(1)

In [None]:
def calculate_loss(X, y, model):
    y_pred = model(X.float())
    loss = loss_func(y_pred, y[:, None].float())
    return loss


In [None]:
for epoch in range(n_epochs):
    for i in range(int(training_size // batch_size)):
        X_batch, y_batch = sampler.generate_data(batch_size)

        # Calculate total loss
        loss = calculate_loss(X_batch, y_batch, model)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


        logger.log({"loss": loss.item()})

        # Calculate loss for each individual task.
        for task_index in range(n_control_bits):
            X_task, y_task = sampler.generate_data(test_size, force_task=task_index)
            with torch.no_grad():
                task_loss = calculate_loss(X_task, y_task, model).item()
                logger.log({f"task_{task_index}_loss" : task_loss})

        
    print(f"Epoch: {epoch} loss: {loss.item()}")

In [None]:
for key in logger.metrics.keys():
    plt.plot(logger.metrics[key], label=key)
plt.legend()
# plt.yscale("log")