# **Efficient Computation of Deep Nonlinear Infinite-Width Neural Networks that Learn Features**

Welcome to the tutorial for using the Pi-Limit Python library, based on our work [published in ICLR 2022](https://openreview.net/pdf?id=tUMr0Iox8XW). Using this library, one can easily create their own trainable pi-nets.

In this notebook, we will walk through:
- Creating a basic infinite-width Pi-MLP
- Training it on some dummy data
- Training the network on CIFAR10
- Saving & reloading a trained network
- Sampling a finite Pi-Net and testing its performance

We will create a basic infinite-width Pi-MLP, first train it on some dummy data, and then train it on CIFAR10. Let's first begin by downloading the pip package for pilimit.


## Installation

Please use the following commands to install the pilimit library (includes some extra commands for this Colab notebook - the main command is the pip install from the git repo).

In [2]:
!pip install --upgrade pip
!pip install --upgrade setuptools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-23.0.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 22.0.4
    Uninstalling pip-22.0.4:
      Successfully uninstalled pip-22.0.4
Successfully installed pip-23.0.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setuptools
  Downloading setuptools-67.6.0-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 63.4.3
    Uninstalling setuptools-63.4.3:
      Successfully uninstalled

In [3]:
!pip uninstall pilimit -y
!pip install  git+https://github.com/santacml/pilim.git#egg=pilimit 

[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pilimit
  Cloning https://github.com/santacml/pilim.git to /tmp/pip-install-tvysssl9/pilimit_ea16e011287a471285d196cd98d1afef
  Running command git clone --filter=blob:none --quiet https://github.com/santacml/pilim.git /tmp/pip-install-tvysssl9/pilimit_ea16e011287a471285d196cd98d1afef
  Resolved https://github.com/santacml/pilim.git to commit 84968cc2046fb3c754e27230b68528c8f51f0276
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pilimit
  Building wheel for pilimit (setup.py) ... [?25l[?25hdone
  Created wheel for pilimit: filename=pilimit-0.1.0-py3-none-any.whl size=140585 sha256=8afcc51dd62f6b6f8600b34608990cbb9cebd0c8ef2e896e658ee874097c404e
  Stored in directory: /tmp/pip-ephem-wheel-cache-w4gbm89h/wheels/63/9f/89/2e5302a7b80ab2280ef99317355f63d2cee8fa521516567afd
Successfully built pilimit
Installing collected packages: p

## Creating a basic Pi-Net

Inside the package pilimit_lib, we provide torch-style infinite-width fully connected layers for use in networks. These layers can be used to easily create a network using familiar torch syntax. 

Examples are provided in examples.networks which are already good to go. Let's quickly go through a standard infinite MLP as defined in the networks file (edited for this notebook).

Feel free to copy and use this network for other projects. However, our intention is to provide easy-to-use primitives which others may easily build upon.


In [4]:
import torch
from torch import nn
from pilimit_lib.inf.layers import InfPiInputLinearReLU, InfPiLinearReLU
from pilimit_lib.examples.networks import PiNet


class InfMLP(PiNet):
    def __init__(
            self, 
            d_in, 
            d_out, 
            r, 
            L, 
            first_layer_alpha=1, 
            last_layer_alpha=1, 
            bias_alpha=1, 
            last_bias_alpha=None, 
            layernorm=False, 
            cuda_batch_size=None, 
            device="cpu"):
        super(InfMLP, self).__init__()

        self.d_in = d_in
        self.d_out = d_out
        self.r = r
        self.L = L

        # save as buffers for saving
        self.register_buffer("first_layer_alpha", torch.tensor(first_layer_alpha, dtype=torch.get_default_dtype()))
        self.register_buffer("last_layer_alpha", torch.tensor(last_layer_alpha, dtype=torch.get_default_dtype()))
        self.register_buffer("bias_alpha", torch.tensor(bias_alpha, dtype=torch.get_default_dtype()))
        if last_bias_alpha is None:
            last_bias_alpha = bias_alpha
        self.register_buffer("last_bias_alpha", torch.tensor(last_bias_alpha, dtype=torch.get_default_dtype()))
        self.layernorm = layernorm

        self.layers = nn.ModuleList()

        self.layers.append(InfPiInputLinearReLU(d_in, r, bias_alpha=bias_alpha, device=device))
        for n in range(1, L+1):
            self.layers.append(InfPiLinearReLU(r, device=device, bias_alpha=bias_alpha, layernorm=layernorm, cuda_batch_size=cuda_batch_size))
        
        self.layers.append(InfPiLinearReLU(r, r_out=d_out, output_layer=True, bias_alpha=last_bias_alpha, device=device, layernorm=layernorm, cuda_batch_size=cuda_batch_size))

        
    def forward(self, x):
        for n in range(0, self.L+2):
            x = self.layers[n](x)
            if n == 0: 
                x *= self.first_layer_alpha
            if n == self.L+1: 
                x *= self.last_layer_alpha
        return x
        

This code should look generally familiar to anyone who has created a network in torch before. There are a few key things to note:



1.   The network inherits from a class ```PiNet```; this is not strictly necessary, but ```PiNet``` simply helps with loading from a saved model. One could instead inherit from ```torch.nn.Module``` if so desired.

2.   There is a special ```InfPiInputLinearReLU``` specifically for input layers due to the definition of the infinite width limit. These layers have a different formulation than middle layers.

3. The general-purpose layer is ```InfPiLinearReLU```. As with the input layer, one can note that ReLU is baked-in to the layer and the activation cannot be changed. Again due to the formulation of the limit, this is necessary. We plan to add support for other activations in the future.

With these core building blocks, one can easily create a Pi-Net.


## What's in a layer?

A layer like `InfPiLinearReLU` has 3 parameters which can be though of as 'trainable weights' - though, they differ from weights in the traditional sense. These parameters are:

- A - represents the *post-activation portion* of the loss gradient from previous iterations
- B - represents the *pre-activation portion* of the loss gradient from previous iterations
- Amult - represents the accumulated learning rate and weight decay from previous iterations

A and B function together to represent coefficients which would multiply and sum *r* (rank hyperparameter) gaussian vectors together. These matrices are of shape *m* by *r*, where *m* actually *grows* throughout training. For further explanation and details, please see our original work and blog post.


## Training on dummy data

Next, let's define some dummy data (just a sine wave) and train on it with the PiNet to see what a minimal training loop looks like. 

In [5]:
import numpy as np 
import time 

torch.manual_seed(3133)
np.random.seed(3331)
device = "cuda" if torch.cuda.is_available() else "cpu"

data = torch.linspace(-np.pi, np.pi, 100, device=device).reshape(-1, 1)
labels = torch.sin(data) #.reshape(-1)
data = torch.cat([data, torch.ones_like(data, device=device)], dim=1)

Now, we create an instance of the earlier defined InfMLP.

In [6]:
d_in = 2
d_out = 3
r = 20
L = 1
bias_alpha = .5
batch_size = 50
net = InfMLP(d_in, d_out, r, L, device=device, bias_alpha=bias_alpha )

Now, let's go through the core training loop. 

Again, this should look very familiar to anyone who has worked in PyTorch before. There are a few key differences which are listed below.

In [7]:
from pilimit_lib.inf.optim import PiSGD, store_pi_grad_norm_, clip_grad_norm_
import sys

net.train()
epoch = 20
accum_steps = 1
gclip = .1
optimizer = PiSGD(net.parameters(), lr = .02)
tic = time.time()
for epoch in range(epoch):
    if epoch % accum_steps == 0:
        optimizer.zero_grad()
        net.zero_grad()
    
    prediction = net(data)
    
    loss = torch.sum((prediction - labels)**2)**.5

    print('Epoch {}: train loss: {}'.format(epoch, loss.item()))
    
    loss.backward()
    # stage_grad(net)

    if epoch % accum_steps == 0:
        # unstage_grad(net)

        if gclip:
            store_pi_grad_norm_(net.modules())
            clip_grad_norm_(net.parameters(), gclip)

        optimizer.step()

    #print("Memory used", torch.cuda.memory_reserved() / 1e9, torch.cuda.max_memory_reserved()  / 1e9)
    print("Network A size", net.layers[1].A.shape[0])
print("time", time.time() - tic)


Epoch 0: train loss: 12.186058044433594
Network A size 120
Epoch 1: train loss: 12.177803993225098
Network A size 220
Epoch 2: train loss: 12.16955852508545
Network A size 320
Epoch 3: train loss: 12.161317825317383
Network A size 420
Epoch 4: train loss: 12.153082847595215
Network A size 520
Epoch 5: train loss: 12.14484977722168
Network A size 620
Epoch 6: train loss: 12.13662052154541
Network A size 720
Epoch 7: train loss: 12.12839126586914
Network A size 820
Epoch 8: train loss: 12.120162010192871
Network A size 920
Epoch 9: train loss: 12.111932754516602
Network A size 1020
Epoch 10: train loss: 12.103699684143066
Network A size 1120
Epoch 11: train loss: 12.095463752746582
Network A size 1220
Epoch 12: train loss: 12.087224006652832
Network A size 1320
Epoch 13: train loss: 12.078978538513184
Network A size 1420
Epoch 14: train loss: 12.070727348327637
Network A size 1520
Epoch 15: train loss: 12.062467575073242
Network A size 1620
Epoch 16: train loss: 12.05419921875
Network A 

Performance doesn't really matter here; we want to take a look at what the code looks like when training a Pi-Net.

Note that as we train, the size of an A matrix in the network also grows per training step. This is an important and unfortunate side effect of the inf-width limit. Using these models takes careful consideration of this growing memory requirement - the more memory that you can use, the better.

Our library keeps as many native torch functions and syntax as possible, however, some drop-in replacement functions had to be created. These can be seen above. Here are the important ones:

1.   ```PiSGD``` is the optimizer - regular torch optimizers will not work

2.   ```store_pi_grad_norm_``` and ```clip_grad_norm_``` together allow for gradient clipping. It's necessary to use these two functions together as shown, given the unique format of the layers

3. ```stage_grad``` and ```unstage_grad``` allow for gradient accumulation if desired (commented out here). It's again necessary to use both of these functions exactly as shown, due to the growing nature of the network.

With these small caveats, the core training loop is largely the same!


## Training on (a small subsample of) CIFAR10

Let's now train an example Pi-Net on CIFAR10. First, we download the dataset as normal.

**NOTE**: given the restrictions of this colab notebook, we will take a subsample of 10 input samples of cifar10. Please download this notebook to run on a GPU-equipped computer to train fully.

In [8]:
# standard cifar10 download code
from torchvision import datasets, transforms
import torch.utils.data as data_utils

total_samples = 10
batch_size = 1

transform_list = []
transform_list.extend([transforms.ToTensor()])

transform_list.extend([transforms.Normalize([0.49137255, 0.48235294, 0.44666667], [0.24705882, 0.24352941, 0.26156863])])
transform = transforms.Compose(transform_list)

trainset = datasets.CIFAR10(root=".", train=True,
                                        download=True, transform=transform)

np.random.seed(0) # reproducability of subset
indices = np.random.choice(range(50000), size=total_samples, replace=False).tolist()
trainset = data_utils.Subset(trainset, indices)
print("Using subset of", len(trainset), "training samples")
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=0)

testset = datasets.CIFAR10(root=".", train=False,
                                      download=True, transform=transform)
np.random.seed(0) # reproducability of subset
indices = np.random.choice(range(50000), size=total_samples, replace=False).tolist()
testset = data_utils.Subset(testset, indices)
print("Using subset of", len(testset), "testing samples")
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                        shuffle=False, num_workers=0)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .
Using subset of 10 training samples
Files already downloaded and verified
Using subset of 10 testing samples


Next, we again define our InfMLP, along with a few basic hyperparameters.


In [9]:

d_in = 32*32*3
d_out = 10
r = 200
epoch = 50
L = 1 
layernorm = False
net = InfMLP(d_in, d_out, r, L, device=device, bias_alpha=bias_alpha, layernorm=layernorm)


epoch = 100
lr = .001
gclip = 0
# wd=0
gclip = 0.5
wd=0.1
gclip_per_param = True
step = True

no_apply_lr_mult_to_wd = True

first_layer_lr_mult = 1
last_layer_lr_mult = 1
bias_lr_mult = .1
first_layer_alpha = 1
bias_alpha = 1
last_layer_alpha = 1

Here is some intentionally verbose code for demonstration. It would be entirely possible to simply use ``` net.parameters() ``` as normal to gather the network parameters, but if one wishes to use layer-specific LR or apply weight decay, the following syntax is necessary.

Note the following:

1.   Layer 0 is special - there is no Amult or B, and A gets a learning rate
2.   After layer 0, only Amult gets learning rate/weight decay, but A and B are still added to optimizer
  - Don't worry about configuring whether A and B get learning rate - this is all handled in the backend





In [11]:

paramgroups = []
# first layer weights
paramgroups.append({
  'params': [net.layers[0].A],
  'lr': first_layer_lr_mult * lr,
  'weight_decay': wd / first_layer_lr_mult if no_apply_lr_mult_to_wd else wd
})
if net.layers[0].bias is not None:
  paramgroups.append({
    'params': [l.bias for l in net.layers],
    'lr': bias_lr_mult * lr,
    'weight_decay': wd / bias_lr_mult if no_apply_lr_mult_to_wd else wd
  })
paramgroups.append({
  'params': [l.Amult for l in net.layers[1:-1]],
})
paramgroups.append({
  'params': [net.layers[-1].Amult],
  'lr': last_layer_lr_mult * lr,
  'weight_decay': wd / last_layer_lr_mult if no_apply_lr_mult_to_wd else wd
})
paramgroups.append({
  'params': [l.A for l in net.layers[1:]],
})
paramgroups.append({
  'params': [l.B for l in net.layers[1:]],
})
optimizer = PiSGD(paramgroups, lr = lr, weight_decay=wd)

# this is the easy way to get parameters 
# optimizer = PiSGD(net.parameters(), lr = lr, weight_decay=wd) 

Finally, we can train on cifar10 using a core training loop very similar to the one above.

Here is a very short example, only training for 10 "epochs" on a very small subset of samples.

In [12]:

import torch.nn.functional as F
  

net.train()
losses = []
for epoch in range(epoch):
  epoch_losses = []
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device).type(torch.get_default_dtype()), target.to(device)
    data = data.reshape(data.shape[0], -1)
    labels = target

    optimizer.zero_grad()
    net.zero_grad()
    
    prediction = net(data)
    
    oh_target = target.new_zeros(target.shape[0], 10).type(torch.get_default_dtype())
    oh_target.scatter_(1, target.unsqueeze(-1), 1)
    oh_target -= 0.1
    loss = F.mse_loss(prediction, oh_target, reduction="mean")
      
    loss.backward()

    if gclip:
        store_pi_grad_norm_(net.modules())
        if gclip_per_param:
          for param in net.parameters():
              clip_grad_norm_(param, gclip)
        else:
          clip_grad_norm_(net.parameters(), gclip)

    if step: optimizer.step()

    epoch_losses.append(loss)

  
  if epoch % 10 == 0: 
    print('Epoch {}: train loss: {}'.format(epoch, (sum(epoch_losses) / len(epoch_losses)).item()))

Epoch 0: train loss: 0.09023816883563995
Epoch 10: train loss: 0.08447568118572235
Epoch 20: train loss: 0.08399686962366104
Epoch 30: train loss: 0.08405115455389023
Epoch 40: train loss: 0.08403242379426956
Epoch 50: train loss: 0.08391579240560532
Epoch 60: train loss: 0.08373766392469406
Epoch 70: train loss: 0.08352664113044739
Epoch 80: train loss: 0.08329982310533524
Epoch 90: train loss: 0.08306676894426346


## Reloading a trained Infinite Pi-Net

One of the main advantages of pilimit_lib is utilizing familiar torch-style syntax. We have one of our best Pi-Nets already trained on CIFAR10 available for download and use here: https://1drv.ms/u/s!Aqm-bcw66kwDnSYUPdFw-km20Hta?e=OrujuC

In the following code block, we demonstrate loading an infinite pi-net and testing it on CIFAR10. Note **this notebook does not download the trained network** - Colab is not powerful enough. To test the net, download the above link and this notebook on a computer with a good GPU. Then, insert the path to the path variable and uncomment the commented lines.

As-is, the code will produce only 10% accuracy, for random guessing.

In [13]:

testset = datasets.CIFAR10(root=".", train=False,
                                      download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                        shuffle=False, num_workers=0)

d_in = 32*32*3
d_out = 10
r = 400
L = 1
layernorm = False
net = InfMLP(d_in, d_out, r, L, device=device, bias_alpha=bias_alpha, layernorm=layernorm)

#path = 
#net.load_state_dict(torch.load(path))


def test_nn(model, device, test_loader):
    '''
    Test a model on a dataset (for validation/testing).
    '''
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device).type(torch.get_default_dtype()), target.to(device)
            data = data.reshape(data.shape[0], -1)

            output = model(data)
            test_loss += torch.nn.functional.cross_entropy(output, target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    return test_loss, correct / len(test_loader.dataset)

print("Now testing the network, this will require a powerful gpu...")
test_loss, acc = test_nn(net, device, test_loader)


Files already downloaded and verified
Now testing the network, this will require a powerful gpu...


KeyboardInterrupt: ignored

The performance should be About 60.9%, which is slightly less than the best overall reported performance in our paper (61.5%). This is because we refactored the repository compared to our original results, so if you'd like to exactly reproduce our results, please use pilimit_orig. 

Let's quickly inspect this network:

In [14]:
for n, layer in enumerate(net.layers):
  print(f"Layer {n} A has shape {layer.A.shape}")

Layer 0 A has shape torch.Size([3072, 400])
Layer 1 A has shape torch.Size([400, 400])
Layer 2 A has shape torch.Size([400, 10])


Notice a couple things **(layer shapes are only if the trained pi-net has been loaded)**:

- 1 hidden layer does not include input/output for our network, so there are 3 layers total (named 0 through 2)
- Layer 0 has an input of shape 3072 and layer 2 has output shape 10 
- The other dimensions are of shape 50000*46=2300000, because this is the network from epoch 46 of training (out of 50) on 50000

## Sampling and testing a Finite Pi-Net

Each layer class that we have defined has an associated ```sample()``` function which can be used to create a finite-layer equivalent. Under ```networks.py``` is defined a ```FinPiMLPSample```class, which demonstrates how to use this sample function along with the example infinite network.

There are 2 important things to note here:
- Unlike the infnet, finite layer classes do not include the activation built in (subject to refactor later)
- When sampling a finite layer, it is necessary to pass in the previous layer's omega as per the paper's instructions for building a pi-net

Be sure to fully understand how finite layers are sampled from infinite ones before creating a new finite-sampled network. 

There is also a big difference between the two following types of sampling:

1. sampling from a *trained* infinite pi-net to create some finite network
2. sampling from an *untrained* infinite pi-net, then training that network

In the paper, we mostly refer to type #2 to demonstrate the difference in training an infinite or finite network from the same initialization. When sampling type #1 as we are about to show, we expect performance to heavily suffer as the network was trained in infinite-width (though performance would converge as width approaches infinity).



Now, let's sample a finite network of width 2048 from the above loaded infnet. Note that the syntax to do so is incredibly light - only one line - and the same testing function defined above can be used!

In [15]:
from pilimit_lib.examples.networks import FinPiMLPSample

mynet = FinPiMLPSample(net, 2048)
print("Now testing the finite network...")
test_loss, acc = test_nn(net, device, test_loader)

  self.register_buffer("bias_alpha", torch.tensor(bias_alpha, dtype=torch.get_default_dtype()))
  self.register_buffer("bias_alpha", torch.tensor(bias_alpha, dtype=torch.get_default_dtype()))


Now testing the finite network...

Test set: Average loss: 23025.8512, Accuracy: 1000/10000 (10.00%)



## Final notes

For any future Pi-Nets we highly recommend using pilimit_lib as it is easier to read and modify. The results presented in our paper, however, come from pilimit_orig and due to minor issues like floating point conversions, results with pilimit_lib will be ever so slightly different.

Therefore, we also include the original pilimit_orig in the repo along with some instructions on how to use it if one wishes exact reproduction.

For best possible performance, we have an infinite pi-net of rank 400 already trained on CIFAR10 available in the repo.