In [None]:
import os
import datetime
import argparse
import wandb
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data.distributed
import torch.utils.data
import tensorly as tl

import utils.common as utils
from data import cifar10
from models.cifar10.vgg import vgg_16_bn
from ptflops import get_model_complexity_info
import numpy as np


In [None]:
import copy
from tqdm.auto import tqdm


In [None]:
from tensorly.decomposition import parafac
tl.set_backend('pytorch')


In [None]:
origin_model = vgg_16_bn().cuda()
ckpt = torch.load('./checkpoint/cifar10/vgg_16_bn.pt', map_location='cuda:0')
origin_model.load_state_dict(ckpt['state_dict'])


In [None]:
print(origin_model)


In [None]:
import torchvision
from torchvision import datasets, transforms


def load_data():

    # load training data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data/cifar-10-batches-py/', train=True, download=True,
                                            transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=256, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(
        root='./data/cifar-10-batches-py/', train=False, download=True, transform=transform_test)
    val_loader = torch.utils.data.DataLoader(
        testset, batch_size=256, shuffle=False, num_workers=2)

    return train_loader, val_loader


In [None]:
print_freq = (256*50)//256


In [None]:
def train(epoch, train_loader, model, criterion, optimizer, scheduler):
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    model.train()

    for param_group in optimizer.param_groups:
        cur_lr = param_group['lr']
    print('learning_rate: ' + str(cur_lr))

    num_iter = len(train_loader)
    for i, (images, target) in enumerate(train_loader):
        images = images.cuda()
        target = target.cuda()

        # compute outputy
        logits = model(images)
        loss = criterion(logits, target)

        # measure accuracy and record loss
        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = images.size(0)
        losses.update(loss.item(), n)  # accumulated loss
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if i % print_freq == 0:
            print(
                'Epoch[{0}]({1}/{2}): '
                'Loss {loss.avg:.4f} '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '
                'Lr {cur_lr:.4f}'.format(
                    epoch, i, num_iter, loss=losses,
                    top1=top1, top5=top5, cur_lr=cur_lr))

    return losses.avg, top1.avg, top5.avg


In [None]:
def validate(val_loader, model, criterion):
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    # switch to evaluation mode
    model.eval()
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()

            # compute output
            logits = model(images)
            loss = criterion(logits, target)

            # measure accuracy and record loss
            pred1, pred5 = utils.accuracy(logits, target, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg


In [None]:
train_loader, val_loader = load_data()
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
_, ori_acc, _ = validate(val_loader, origin_model, criterion)


In [None]:
class LayerC(nn.Module):
    def __init__(self, C, padding):
        super(LayerC, self).__init__()
        self.C = nn.Parameter(torch.Tensor(C))
        self.padding = padding

    def forward(self, input):
        # Add padding to input
        batch_size, Cin, h, w = input.shape
        padded_I = nn.functional.pad(input, [self.padding]*4)
        padded_I = padded_I.permute(0, 2, 3, 1)
        device = input.device
        Cout, _, r = self.C.shape
        # Calculate output size after padding
        padded_h = h + 2 * self.padding
        padded_w = w + 2 * self.padding

        # Step 1: Compute Oc
        padded_I_col = padded_I.reshape(batch_size * padded_h * padded_w, Cin)
        C_col = self.C.permute(1, 0, 2).reshape(Cin, Cout * r)

        # Compute matrix multiplication and reshape output
        output = torch.matmul(padded_I_col, C_col).reshape(
            batch_size, padded_h, padded_w, Cout, r)

        return output


In [None]:
class LayerB(nn.Module):
    def __init__(self, B, padding):
        super(LayerB, self).__init__()
        self.B = nn.Parameter(torch.Tensor(B))
        self.padding = padding

    def forward(self, input):
        device = input.device
        _, padded_h, padded_w, Cout, r = input.shape
        Cout, d, _ = self.B.shape
        w = padded_h - 2 * self.padding

        # Step 2: Compute Ob
        Oc = input.permute(0, 1, 3, 4, 2)
        B = self.B.permute(0, 2, 1)

        # Add a new axis to B for broadcasting, B's shape becomes (1, 1, Cout, r, 1, d)
        B_expanded = B[None, None, :, :, None, :]
        # Assuming 'Oc' is a 5-dimensional and 'w' and 'd' are the window width and depth, respectively
        window_indices = torch.arange(w)[:, None] + torch.arange(d)
        Oc_expanded = Oc[:, :, :, :, window_indices]

        # Perform the element-wise multiplication and sum over the last axis (d)
        output = torch.sum(Oc_expanded * B_expanded, dim=-1)

        return output


In [None]:
class LayerA(nn.Module):
    def __init__(self, A, padding, bias):
        super(LayerA, self).__init__()
        self.A = nn.Parameter(torch.Tensor(A))
        self.padding = padding
        self.bias = nn.Parameter(bias)

    def forward(self, input):
        device = input.device
        _, padded_h, Cout, r, w = input.shape
        Cout, d, _ = self.A.shape

        h = padded_h - 2 * self.padding

        # Step 3: Compute Oa
        A = self.A.permute(0, 2, 1)
        Ob = input.permute(0, 4, 2, 3, 1)

        # Add a new axis to B for broadcasting, A's shape becomes (1, 1, Cout, r, 1, d)
        A_expanded = A[None, None, :, :, None, :]
        # Assuming 'Ob' is a 5-dimensional and 'h' and 'd' are the window width and depth, respectively

        window_indices = torch.arange(h)[:, None] + torch.arange(d)
        Ob_expanded = Ob[:, :, :, :, window_indices]

        # Perform the element-wise multiplication and sum over the last axis (d)
        Oa = torch.sum(Ob_expanded * A_expanded, axis=-1)

        Oa = Oa.permute(0, 4, 1, 2, 3)

        # Step 4: Compute O
        output = torch.sum(Oa, dim=-1) + self.bias

        output = output.permute(0, 3, 1, 2)

        return output


In [None]:
from decomposition.CPDLayers import *


In [None]:
def cp_decomposition_conv_layer(layer, rank):
    """ Gets a conv layer and a target rank, 
        returns a nn.Sequential object with the decomposition """

    padding = layer.padding[0]
    kernel_size = layer.kernel_size[0]
    Cin = layer.in_channels
    Cout = layer.out_channels
    W = layer.weight.data
    device = W.get_device()

    # Initialize the factor matrices with zeros
    body_factors = torch.zeros(Cout, kernel_size, rank, device=device)
    tail_factors = torch.zeros(Cout, kernel_size, rank, device=device)
    head_factors = torch.zeros(Cout, Cin, rank, device=device)

    for i in tqdm(range(Cout)):
        head_factors[i], tail_factors[i], body_factors[i] = parafac(
            W[i, :, :, :], rank=rank, n_iter_max=1000, tol=1e-32, init='random')[1]
        factors = parafac(W[i, :, :, :], rank=rank, n_iter_max=1000, tol=1e-32, init='random')
        resconstructed = tl.cp_to_tensor(factors)
        print((torch.norm(W[i, :, :, :]-resconstructed, p=2).item()/torch.norm(W[i, :, :, :], p=2).item())**2)
    # head = LayerC(head_factors, padding)
    # body = LayerB(body_factors, padding)
    # tail = LayerA(tail_factors, padding, layer.bias.data)
    assert not torch.isnan(head_factors).any(
    ), "head_factors tensor from parafac is nan"
    assert not torch.isnan(body_factors).any(
    ), "body_factors tensor from parafac is nan"
    assert not torch.isnan(tail_factors).any(
    ), "tail_factors tensor from parafac is nan"

    head_factors = head_factors.permute(1, 0, 2)
    body_factors = body_factors.permute(0, 2, 1)
    tail_factors = tail_factors.permute(0, 2, 1)
    head = CPDHead(Cin, Cout, rank, padding, head_factors)
    body = CPDBody(Cin, Cout, rank, kernel_size, padding, body_factors)
    tail = CPDTail(Cin, Cout, rank, kernel_size, padding,
                   tail_factors, layer.bias.data)

    new_layers = [head, body, tail]

    return nn.Sequential(*new_layers)


In [None]:
# W = torch.randn(128, 64, 3, 3)


In [None]:
# head_factors = torch.randn(128, 64, 9, device=0)
# body_factors = torch.randn(128, 3, 9, device=0)
# tail_factors = torch.randn(128, 3, 9, device=0)
# bias = torch.randn([128], device=0)


In [None]:
# for i in tqdm(range(128)):
#         head_factors[i], _, _ = parafac(W[i, :, :, :], rank=9, n_iter_max=1000, tol= 1e-32, init='random')[1]


In [None]:
# head = LayerC(head_factors, 1)
# body = LayerB(body_factors, 1)
# tail = LayerA(tail_factors, 1, bias)


In [None]:
# input = torch.randn((1, 64, 32, 32), device=0)


In [None]:
# tail(body(head(input)))


In [None]:
# new_head = CPDHead(64, 128, 9, 1, head_factors)
# new_body = CPDBody(64, 128, 9, 3, 1, body_factors)
# new_tail = CPDTail(64, 128, 9, 3, 1, tail_factors, bias)


In [None]:
# new_tail(new_body(new_head(input)))


In [None]:
# ori, backtracking
# def cp_decompose_model(model, exclude_first_conv=True, exclude_linears=True, passed_first_conv=False):
#     for name, module in model._modules.items():
#         if len(list(module.children())) > 0:
#             # recurse
#             model._modules[name] = cp_decompose_model(
#                 module, exclude_first_conv, exclude_linears, passed_first_conv)
#         elif type(module) == nn.Conv2d:
#             if passed_first_conv is False:
#                 passed_first_conv = True
#                 if exclude_first_conv is True:
#                     continue

#             # if (name == 'conv3'):
#             print('name ', name)
#             conv_layer = module
#             rank = 9

#             decomposed = cp_decomposition_conv_layer(conv_layer, rank)

#             model._modules[name] = decomposed

#             print(model)

#             _, decomposed_acc, _ = validate(val_loader, model, criterion)

#             print(decomposed_acc)


#     return model


In [None]:
def cp_decompose_model(model, exclude_first_conv=True, exclude_linears=True, passed_first_conv=False):
    for high_name, high_module in model._modules.items():
        # print(high_name, high_module)
        if high_name == 'features':
            for name, module in high_module._modules.items():
                if type(module) == nn.Conv2d:
                    if passed_first_conv is False:
                        passed_first_conv = True
                        if exclude_first_conv is True:
                            continue

                    # if name == 'conv15':# or name == 'conv4' or name == 'conv6':
                    print('name ', name)
                    conv_layer = module
                    rank = 1

                    decomposed = cp_decomposition_conv_layer(conv_layer, rank)

                    # model._modules[name] = decomposed
                    model._modules[high_name]._modules[name] = decomposed

                    # print(model)

                    _, decomposed_acc, _ = validate(
                        val_loader, model, criterion)

                    # print(decomposed_acc)

    return model


In [None]:
model = copy.deepcopy(origin_model)


In [None]:
model = cp_decompose_model(model)


In [None]:
print(model)


In [None]:
optimizer = torch.optim.SGD(model.parameters(
), lr=0.01, momentum=0.9, weight_decay=5e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, div_factor=10, epochs=100, steps_per_epoch=len(
    train_loader), pct_start=0.1, final_div_factor=100)

start_epoch = 0
best_top1_acc = 0

# train the model
epoch = start_epoch
while epoch < 1:
    train(epoch,  train_loader, model, criterion, optimizer, scheduler)
    _, valid_top1_acc, valid_top5_acc = validate(
        val_loader, model, criterion)

    is_best = False
    if valid_top1_acc > best_top1_acc:
        best_top1_acc = valid_top1_acc
        is_best = True

    cur_lr = optimizer.param_groups[0]["lr"]
    print('epoch', epoch, 'best_acc', max(
        valid_top1_acc, best_top1_acc), 'top1', valid_top1_acc)

    epoch += 1
    print("=>Best accuracy {:.3f}".format(best_top1_acc))


In [None]:
conv = origin_model.features.conv14

In [None]:
conv.weight.shape

In [None]:
conv.weight

In [None]:
dcp_layer = cp_decomposition_conv_layer(conv, rank=1)

In [None]:
rand_inp = torch.randn((1, 512, 2, 2), device=0)

In [None]:
print(conv)
print(dcp_layer)

In [None]:
ori_output = conv(rand_inp)

In [None]:
ori_output

In [None]:
my_output = dcp_layer(rand_inp)

In [None]:
(torch.norm(ori_output-my_output, p=2).item()/torch.norm(ori_output, p=2).item())**2

In [None]:
torch.norm(ori_output-my_output)

In [None]:
def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements

In [None]:
get_num_parameters(origin_model)

In [None]:
get_num_parameters(model)

In [None]:
from torchprofile import profile_macs

In [None]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)

In [None]:
dummy_input = torch.randn(1, 3, 32, 32).cuda()

In [None]:
get_model_macs(model, dummy_input)

In [None]:
get_model_macs(origin_model, dummy_input)