In [1]:
import torch
from torch.functional import F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import copy
import numpy as np
from torchvision import datasets, transforms

from sad_nns.uncertainty import *
from neurops import *

  from .autonotebook import tqdm as notebook_tqdm


**NORTH:** Define a LeNet-style model. 

Use the `ModSequential` class to wrap the `ModConv2d` and `ModLinear` model, which allows us to mask, prune, and grow the model. 

Use the `track_activations` and `track_auxiliary_gradients` arguments to enable the tracking of activations and auxiliary gradients later. 

By adding the `input_shape` of the data, we can compute the conversion factor of how many input neurons to add to the first linear layer when a new output channel is added to the final convolutional layer. 

In [6]:
model = ModSequential(
        ModConv2d(in_channels=1, out_channels=8, kernel_size=7, masked=True, padding=1, learnable_mask=True),
        ModConv2d(in_channels=8, out_channels=16, kernel_size=7, masked=True, padding=1, prebatchnorm=True, learnable_mask=True),
        ModConv2d(in_channels=16, out_channels=16, kernel_size=5, masked=True, prebatchnorm=True, learnable_mask=True),
        ModLinear(64, 32, masked=True, prebatchnorm=True, learnable_mask=True),
        ModLinear(32, 10, masked=True, prebatchnorm=True, nonlinearity=""),
        track_activations=True,
        track_auxiliary_gradients=True,
        input_shape = (1, 14, 14)
    ).to(device)
torch.compile(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Add EDL Loss Function
# KLDivergenceLoss, MaximumLikelihoodLoss, CrossEntropyBayesRisk, SquaredErrorBayesRisk
criterion = SquaredErrorBayesRisk()
kl_divergence = KLDivergenceLoss()

**NORTH:** Get a dataset and define standard training and testing functions.

In [4]:
dataset = datasets.MNIST('../data/', train=True, download=True,
                     transform=transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,)),
                            transforms.Resize((14,14))
                        ]))
train_set, val_set = torch.utils.data.random_split(dataset, lengths=[int(0.9*len(dataset)), int(0.1*len(dataset))])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,)),
                            transforms.Resize((14,14))
                        ])),
    batch_size=128, shuffle=True)

def train(model, train_loader, optimizer, criterion, epochs=10, num_classes=10, val_loader=None, verbose=True):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            # Convert target to one-hot encoding
            target = F.one_hot(target, num_classes=num_classes)

            optimizer.zero_grad()
            output = model(data)

            # Calculate uncertainty
            evidence = F.relu(output)
            loss = criterion(evidence, target)

            # Calculate KL Divergence Loss
            kl_divergence_loss = kl_divergence(evidence, target)
            annealing_step = 10
            annealing_coef = torch.min(
                torch.tensor(1.0, dtype=torch.float32),
                torch.tensor(epoch / annealing_step, dtype=torch.float32)
            )
            loss += annealing_coef * kl_divergence_loss
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if batch_idx % 100 == 0 and verbose:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
        if val_loader is not None:
            print("Validation: ", end = "")
            test(model, val_loader, criterion)

def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            # Convert target to one-hot encoding
            target_one_hot = F.one_hot(target, num_classes=10).float()

            output = model(data)

            # sum up batch loss
            test_loss += criterion(output, target_one_hot).item()

            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    
    print('Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 85774580.72it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 43402255.04it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 23230217.88it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8134299.22it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



**NORTH:** Pretrain the model before changing the architecture

In [7]:
train(model, train_loader, optimizer, criterion, epochs=5, val_loader=val_loader)



Validation: Average loss: 0.0050, Accuracy: 5010/6000 (83.50%)
Validation: Average loss: 0.0144, Accuracy: 5531/6000 (92.18%)
Validation: Average loss: 0.0120, Accuracy: 5684/6000 (94.73%)
Validation: Average loss: 0.0115, Accuracy: 5708/6000 (95.13%)
Validation: Average loss: 0.0108, Accuracy: 5708/6000 (95.13%)


### **NORTH:** Model Optimization Techniques

**NORTH:** Use a heuristic from `metrics.py` to measure the existing channels and neurons to determine which ones to prune.

The simplest one is measuring the norm of incoming weights to a neuron. We'll copy the model (so we have access to the original), then score each neuron and prune the lowest scoring ones within each layer. After running the following block, try uncommenting different lines to see how different metrics affect the model.

In [8]:
modded_model = copy.deepcopy(model)
modded_optimizer = torch.optim.SGD(modded_model.parameters(), lr=0.01)
modded_optimizer.load_state_dict(optimizer.state_dict())

for i in range(len(model)-1):
    scores = weight_sum(modded_model[i].weight)
    # scores = weight_sum(modded_model[i].weight) +  weight_sum(modded_model[i+1].weight, fanin=False, conversion_factor=model.conversion_factor if i == model.conversion_layer else -1)
    # scores = activation_variance(modded_model.activations[str(i)])
    # scores = svd_score(modded_model.activations[str(i)])
    # scores = nuclear_score(modded_model.activations[str(i)], average=i<3)
    # scores = modded_model[i+1].batchnorm.weight.abs() if i != modded_model.conversion_layer else modded_model[i+1].batchnorm.weight.abs().reshape(modded_model.conversion_factor,-1).sum(0) 
    # Before trying this line, run the following block: # scores = fisher_info(mask_grads[i])
    print("Layer {} scores: mean {:.3g}, std {:.3g}, min {:.3g}, smallest 25%:".format(i, scores.mean(), scores.std(), scores.min()), end=" ")
    to_prune = np.argsort(scores.detach().cpu().numpy())[:int(0.25*len(scores))]
    print(to_prune)
    modded_model.prune(i, to_prune, optimizer=modded_optimizer, clear_activations=True)
print("The pruned model has {} effective parameters.".format(modded_model.parameter_count(masked = True)))
print("Validation after pruning: ", end = "")
test(modded_model, val_loader, criterion)
train(modded_model, train_loader, modded_optimizer, criterion, epochs=2, val_loader=val_loader)

Layer 0 scores: mean 4.49, std 0.236, min 4.16, smallest 25%: [1 0]
Layer 1 scores: mean 8.73, std 0.447, min 8, smallest 25%: [7 5 1 3]
Layer 2 scores: mean 8.83, std 0.621, min 7.54, smallest 25%: [ 6 11  1 15]
Layer 3 scores: mean 3.26, std 0.33, min 2.49, smallest 25%: [19  7 11 21 29 15 28 23]
The pruned model has 9058 effective parameters.
Validation after pruning: Average loss: 0.0097, Accuracy: 4111/6000 (68.52%)
Validation: Average loss: 0.0036, Accuracy: 5643/6000 (94.05%)
Validation: Average loss: 0.0166, Accuracy: 5688/6000 (94.80%)


**NORTH:** Grow the model using a neurogenesis strategy similar to NORTH-Random.

In [9]:
modded_model_grow = copy.deepcopy(model)
modded_optimizer_grow = torch.optim.SGD(modded_model_grow.parameters(), lr=0.01)
modded_optimizer_grow.load_state_dict(optimizer.state_dict())

for iter in range(5):
    for i in range(len(modded_model_grow)-1):
        #score = orthogonality_gap(modded_model_grow.activations[str(i)])
        max_rank = modded_model_grow[i].width()
        score = effective_rank(modded_model_grow.activations[str(i)])
        to_add = max(score-int(0.95*max_rank), 0)
        print("Layer {} score: {}/{}, neurons to add: {}".format(i, score, max_rank, to_add))
        modded_model_grow.grow(i, to_add, fanin_weights="iterative_orthogonalization",
                               optimizer=modded_optimizer_grow)
    print("The grown model now has {} effective parameters.".format(modded_model_grow.parameter_count(masked = True)))
    print("Validation after growing: ", end = "")
    test(modded_model_grow, val_loader, criterion)
    train(modded_model_grow, train_loader, modded_optimizer_grow, criterion, epochs=2, val_loader=val_loader)

Layer 0 score: 8/8, neurons to add: 1
Layer 1 score: 16/16, neurons to add: 1
Layer 2 score: 16/16, neurons to add: 1
Layer 3 score: 32/32, neurons to add: 2
The grown model now has 16731 effective parameters.
Validation after growing: Average loss: 0.0108, Accuracy: 5708/6000 (95.13%)
Validation: Average loss: 0.0028, Accuracy: 5755/6000 (95.92%)
Validation: Average loss: 0.0171, Accuracy: 5732/6000 (95.53%)
Layer 0 score: 8/9, neurons to add: 0
Layer 1 score: 17/17, neurons to add: 1
Layer 2 score: 17/17, neurons to add: 1
Layer 3 score: 34/34, neurons to add: 2
The grown model now has 19143 effective parameters.
Validation after growing: Average loss: 0.0171, Accuracy: 5732/6000 (95.53%)
Validation: Average loss: 0.0020, Accuracy: 5742/6000 (95.70%)
Validation: Average loss: 0.0400, Accuracy: 5729/6000 (95.48%)
Layer 0 score: 9/9, neurons to add: 1
Layer 1 score: 18/18, neurons to add: 1
Layer 2 score: 18/18, neurons to add: 1
Layer 3 score: 36/36, neurons to add: 2
The grown model 