In [17]:
import torch
import tensorly as tl
from tensorly.decomposition import partial_tucker
from tltorch import FactorizedConv
import tltorch
from torchvision.models import resnet18
import numpy as np

In [18]:
def count_params(net: torch.nn.Module) -> np.array:
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [19]:
model = resnet18(weights=None)

In [20]:
n_param = count_params(model)
print(f'Number of parameters (before): {n_param}')

Number of parameters (before): 11689512


In [21]:
import copy

factorization = 'tucker'
rank = 0.6
decompose_weights = True
td_init = not decompose_weights

decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
fixed_rank_modes = 'spatial' if factorization == 'tucker' else None

fact_model = copy.deepcopy(model)

layer_names = ['layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.1.conv1', 'layer2.1.conv2', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.1.conv1', 'layer3.1.conv2', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.1.conv1', 'layer4.1.conv2']

for i, (name, module) in enumerate(model.named_modules()):
    if name in layer_names:
        print(f'factorizing: {name}')
        if type(module) == torch.nn.modules.conv.Conv2d:
            fact_layer = tltorch.FactorizedConv.from_conv(
                module, 
                rank=rank, 
                decompose_weights=decompose_weights, 
                factorization=factorization,
                fixed_rank_modes=fixed_rank_modes,
                decomposition_kwargs=decomposition_kwargs,
            )
            if td_init:
                fact_layer.weight.normal_(0, td_init)
            layer, block, conv = name.split('.')
            conv_to_replace = getattr(getattr(fact_model, layer), block)
            setattr(conv_to_replace, conv, fact_layer)
            
n_param_fact = count_params(fact_model)

factorizing: layer1.0.conv1
factorizing: layer1.0.conv2
factorizing: layer1.1.conv1
factorizing: layer1.1.conv2
factorizing: layer2.0.conv1
factorizing: layer2.0.conv2
factorizing: layer2.1.conv1
factorizing: layer2.1.conv2
factorizing: layer3.0.conv1
factorizing: layer3.0.conv2
factorizing: layer3.1.conv1
factorizing: layer3.1.conv2
factorizing: layer4.0.conv1
factorizing: layer4.0.conv2
factorizing: layer4.1.conv1
factorizing: layer4.1.conv2


In [22]:
print(f'original number of parameters: {n_param}')
print(f'factorized number of parameters: {n_param_fact}')
print(f'before - after: {n_param - n_param_fact}')
print(f'compression ratio: {n_param / n_param_fact:.2f}')

original number of parameters: 11689512
factorized number of parameters: 7302933
before - after: 4386579
compression ratio: 1.60


In [23]:
fact_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): FactorizedConv(
        in_channels=64, out_channels=64, kernel_size=(3, 3), rank=(43, 43, 3, 3), order=2, padding=[1, 1], bias=False
        (weight): TuckerTensor(shape=(64, 64, 3, 3), rank=(43, 43, 3, 3))
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): FactorizedConv(
        in_channels=64, out_channels=64, kernel_size=(3, 3), rank=(43, 43, 3, 3), order=2, padding=[1, 1], bias=False
        (weight): TuckerTensor(shape=(64, 64, 3, 3), rank=(43, 43, 3, 3))
      )
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True

In [24]:
model = model.eval()
fact_model = fact_model.eval()

x = torch.randn(1, 3, 32, 32)
y = model(x)
y_tn = fact_model(x)
print(y.shape)
print(y_tn.shape)

torch.Size([1, 1000])
torch.Size([1, 1000])
