In [None]:
import torch
import tensorly as tl
from torch import nn
from torchvision import models
from tensorly import decomposition

tl.set_backend("pytorch")


def cp_decomposition(layer, rank):
    weights, factors = decomposition.parafac(
        tensor=layer.weight.data,
        rank=rank,
        init="random",
        normalize_factors=False
    )
    last, first, vertical, horizontal = factors

    pointwise_s_to_r_layer = nn.Conv2d(
        first.shape[0],
        first.shape[1],
        kernel_size=1,
        stride=1,
        padding=0,
        dilation=layer.dilation,
        bias=False,
    )
    depthwise_vertical_layer = nn.Conv2d(
        vertical.shape[1],
        vertical.shape[1],
        kernel_size=(vertical.shape[0], 1),
        stride=1,
        padding=(layer.padding[0], 0),
        dilation=layer.dilation,
        groups=vertical.shape[1],
        bias=False,
    )
    depthwise_horizontal_layer = nn.Conv2d(
        horizontal.shape[1],
        horizontal.shape[1],
        kernel_size=(1, horizontal.shape[0]),
        stride=layer.stride,
        padding=(0, layer.padding[0]),
        dilation=layer.dilation,
        groups=horizontal.shape[1],
        bias=False,
    )
    pointwise_r_to_t_layer = nn.Conv2d(
        last.shape[1],
        last.shape[0],
        kernel_size=1,
        stride=1,
        padding=0,
        dilation=layer.dilation,
        bias=True,
    )
    pointwise_r_to_t_layer.bias.data = layer.bias.data

    depthwise_horizontal_layer.weight.data = (
        torch.transpose(horizontal, 1, 0).unsqueeze(1).unsqueeze(1)
    )
    depthwise_vertical_layer.weight.data = (
        torch.transpose(vertical, 1, 0).unsqueeze(1).unsqueeze(-1)
    )
    pointwise_s_to_r_layer.weight.data = (
        torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1)
    )
    pointwise_r_to_t_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1)

    new_layers = [
        pointwise_s_to_r_layer,
        depthwise_vertical_layer,
        depthwise_horizontal_layer,
        pointwise_r_to_t_layer,
    ]
    return nn.Sequential(*new_layers)


model = models.vgg16(num_classes=2)
model.load_state_dict(torch.load("../models/VGG16.pt"))
model.eval()

In [None]:
import copy


decomposed_model = copy.deepcopy(model)
for idx, module in enumerate(decomposed_model.features):
    if isinstance(module, nn.Conv2d):
        rank = max(module.weight.data.numpy().shape) // 3
        decomposed_model.features[idx] = cp_decomposition(module, rank)

print("CP 분해 전 가중치 수 :", sum(param.numel() for param in model.parameters()))
print("CP 분해 후 가중치 수 :", sum(param.numel() for param in decomposed_model.parameters()))