In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR

device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Generate synthetic data, train linear regression model

In [3]:
my_net = nn.Sequential(
    nn.Linear(in_features=2, out_features=3, bias=False),
    nn.ReLU(),
    nn.Linear(in_features=3, out_features=2, bias=False),
).to(device)

In [4]:
dataset_size = 1000
batch_size = 10

x = np.random.randn(dataset_size, 2)
y1 = x[:, 0] * 5
y2 = x[:, 0] * 3 - x[:, 1] * 2
y = np.zeros_like(x)
y[:, 0] = y1
y[:, 1] = y2

tensor_x = torch.Tensor(x) # transform to torch tensor
tensor_y = torch.Tensor(y)

my_dataset = TensorDataset(tensor_x, tensor_y) # create your datset
my_dataloader = DataLoader(my_dataset, batch_size=batch_size) # create your dataloader

In [5]:
epochs = 100
lr = 5e-3
lr_decay = 0.99
my_net.train()
optim = torch.optim.SGD(my_net.parameters(), lr=lr)

train_scheduler = LambdaLR(
    optimizer=optim, lr_lambda=lambda epoch: lr_decay**epoch
)

loss_fn = nn.MSELoss()

In [6]:
for epoch in range(epochs):
    for i, (net_inp, target) in enumerate(my_dataloader):
        net_inp = net_inp.to(device)
        target = target.to(device)

        optim.zero_grad()
        pred = my_net(net_inp)
        loss = loss_fn(pred, target)
        loss.backward()
        optim.step()

        if epoch % 10 == 0 and i % 100 == 0:
            print(f"=== loss === {loss}")

    train_scheduler.step(epoch)

=== loss === 18.090970993041992




=== loss === 1.0944627523422241
=== loss === 1.0761324167251587
=== loss === 1.0718954801559448
=== loss === 1.074318289756775
=== loss === 1.0783538818359375
=== loss === 1.0761851072311401
=== loss === 1.0689371824264526
=== loss === 1.0653672218322754
=== loss === 1.063279151916504


## Expand model (create excess weights)

In [7]:
expanded_model = nn.Sequential(
    nn.Linear(in_features=2, out_features=4, bias=False),
    nn.ReLU(),
    nn.Linear(in_features=4, out_features=2, bias=False),
)

nn.init.zeros_(expanded_model[0].weight);
nn.init.zeros_(expanded_model[2].weight);

In [8]:
sh0 = my_net[0].weight.shape
sh2 = my_net[2].weight.shape

expanded_model[0].weight.data[:sh0[0], :sh0[1]] = my_net[0].weight.data
expanded_model[2].weight.data[:sh2[0], :sh2[1]] = my_net[2].weight.data

expanded_model.to(device);

In [9]:
expanded_model[0].weight

Parameter containing:
tensor([[ 2.4329, -0.4615],
        [-0.6964, -1.3951],
        [-2.4095,  0.5311],
        [ 0.0000,  0.0000]], device='cuda:0', requires_grad=True)

In [10]:
my_net[0].weight

Parameter containing:
tensor([[ 2.4329, -0.4615],
        [-0.6964, -1.3951],
        [-2.4095,  0.5311]], device='cuda:0', requires_grad=True)

In [11]:
# Check that outputs are the same

inp = torch.rand((10, 2)).to(device)
print(expanded_model(inp) - my_net(inp))

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0', grad_fn=<SubBackward0>)


## Pruning

Pruning criterion taken from https://arxiv.org/pdf/2004.14340.pdf

In [12]:
class FisherMatrix:
    """Based on article: https://arxiv.org/pdf/2004.14340.pdf"""
    def __init__(self, device=None):
        self._n_batches_tracked = 0
        self._fisher_matrix = None
        if device:
            self._device = device
        else:
            self._device = "cuda" if torch.cuda.is_available() else "cpu"

    def update(self, grads: torch.Tensor, batch_size: int):
        grads_T = grads.reshape((*grads.shape, 1))

        if self._fisher_matrix is None:
            self._fisher_matrix = 0

        self._fisher_matrix += grads_T * grads
        self._n_batches_tracked += 1

    def get(self):
        if self._fisher_matrix is not None:
            return self._fisher_matrix / self._n_batches_tracked
        return None

In [13]:
fm = FisherMatrix(device)

In [14]:
for net_inp, target in my_dataloader:
    optim.zero_grad()

    net_inp = net_inp.to(device)
    target = target.to(device)    

    out = expanded_model(net_inp)
    loss = loss_fn(out, target)
    loss.backward()

    fm.update(
        grads=torch.flatten(expanded_model[0].weight.grad),
        batch_size=my_dataloader.batch_size,
    )

In [15]:
def pruning_criterion(ind_x, ind_y, weights, hessian, grads):
    try:
        hessian_inv = torch.linalg.inv(hessian)
    except:
        hessian_inv = torch.linalg.inv(hessian + 1e-6 * torch.eye(hessian.shape[0]).to(hessian.device))

    # print(hessian_inv)

    w_q = weights[ind_y, ind_x]
    h, w = weights.shape
    idx = ind_x + ind_y * w
    h_inv_qq = hessian_inv[idx, idx]

    l1 = 0.5 * w_q**2 / h_inv_qq
    l2 = 0.5 * (hessian_inv[idx, :] @ grads)**2 / h_inv_qq
    l3 = -w_q * (hessian_inv[idx, :] @ grads) / h_inv_qq
    l4 = -0.5 * grads.T @ hessian_inv @ grads
    print(f"l1={l1.item()}, l2={l2.item()}, l3={l3.item()}, l4={l4.item()}")

    return l1.item() + l2.item() + l3.item() + l4.item()

In [16]:
ind_y = 3
ind_x = 1

print(f"Weights on first linear layer: {expanded_model[0].weight.data}")
print(f"We want to prune weight at indices {ind_y, ind_x} with value {expanded_model[0].weight[ind_y, ind_x]}")

Weights on first linear layer: tensor([[ 2.4329, -0.4615],
        [-0.6964, -1.3951],
        [-2.4095,  0.5311],
        [ 0.0000,  0.0000]], device='cuda:0')
We want to prune weight at indices (3, 1) with value 0.0


In [17]:
pruning_criterion(
    ind_x=ind_x,
    ind_y=ind_y,
    weights=expanded_model[0].weight,
    hessian=fm.get(),
    grads=torch.flatten(expanded_model[0].weight.grad),
)

l1=0.0, l2=0.0, l3=-0.0, l4=-3.492868185043335


  l4 = -0.5 * grads.T @ hessian_inv @ grads


-3.492868185043335