# Conv ResNet

we are applying tucker compression for mode-0 and mode-1 (Channel in and channel out).

suppose $\mathbf{W} \in \mathbb{R}^{C_{out} \times C_{in} \times k_{h} \times k_{w}}$ is a 2D 

Let 
$\mathbf{X}^{(0)} \in \mathbb{R}^{C_{in} \times H_{in} \times W_{in}}$ : input feature map

therefore, we perform tucker decomposition:
$$\mathbf{W} \approx \mathbf{S} \times_{0} U^{(0)} \times_{1}U^{(1)}$$
where:
$U^{(0)} \in \mathbb{R}^{C_{out} \times r_{0}}$ : output (channel) basis
$U^{(1)} \in \mathbb{R}^{C_{in} \times r_{1}}$ : input (channel) basis
$\mathbf{S} \in \mathbb{R}^{r_{0} \times r_{1} \times k_{h} \times k_{w}}$ : compressed core kernel

## 1. reduce channels
from $C_{in} \rightarrow r_{0}$ using mode-1 multiplication
$$\mathbf{X}^{(1)} = U^{(1)\intercal} \times_{1} \mathbf{X}^{(0)}\quad \in \mathbb{R}^{r_{1} \times H_{in} \times W_{in}}$$
which is equivalent to $1\times 1$ convolution with weight $U^{(1)\intercal} \in \mathbb{R}^{r_{1} \times C_{in}}$

## 2. spatial convolution with core
$$\mathbf{X}^{(2)} = \mathbf{S} * \mathbf{X}^{(1)} \quad \in \mathbb{R}^{r_{0} \times H_{out} \times W_{out}}$$
convolution using the compressed core
$\mathbf{S} \in \mathbb{R}^{r_{0} \times r_{1} \times k_{h} \times k_{w}}$

## 3. expand channels
From $r_{0} \rightarrow C_{out}$ using mode-0 multiplication
$$
\mathbf{Y} = U^{(0)} \times_{1} \mathbf{X}^{(2)} \quad \in \mathbb{R}^{C_{out} \times H_{out} \times W_{out}}
$$


**therefore**
$$\mathbf{Y} \approx U^{(0)} \cdot \left(  \mathbf{S} * \left( U^{(1)\intercal} \cdot \mathbf{X}^{(0)} \right)   \right) $$

this reduces cost 
from: $\mathcal{O}(C_{out} \cdot  C_{in} \cdot H \cdot W)$
to: $\mathcal{O}(C_{out} r_{0} + r_{0}r_{1}  H W + r_{1} C_{in})$


CIFAR-100 ImageFolder already prepared.


In [2]:
import torch
import torch.nn as nn
from typing import Tuple
from src.logger import setup_logger


class TuckerCompressor:
    def __init__(self) -> None:
        self.logger = setup_logger("api.log")

    def compress_conv2d(
        self,
        conv: nn.Conv2d,
        rank: Tuple[int, int]
    ) -> nn.Sequential:
        """Compress a Conv2d layer using Tucker decomposition (mode-0 and mode-1)."""
        C_out, C_in = conv.out_channels, conv.in_channels
        kh, kw = conv.kernel_size
        stride, padding, dilation = conv.stride, conv.padding, conv.dilation
        r_out, r_in = rank

        self.logger.info(
            f"Compressing Conv2d: Cin={C_in}, Cout={C_out}, kh={kh}, kw={kw}, rank={rank}"
        )

        weight = conv.weight.data  # (C_out, C_in, kh, kw)
        weight_unfold_0 = weight.reshape(C_out, -1)
        weight_unfold_1 = weight.permute(1, 0, 2, 3).reshape(C_in, -1)

        U_0, _, _ = torch.linalg.svd(weight_unfold_0, full_matrices=False)
        U_1, _, _ = torch.linalg.svd(weight_unfold_1, full_matrices=False)

        U_0_tilde = U_0[:, :r_out]  # (C_out, r_out)
        U_1_tilde = U_1[:, :r_in]   # (C_in, r_in)

        # Compute core tensor: W ×₀ U₀ᵀ ×₁ U₁ᵀ
        core = torch.einsum('oc, cihw -> oihw', U_0_tilde.T, weight)
        core = torch.einsum('ci, oihw -> ochw', U_1_tilde.T, core)


        first_1x1 = nn.Conv2d(
            in_channels=C_in,
            out_channels=r_in,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )

        core_conv = nn.Conv2d(
            in_channels=r_in,
            out_channels=r_out,
            kernel_size=(kh, kw),
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=False
        )

        last_1x1 = nn.Conv2d(
            in_channels=r_out,
            out_channels=C_out,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True
        )

        with torch.no_grad():
            first_1x1.weight.copy_(U_1_tilde.T.unsqueeze(-1).unsqueeze(-1))
            core_conv.weight.copy_(core)
            if conv.bias is not None:
                last_1x1.bias.copy_(conv.bias)
            last_1x1.weight.copy_(U_0_tilde.unsqueeze(-1).unsqueeze(-1))

        return nn.Sequential(first_1x1, core_conv, last_1x1)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models 
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader
from typing import Tuple, Optional
from tqdm import tqdm
import copy
import time


from src.logger import setup_logger


logger = setup_logger("tucker_eval.log")



def get_dataloader(data_root: str, batch_size: int = 128, num_workers: int = 4):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    train_set = datasets.ImageFolder(root=f"{data_root}/train", transform=transform_train)
    val_set = datasets.ImageFolder(root=f"{data_root}/val", transform=transform_test)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader



def evaluate(model: nn.Module, dataloader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            pred = outputs.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total


def fine_tune(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader, device: torch.device,
              epochs: int = 5, lr: float = 0.01):
    model.train()
    optimizer = torch.optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=lr, momentum=0.9, weight_decay=5e-4
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
        scheduler.step()
        acc = evaluate(model, test_loader, device)
        logger.info(f"Fine-tune Epoch {epoch+1}, Accuracy: {acc:.4f}")


def compress_model(model: nn.Module, ratio: float = 0.2) -> nn.Module:
    compressor = TuckerCompressor()

    def compress_layer(module: nn.Module):
        for name, child in module.named_children():
            if isinstance(child, nn.Conv2d):
                r_out = max(1, int(child.out_channels * ratio))
                r_in = max(1, int(child.in_channels * ratio))
                setattr(module, name, compressor.compress_conv2d(child, rank=(r_out, r_in)))
            else:
                compress_layer(child)

    model = copy.deepcopy(model)
    compress_layer(model)
    return model



def get_resnet18_100_classes():
    model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 100)  # 100 CIFAR-100 classes
    return model


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


# 1) Setup
train_loader, test_loader = get_dataloader(
    data_root=dataset_path, 
    batch_size=64, 
    num_workers=4
    )
model = get_resnet18_100_classes().to(device)

t0 = time.time()
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True
fine_tune(model, train_loader, test_loader, device, epochs=5)
logger.info(f"Finished finetuning original model in {time.time() - t0:.4f}s")



cuda


Epoch 1/5: 100%|██████████| 782/782 [04:36<00:00,  2.83it/s]
2025-04-09 16:27:35,987 - INFO - Fine-tune Epoch 1, Accuracy: 0.4408
Epoch 2/5: 100%|██████████| 782/782 [03:40<00:00,  3.54it/s]
2025-04-09 16:32:12,038 - INFO - Fine-tune Epoch 2, Accuracy: 0.4781
Epoch 3/5: 100%|██████████| 782/782 [04:18<00:00,  3.02it/s]
2025-04-09 16:37:27,021 - INFO - Fine-tune Epoch 3, Accuracy: 0.5105
Epoch 4/5: 100%|██████████| 782/782 [09:17<00:00,  1.40it/s]
2025-04-09 16:48:27,710 - INFO - Fine-tune Epoch 4, Accuracy: 0.5108
Epoch 5/5: 100%|██████████| 782/782 [05:28<00:00,  2.38it/s]
2025-04-09 16:55:04,072 - INFO - Fine-tune Epoch 5, Accuracy: 0.5196
2025-04-09 16:55:04,076 - INFO - Finished finetuning original model in 1964.6274065971375s


In [5]:
# 2) Evaluate_original
t0 = time.time()
original_acc = evaluate(model, test_loader, device)
logger.info(f"Original ResNet18 Accuracy: {original_acc:.4f}")
logger.info(f"Evaluated in {time.time() - t0:.4f}s")




2025-04-09 16:56:18,613 - INFO - Original ResNet18 Accuracy: 0.5196
2025-04-09 16:56:18,615 - INFO - Evaluated in 40.9793s


In [6]:

# 3) Compress_model
t0 = time.time()
compressed_model = compress_model(model, ratio=0.2)
compressed_model.to(device)
logger.info(f"Model compressed in {time.time() - t0:.4f}s")
del model
torch.cuda.empty_cache()




2025-04-09 16:56:51,632 - INFO - Logger already initialized
2025-04-09 16:56:51,664 - INFO - Compressing Conv2d: Cin=3, Cout=64, kh=7, kw=7, rank=(12, 1)
2025-04-09 16:56:51,759 - INFO - Compressing Conv2d: Cin=64, Cout=64, kh=3, kw=3, rank=(12, 12)
2025-04-09 16:56:51,767 - INFO - Compressing Conv2d: Cin=64, Cout=64, kh=3, kw=3, rank=(12, 12)
2025-04-09 16:56:51,774 - INFO - Compressing Conv2d: Cin=64, Cout=64, kh=3, kw=3, rank=(12, 12)
2025-04-09 16:56:51,780 - INFO - Compressing Conv2d: Cin=64, Cout=64, kh=3, kw=3, rank=(12, 12)
2025-04-09 16:56:51,787 - INFO - Compressing Conv2d: Cin=64, Cout=128, kh=3, kw=3, rank=(25, 12)
2025-04-09 16:56:51,797 - INFO - Compressing Conv2d: Cin=128, Cout=128, kh=3, kw=3, rank=(25, 25)
2025-04-09 16:56:51,809 - INFO - Compressing Conv2d: Cin=64, Cout=128, kh=1, kw=1, rank=(25, 12)
2025-04-09 16:56:51,816 - INFO - Compressing Conv2d: Cin=128, Cout=128, kh=3, kw=3, rank=(25, 25)
2025-04-09 16:56:51,829 - INFO - Compressing Conv2d: Cin=128, Cout=128, 

In [7]:
# 4) Evaluate_compressing
compressed_acc = evaluate(compressed_model, test_loader, device)
logger.info(f"Compressed ResNet18 (no fine-tune) Accuracy: {compressed_acc:.4f}")
logger.info(f"Evaluated in {time.time() - t0:.4f}s")




2025-04-09 16:57:28,796 - INFO - Compressed ResNet18 (no fine-tune) Accuracy: 0.0096
2025-04-09 16:57:28,799 - INFO - Evaluated in 37.1672s


In [8]:
# 5) fine_tune
t0 = time.time()
fine_tune(compressed_model, train_loader, test_loader, device, epochs=5)
logger.info(f"Finetuned in {time.time() - t0:.4f}s")



Epoch 1/5: 100%|██████████| 782/782 [10:28<00:00,  1.24it/s]
2025-04-09 17:09:13,088 - INFO - Fine-tune Epoch 1, Accuracy: 0.2414
Epoch 2/5: 100%|██████████| 782/782 [11:14<00:00,  1.16it/s]
2025-04-09 17:21:33,421 - INFO - Fine-tune Epoch 2, Accuracy: 0.3454
Epoch 3/5: 100%|██████████| 782/782 [14:07<00:00,  1.08s/it]
2025-04-09 17:37:00,586 - INFO - Fine-tune Epoch 3, Accuracy: 0.5040
Epoch 4/5: 100%|██████████| 782/782 [15:39<00:00,  1.20s/it]
2025-04-09 17:54:00,698 - INFO - Fine-tune Epoch 4, Accuracy: 0.5178
Epoch 5/5: 100%|██████████| 782/782 [11:25<00:00,  1.14it/s]
2025-04-09 18:06:38,770 - INFO - Fine-tune Epoch 5, Accuracy: 0.5336
2025-04-09 18:06:38,773 - INFO - Finetuned in 4133.8973s


In [9]:
# 6) evaluate finetuned
t0 = time.time()
final_acc = evaluate(compressed_model, test_loader, device)
logger.info(f"Compressed + Fine-Tuned ResNet18 Accuracy: {final_acc:.4f}")
logger.info(f"Evaluated in {time.time() - t0:.4f}s")

2025-04-09 18:07:57,728 - INFO - Compressed + Fine-Tuned ResNet18 Accuracy: 0.5336
2025-04-09 18:07:57,729 - INFO - Evaluated in 27.3281s
