In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import lightning.pytorch as pl
from tqdm import tqdm as tqdm

In [21]:
torch.set_float32_matmul_precision("high")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [22]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [23]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ResNet18()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [24]:
autoencoder = LitAutoEncoder()
trainer = pl.Trainer(accelerator="gpu", max_epochs=10)
trainer.fit(autoencoder, trainloader, testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 391/391 [00:12<00:00, 31.06it/s, v_num=3]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 391/391 [00:12<00:00, 30.57it/s, v_num=3]


In [25]:
import copy
import tltorch

model = ResNet18()
factorization = 'tucker'
rank = 0.5
decompose_weights = True
td_init = not decompose_weights

decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
fixed_rank_modes = 'spatial' if factorization == 'tucker' else None

fact_model = copy.deepcopy(model)

ranks = {
    'layer1.0.conv1': (64, 16),
    'layer1.0.conv2': (64, 16),
    'layer1.1.conv1': (64, 16),
    'layer1.1.conv2': (64, 16),
    'layer2.0.conv1': (128, 16),
    'layer2.0.conv2': (128, 16),
    'layer2.1.conv1': (128, 16),
    'layer2.1.conv2': (128, 16),
    'layer3.0.conv1': (256, 16),
    'layer3.0.conv2': (256, 16),
    'layer3.1.conv1': (256, 16),
    'layer3.1.conv2': (256, 16),
    'layer4.0.conv1': (512, 16),
    'layer4.0.conv2': (512, 16),
    'layer4.1.conv1': (512, 16),
    'layer4.1.conv2': (512, 16),
}

layer_names = list(ranks.keys())

for i, (name, module) in enumerate(model.named_modules()):
    if name in layer_names:
        print(f'factorizing: {name}')
        if type(module) == torch.nn.modules.conv.Conv2d:
            fact_layer = tltorch.FactorizedConv.from_conv(
                module, 
                rank=rank, 
                decompose_weights=decompose_weights, 
                factorization=factorization,
                fixed_rank_modes=fixed_rank_modes,
                decomposition_kwargs=decomposition_kwargs,
            )
            if td_init:
                fact_layer.weight.normal_(0, td_init)
            layer, block, conv = name.split('.')
            conv_to_replace = getattr(getattr(fact_model, layer), block)
            setattr(conv_to_replace, conv, fact_layer)

factorizing: layer1.0.conv1
factorizing: layer1.0.conv2
factorizing: layer1.1.conv1
factorizing: layer1.1.conv2
factorizing: layer2.0.conv1
factorizing: layer2.0.conv2
factorizing: layer2.1.conv1
factorizing: layer2.1.conv2
factorizing: layer3.0.conv1
factorizing: layer3.0.conv2
factorizing: layer3.1.conv1
factorizing: layer3.1.conv2
factorizing: layer4.0.conv1
factorizing: layer4.0.conv2
factorizing: layer4.1.conv1
factorizing: layer4.1.conv2


In [26]:
class TNAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = fact_model

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        print(x.shape)
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [27]:
autoencoder_tn = TNAutoEncoder()
trainer = pl.Trainer(accelerator="gpu", max_epochs=10)
trainer.fit(autoencoder_tn, trainloader, testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 5.7 M 
---------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.761    Total estimated model params size (MB)


Epoch 9: 100%|██████████| 391/391 [00:15<00:00, 25.47it/s, v_num=4]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 391/391 [00:15<00:00, 25.25it/s, v_num=4]
