In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import lightning.pytorch as pl
from tqdm import tqdm as tqdm
import wandb
import lightning.pytorch.loggers as pl_loggers

import sys
sys.path.append('../src/td-comp/')

%reload_ext autoreload
%autoreload 2

In [2]:
# torch.set_float32_matmul_precision("high")
torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

In [7]:
# Data
from torch.utils.data import random_split

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

train_ds, val_ds = random_split(trainset, [45000, 5000])

trainloader = torch.utils.data.DataLoader(
    train_ds, batch_size=128, shuffle=True, num_workers=4)
valloader = torch.utils.data.DataLoader(
    val_ds, batch_size=128, shuffle=False, num_workers=4)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
wandb_logger = pl_loggers.WandbLogger(name='wandb_test', project="td-compression")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33musainzg[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
from models import resnet
from torchmetrics.functional import accuracy

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resnet.ResNet18()
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return preds
    
    def test_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [12]:
model = Model()
trainer = pl.Trainer(
    accelerator="gpu", 
    max_epochs=5, 
    logger=wandb_logger
)
trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=valloader)
trainer.test(model, dataloaders=testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [15]:
wandb_logger.experiment.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█
test_accuracy,▁
test_loss,▁
train_accuracy,▁▁▄▄▃▄▄▅▅▄▅▆▅▆▆▇▇▆▇▇▆▇▇▇▇███▇██████
train_loss,██▆▆▆▆▅▅▄▄▄▄▄▃▃▂▂▃▂▃▃▂▂▃▂▁▁▂▂▁▂▁▁▂▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
val_accuracy,▁▃▆██
val_loss,█▇▃▁▁

0,1
epoch,5.0
test_accuracy,0.7824
test_loss,0.65465
train_accuracy,0.78906
train_loss,0.5356
trainer/global_step,1760.0
val_accuracy,0.7802
val_loss,0.64431


In [16]:
class TNModel(pl.LightningModule):
    def __init__(self, fact_model):
        super().__init__()
        self.model = fact_model
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        logits = self.model(x)
        return logits

    def training_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return preds
    
    def test_step(self, batch, batch_idx):
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y, 'multiclass', num_classes=10)
        return preds, loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [18]:
import copy
import tltorch

factorization = 'tucker'
rank = 0.75
decompose_weights = True
td_init = not decompose_weights

decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
fixed_rank_modes = 'spatial' if factorization == 'tucker' else None

resnet = resnet.ResNet18()
fact_model = copy.deepcopy(resnet)

layer_names = ['layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.1.conv1', 'layer2.1.conv2', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.1.conv1', 'layer3.1.conv2', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.1.conv1', 'layer4.1.conv2']

for i, (name, module) in enumerate(resnet.named_modules()):
    if name in layer_names:
        print(f'factorizing: {name}')
        if type(module) == torch.nn.modules.conv.Conv2d:
            fact_layer = tltorch.FactorizedConv.from_conv(
                module, 
                rank=rank, 
                decompose_weights=decompose_weights, 
                factorization=factorization,
                fixed_rank_modes=fixed_rank_modes,
                decomposition_kwargs=decomposition_kwargs,
            )
            if td_init:
                fact_layer.weight.normal_(0, td_init)
            layer, block, conv = name.split('.')
            conv_to_replace = getattr(getattr(fact_model, layer), block)
            setattr(conv_to_replace, conv, fact_layer)

factorizing: layer1.0.conv1
factorizing: layer1.0.conv2
factorizing: layer1.1.conv1
factorizing: layer1.1.conv2
factorizing: layer2.0.conv1
factorizing: layer2.0.conv2
factorizing: layer2.1.conv1
factorizing: layer2.1.conv2
factorizing: layer3.0.conv1
factorizing: layer3.0.conv2
factorizing: layer3.1.conv1
factorizing: layer3.1.conv2
factorizing: layer4.0.conv1
factorizing: layer4.0.conv2
factorizing: layer4.1.conv1
factorizing: layer4.1.conv2


In [20]:
tn_model = TNModel(fact_model=fact_model)

In [21]:
wandb_logger = pl_loggers.WandbLogger(name='tn-test', project="td-compression")
trainer = pl.Trainer(
    accelerator="gpu", 
    max_epochs=5, 
    logger=wandb_logger
)
trainer.fit(tn_model, train_dataloaders=trainloader, val_dataloaders=valloader)
trainer.test(tn_model, dataloaders=testloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 8.4 M 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
8.4 M     Trainable params
0         Non-trainable params
8.4 M     Total params
33.703    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.6699155569076538, 'test_accuracy': 0.776199996471405}]

In [22]:
wandb_logger.experiment.finish()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█
test_accuracy,▁
test_loss,▁
train_accuracy,▁▁▁▃▁▃▂▃▅▄▃▅▄▄▄▅▃▄▅▅▅▅▅▅▇▄▄▅▅▆▆▆█▇▅
train_loss,███▆▇▅▇▇▅▅▅▄▅▄▅▄▅▅▃▄▃▄▃▃▃▄▅▄▃▂▂▄▁▂▄
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
val_accuracy,▂▁▆▇█
val_loss,▆█▃▂▁

0,1
epoch,5.0
test_accuracy,0.7762
test_loss,0.66992
train_accuracy,0.78906
train_loss,0.62855
trainer/global_step,1760.0
val_accuracy,0.787
val_loss,0.63214
