<a href="https://colab.research.google.com/github/starship006/ARENA-work/blob/main/w3/scaling%20laws/CNN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# "Write a script to train a small CNN on MNIST, or find one you have written previously."

Day 1 turned out to be more of a "relearn how to actually make a CNN, and then implement it out" type of day; I got a working CNN predicting on MNIST on 98% accuracy

Day 2, I finally implemented sweeping on the CNN! I'm not getting clean scaling results, primarily because the learning rate is playing an extrodinariy amount of important compared to the parameter count, and I'm not sure about how I should change the learning rate in relation to parameter count.

Day 3, I restarted my hyperparameter searching; day 2 was kind of sloppy, and I didn't get any clear results. I'll start by finetuning for the best learning rate for "scale 5" CNNs. From there, I can go between different parameters counts easily - a change in scale by one doubles the parameter count, which roughly cooresponds to a change in the learning rate by $1/\sqrt{2}$.

For a scale of 5, the optimal lr turned out to be around 0.0009.
This means that for a scale of 4, I can predict it to be around 0.0012 - let's see if this shows to be true! And for a scale of 3, I can predict for it to be around 0.0017. For 2, 0.0024. For 1, 0.0033.

In [None]:
!pip install wandb

In [45]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from tqdm import tqdm_notebook

import wandb

device = "cuda" if t.cuda.is_available() else "cpu"
t.manual_seed(1) # reproducability!

<torch._C.Generator at 0x7f2d0fee2070>

In [16]:
class ConvNet(nn.Module):

    def __init__(self, scale = 3, kernel_size = 5):
        super().__init__()
        
        scale = (2 ** 0.5) ** scale
        out_one = int(2 * scale)
        out_two = int(6 * scale)
        lin_out_one = int(40 * scale)
        lin_out_two = int(20 * scale)



        lin_out_three = 10
        self.conv1 = nn.Conv2d(1, out_one, kernel_size)
        self.conv2 = nn.Conv2d(out_one, out_two, kernel_size)

        final_dim = (((28 - kernel_size + 1) / 2) - kernel_size + 1) / 2 # two comes from max pooling
        final_dim = int(final_dim) # convert float to int
        self.lin1 = nn.Linear(out_two * final_dim * final_dim, lin_out_one) 
        self.lin2 = nn.Linear(lin_out_one, lin_out_two)
        self.lin3 = nn.Linear(lin_out_two, lin_out_three)

    def forward(self, x):
        # apply convolutions
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        # flatten everything except the batch
        x = t.flatten(x, 1) 
        # apply linear layers
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

    def getParamCount(self):
        totalParamCount = 0
        for parameter in self.parameters():
          currentCount = 1
          for i in parameter.shape:
            currentCount = currentCount * i
          totalParamCount += currentCount
        return totalParamCount

In [17]:
# testing to see how scale manipulates size; increasing scale linearlly increases num parameters quadratically
for i in range(9):
  print("num parameters with the following 'scale' of ConvNet():" + str(i))
  net = ConvNet(i)
  print(net.getParamCount())

num parameters with the following 'scale' of ConvNet():0
5268
num parameters with the following 'scale' of ConvNet():1
9570
num parameters with the following 'scale' of ConvNet():2
20406
num parameters with the following 'scale' of ConvNet():3
38141
num parameters with the following 'scale' of ConvNet():4
80322
num parameters with the following 'scale' of ConvNet():5
155739
num parameters with the following 'scale' of ConvNet():6
318714
num parameters with the following 'scale' of ConvNet():7
627133
num parameters with the following 'scale' of ConvNet():8
1269738


In [18]:
mnist_trainset = datasets.MNIST(root = './data', train=True, download=True, transform=transforms.ToTensor()) # has [28,28] sized images
mnist_testset = datasets.MNIST(root = './data', train=False, download=True, transform=transforms.ToTensor()) # has [28,28] sized images

train_dl = DataLoader(mnist_trainset, batch_size=64, shuffle=True) 
test_dl = DataLoader(mnist_testset, batch_size=64, shuffle=True)

(Apologies, the following training code is messy and hard to follow and could probably look better. For now, I'm lazy)

In [20]:
def train_with_wandb():
    wandb.init()
    # constant
    num_epochs = 1

    # wandb parameters
    lr = wandb.config.lr

    # set-up
    model = ConvNet(wandb.config.scale).to(device)
    optim = t.optim.Adam(model.parameters(), lr)
    criterion = nn.CrossEntropyLoss().to(device)

    examples_seen = 0
    wandb.watch(model, criterion = criterion, log="all", log_freq = 10, log_graph = True)
    # log param size
    wandb.log({"Parameter Size" : model.getParamCount()}, step = examples_seen)

    mid_log_count = 60 # how many training steps to take before looking at test loss
    count = 0
    # train!
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_corrects = 0.0
        model.train()
        current_dl = train_dl
        for inputs, labels in current_dl:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            optim.zero_grad()
            with t.set_grad_enabled(True):
                outputs = model(inputs)
                orig_labels = labels
                labels = labels.to(t.int64)
                labels = F.one_hot(labels, 10).to(t.float64)


                loss = criterion(outputs, labels)
                preds = t.argmax(outputs,dim=-1)

                loss.backward()
                optim.step()
                examples_seen += len(inputs)
                # logging
                running_loss += loss.item() * inputs.size(0)
                running_corrects += t.sum(preds == orig_labels.data)
                wandb.log({"Train Loss" : loss}, step = examples_seen)
            
            count += 1          
            if count == mid_log_count:
              test_acc = find_test_acc(model, criterion)
              wandb.log({"Test Accuracy" : test_acc}, step = examples_seen)
              model.train()
              count = 0



    filename = f"{wandb.run.dir}/model_state_dict.pt"
    print(f"Saving model to: {filename}")
    t.save(model.state_dict(), filename)
    wandb.save(filename)

In [21]:
def find_test_acc(model, criterion):
    running_corrects = 0.0
    model.eval()
    current_dl = test_dl
    for inputs, labels in current_dl:
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        with t.set_grad_enabled(False): # don't think this is necessary, but just in case
            outputs = model(inputs)
            orig_labels = labels
            labels = labels.to(t.int64)
            labels = F.one_hot(labels, 10).to(t.float64)
            loss = criterion(outputs, labels)
            preds = t.argmax(outputs,dim=-1)
            running_corrects += t.sum(preds == orig_labels.data)
    epoch_acc = running_corrects / len(current_dl.dataset)
    return epoch_acc     

In [41]:
sweep_config = {
    'method': 'random',
    'name': 'CNN_scale_sweep',
    'metric': {'name' : 'Test Accuracy', 'goal' : 'maximize'},
    'parameters':
    {
        'lr': {'max': 0.01, 'min': 0.0001, 'distribution': 'log_uniform_values'},
        'scale': {'values': [1]} # +1 scale doubles parameter count
    }
}

In [47]:
sweep_id = "oufsr67l" # wandb.sweep(sweep = sweep_config, project='Scaling_Laws_CNN')

# sweep size - associated sweep_id
# 1 - hcidgdal
# 2 - 03tt0k9t
# 3 - 9xwcgfwm
# 4 - g9wdllp4
# 5 - oufsr67l

In [None]:
# the bottom isn't exactly what I did; I changed the random seed in the middle of the training for each sweep, too
for i in ["oufsr67l", "g9wdllp4", "9xwcgfwm", "03tt0k9t", "hcidgdal"]: 
    sweep_id = i
    wandb.agent(sweep_id = sweep_id, function = train_with_wandb, count = 12)