In [7]:
import torch
import tensorly as tl
from tensorly.decomposition import partial_tucker
from tltorch import FactorizedConv
import tltorch
import numpy as np
import sys
sys.path.append('../src/td-comp/')
tl.set_backend('pytorch')

In [14]:
from models import vgg

In [9]:
def count_params(net: torch.nn.Module) -> np.array:
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

In [28]:
model = vgg.VGG('VGG11')

In [29]:
n_param = count_params(model)
print(f'Number of parameters (before): {n_param}')

Number of parameters (before): 9231114


In [30]:
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

In [33]:
import copy

factorization = 'tucker'
rank = 0.75
decompose_weights = True
td_init = not decompose_weights

decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
fixed_rank_modes = 'spatial' if factorization == 'tucker' else None

fact_model = copy.deepcopy(model)

for i, (name, module) in enumerate(model.named_modules()):
    if type(module) == torch.nn.modules.conv.Conv2d:
        if name == 'features.0':
            # Skip first layer
            continue
        print(f'factorizing: {name}')
        fact_layer = tltorch.FactorizedConv.from_conv(
            module, 
            rank=rank, 
            decompose_weights=decompose_weights, 
            factorization=factorization,
            fixed_rank_modes=fixed_rank_modes,
            decomposition_kwargs=decomposition_kwargs,
        )
        if td_init:
            fact_layer.weight.normal_(0, td_init)
        layer, block = name.split('.')
        conv_to_replace = getattr(fact_model, layer)
        setattr(conv_to_replace, block, fact_layer)
        
n_param_fact = count_params(fact_model)
print(f'Number of parameters (after): {n_param_fact}')

factorizing: features.4
factorizing: features.8
factorizing: features.11
factorizing: features.15
factorizing: features.18
factorizing: features.22
factorizing: features.25
Number of parameters (after): 6922095


In [34]:
fact_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): FactorizedConv(
      in_channels=64, out_channels=128, kernel_size=(3, 3), rank=(94, 47, 3, 3), order=2, padding=[1, 1], 
      (weight): TuckerTensor(shape=(128, 64, 3, 3), rank=(94, 47, 3, 3))
    )
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): FactorizedConv(
      in_channels=128, out_channels=256, kernel_size=(3, 3), rank=(189, 94, 3, 3), order=2, padding=[1, 1], 
      (weight): TuckerTensor(shape=(256, 128, 3, 3), rank=(189, 94, 3, 3))
    )
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, tr

In [35]:
print(f'original number of parameters: {n_param}')
print(f'factorized number of parameters: {n_param_fact}')
print(f'before - after: {n_param - n_param_fact}')
print(f'compression ratio: {n_param / n_param_fact:.2f}')

original number of parameters: 9231114
factorized number of parameters: 6922095
before - after: 2309019
compression ratio: 1.33
