In [8]:
import torch
import timm
import lightning.pytorch as pl
import lightning.pytorch.loggers as pl_loggers
from torchvision import models
import copy
import tltorch
import tensorly as tl

In [9]:
NUM_WORKERS = 4
# reproducibility
SEED = 42
# reproducibility
pl.seed_everything(42)
# allow tf32 (TENSOR CORES)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
torch.backends.cudnn.deterministic = True  # deterministic cudnn

Global seed set to 42


In [10]:
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)

Using cache found in /home/usainzg/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [11]:
model

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [12]:
layer_names = ['layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.1.conv1', 'layer2.1.conv2', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.1.conv1', 'layer3.1.conv2', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.1.conv1', 'layer4.1.conv2']

In [13]:
def factorize_layer(
    module,
    factorization='tucker',
    rank=None,
    decompose_weights=False,
    vbmf=0,
    implementation='reconstructed'
):
    init_std = None if decompose_weights else 0.01
    decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
    fixed_rank_modes = 'spatial' if factorization == 'tucker' else None
    # implementation see: https://github.com/tensorly/torch/blob/d27d58f16101b7ecc431372eb218ceda59d8b043/tltorch/functional/convolution.py#L286
    
    if rank is None and vbmf == 0 and factorization != 'tucker':
        raise ValueError('rank must be specified for non-tucker factorization')
    
    if not decompose_weights:
        vbmf = 0 

    if type(module) == torch.nn.modules.conv.Conv2d:
        # rank selection
        
        if rank is not None:
            ranks = rank
        else:
            weights = module.weight.data
            ranks = [weights.shape[0]//3, weights.shape[1]//3, weights.shape[2], weights.shape[3]]
        
        # factorize from conv layer
        fact_module = tltorch.FactorizedConv.from_conv(
            module,
            rank=ranks,
            decompose_weights=decompose_weights,
            factorization=factorization,
            fixed_rank_modes=fixed_rank_modes,
            implementation=implementation,
            decomposition_kwargs=decomposition_kwargs
        )
    elif type(module) == torch.nn.modules.linear.Linear:
        fact_module = tltorch.FactorizedLinear.from_linear(
            module,
            n_tensorized_modes=3,
            rank=rank,
            factorization=factorization,
            decompose_weights=decompose_weights,
            fixed_rank_modes=fixed_rank_modes,
            decomposition_kwargs=decomposition_kwargs
        )
    else:
        raise NotImplementedError(type(module))
    
    if init_std:
        #print('Initializing with std')
        fact_module.weight.normal_(0, init_std)
    
    return fact_module

In [14]:
fact_model = copy.deepcopy(model)
tn_decomp = 'tucker'
rank = 0.8
vbmf = 0
decompose_weights = True
implementation = 'reconstructed'

# factorize resnet
for i, (name, module) in enumerate(model.named_modules()):
    if name in layer_names:
        
        print(f'factorizing: {name}')
        fact_module = factorize_layer(
            module=module, 
            factorization=tn_decomp, 
            rank=rank, 
            vbmf=vbmf,
            decompose_weights=decompose_weights,
            implementation=implementation
        )
        layer, block, conv = name.split('.')
        conv_to_replace = getattr(getattr(fact_model, layer), block)
        setattr(conv_to_replace, conv, fact_module)

factorizing: layer1.0.conv1
factorizing: layer1.0.conv2
factorizing: layer1.1.conv1
factorizing: layer1.1.conv2
factorizing: layer2.0.conv1
factorizing: layer2.0.conv2
factorizing: layer2.1.conv1
factorizing: layer2.1.conv2
factorizing: layer3.0.conv1
factorizing: layer3.0.conv2
factorizing: layer3.1.conv1
factorizing: layer3.1.conv2


In [16]:
fact_model

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): FactorizedConv(
        in_channels=16, out_channels=16, kernel_size=(3, 3), rank=(13, 13, 3, 3), order=2, padding=[1, 1], bias=False
        (weight): TuckerTensor(shape=(16, 16, 3, 3), rank=(13, 13, 3, 3))
      )
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): FactorizedConv(
        in_channels=16, out_channels=16, kernel_size=(3, 3), rank=(13, 13, 3, 3), order=2, padding=[1, 1], bias=False
        (weight): TuckerTensor(shape=(16, 16, 3, 3), rank=(13, 13, 3, 3))
      )
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): FactorizedCon

In [17]:
%timeit model(torch.randn(1, 3, 32, 32))

2.12 ms ± 62.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%timeit fact_model(torch.randn(1, 3, 32, 32))

5.61 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
