# a Replication of "The Quantization Model of Neural Scaling"
https://browse.arxiv.org/pdf/2303.13506.pdf

This paper proposes explaining neural scaling as a set of discrete tasks being individually learnt.

To demonstrate this, we train a model on a "Multitask Spare Parity" dataset. Each example is composed of two sections:

$$[\:\underbrace{0,1,0,0,}_\text{Control bits} \; \underbrace{1,0,0,1,0,1,1,1,0}_\text{Data bits}\:]$$

The control bit is a one-hot encoding representing the "task" to be computed on the data bits. In this case, the task is Sparse Parity; i.e. the parity of a subset of the data bits.



This notebook replicates Figure 7 from the addendum:

![Expected training dynamics for multitask parity.](image.png)

This plot shows the training curve of each subtask against the loss of the total dataset. We see that each task appears to be learned individually and sharply, while the overall loss has a smooth training curve!

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from matplotlib import pyplot as plt
import time
from tqdm import tqdm

from lib.generate_data import Sampler, DummyData, MultitaskSparseParity
from lib.tracking import AnalyticsManager

## Set up data

In [None]:
n_data_bits = 100
# n_control_bits = 500
n_control_bits = 10
k = 3
alpha = 0.4

sampler: Sampler = MultitaskSparseParity(n_control_bits, n_data_bits, k=k, alpha=alpha)
# sampler: Sampler = DummyData(n_control_bits + n_data_bits)

In [None]:

%timeit sampler.generate_data(20000)

## Train Network

In [None]:
batch_size = 20000
training_size = 1e5
test_size = 1000

n_hidden = 200
lr = 1e-3
n_epochs = 500
optimizer_func = lambda model: torch.optim.Adam(model.parameters(), lr=lr)
loss_func = torch.nn.BCELoss()

logger = AnalyticsManager()
print_frequency = 30 # Time between logging messages. In between, we use TQDM to show the current loss.

In [None]:
class TinyModel(torch.nn.Module):
    """Single hidden layer model with relu activations."""
    def __init__(self, n_hidden: int):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(n_control_bits + n_data_bits, n_hidden)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(n_hidden, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        return x

model = TinyModel(n_hidden)
optimizer = optimizer_func(model)

In [None]:
# %matplotlib notebook
# from time import sleep

# fig = plt.figure()
# ax1 = fig.add_subplot(211)
# ax2 = fig.add_subplot(212, sharex=ax1)

# for i in range(30):
#     ax1.plot(range(i))
#     plt.show()
#     sleep(1)

In [None]:
def calculate_loss(X, y, model):
    y_pred = model(X.float())
    loss = loss_func(y_pred, y[:, None].float())
    return loss


In [None]:
last_print_time = 0

for epoch in (pbar := tqdm(range(n_epochs))):
    for i in range(int(training_size // batch_size)):
        X_batch, y_batch = sampler.generate_data(batch_size)

        # Calculate total loss
        loss = calculate_loss(X_batch, y_batch, model)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


        logger.log({"loss": loss.item()})

        # Calculate loss for each individual task.
        for task_index in range(n_control_bits):
            X_task, y_task = sampler.generate_data(test_size, force_task=task_index)
            with torch.no_grad():
                task_loss = calculate_loss(X_task, y_task, model).item()
                logger.log({f"task_{task_index}_loss" : task_loss})

    epoch_str = f"Epoch: {epoch} loss: {loss.item()}"
    pbar.set_description(epoch_str)
    if (time.time() - last_print_time) > print_frequency:
        tqdm.write(epoch_str)
        last_print_time = time.time()

In [None]:
plt.figure(figsize=(20,10))
for key in logger.metrics.keys():
    plt.plot(logger.metrics[key], label=key)
plt.legend()
# plt.yscale("log")