This notebook replicates the second experiment with CIFAR10 on MNIST dataset from our paper "MARS: Masked Automatic Ranks Selection in Tensor Decompositions".

Compression mode: **base** (none).

**Preliminaries**

In [1]:
import os

os.chdir("../main")

In [2]:
from matplotlib import pyplot as plt
import seaborn as sns
sns.set()

%matplotlib inline
%config InlineBackend.figure_format = 'svg' 

import torch
import numpy as np
import random

import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms

from mars import MARS, MARSLoss, get_MARS_attr, set_MARS_attr
from models import MarsConfig, ResNet

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
seed = 228  # set random seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

**Model and hyperparameters definition**

In [5]:
# Compression modes parameters
modes_dict = {
    "base": dict(pi=np.nan, alpha=np.nan),
    "naive": dict(pi=1e-2, alpha=2.25),
    "proper": dict(pi=4e-3, alpha=3),
}

In [6]:
# Hyperparameters section
# In this experiment, we take enough epochs to guarantee complete convergence.
# One can take fewer but adjust the temperature annealing schedule appropriately.
n_epochs = 50  
batch_size = 128
lr = 5e-3
weight_decay = 1e-4
gamma = 0.94
temp_anneal = lambda t: max(1e-2, gamma * t)

In [7]:
from augmentations import MixupWrapper, CutoutWrapper, CombineWrapper, SmoothOHEWrapper

data_dir = "/home/sergej/data"
test_batch = 2048

# p = 0.5
# train_transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.RandomApply([
#         transforms.ColorJitter(brightness=.5, hue=.1)
#     ], p=p),
#     transforms.RandomApply([
#         transforms.GaussianBlur(kernel_size=(3, 5), sigma=(0.1, 5))
#     ], p=p),
#     transforms.RandomApply([
#         gauss_noise_tensor
#     ], p=p),
#     transforms.RandomApply([
#         transforms.RandomRotation((-15, 15), interpolation=transforms.InterpolationMode.BILINEAR)
#     ], p=p),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])

train_transform = transforms.Compose([
#     transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = datasets.CIFAR10(root=data_dir, train=True,
                            download=True, transform=train_transform)
# trainset, validset = torch.utils.data.random_split(trainset, [45000, 5000],
#                                                    generator=torch.Generator().manual_seed(42))
# trainset = SmoothOHEWrapper(
#     CombineWrapper(
#         MixupWrapper(trainset, alpha=0.8, p=0.5),
#         CutoutWrapper(trainset, size=10, p=0.5)
#     ),
#     n_classes=10
# )
trainset = CutoutWrapper(trainset, size=10, p=0.25)
testset = datasets.CIFAR10(root=data_dir, train=False,
                           download=True, transform=test_transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
# validloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
#                                           shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


**Functions definition**

In [8]:
def train_model(model_idx, mode="soft", save=True, load=True):
    """
    Train the model or load the trained one.
    
    Parameters are:
        model_idx : int
            Model index to load or save.
        mode : str in {'soft', 'hard'}
            Compression mode.
        save : bool
            Whether to save the trained model.
        load : bool
            Whether to load the trained model.
    """
    model_directory_path = f"../models/CIFAR10-ResNet/{mode}/"
    prefix = str(model_idx)
        
    model_path = model_directory_path + prefix + '-model.pt'
    losses_path = model_directory_path + prefix + '-losses.npy'
    print("Model path: ", model_path)

    if save and not os.path.exists(model_directory_path):
        os.makedirs(model_directory_path)
        
    model = ResNet(config, **modes_dict[mode]).to(device)
    
    if load and os.path.isfile(model_path):
        # load trained model parameters from disk
        model.load_state_dict(torch.load(model_path))
        losses = np.load(losses_path)
        print('Loaded model parameters from disk.')
        return model, losses
    
    cross_entropy = nn.CrossEntropyLoss()
    criterion = MARSLoss(model, len(trainset), cross_entropy)
    optimizer = optim.AdamW(model.parameters(), lr, weight_decay=weight_decay)
#     scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, **{
        "steps_per_epoch": len(trainloader),
        "epochs": n_epochs,
        "anneal_strategy": "cos",
        "max_lr": lr,
        "pct_start": 0.1
    })

    print("Training...")
    losses = []
    log_step = len(trainloader)
    best_train_acc = 0
    best_test_acc = 0
    
    for epoch in range(n_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        losses.append(0.0)

        model.train()
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses[-1] += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()

            # update statistics
            with torch.no_grad():
                running_loss += loss.item()
                predicted = outputs.argmax(-1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        train_acc = correct / total
        test_acc = eval_model(model)
        temp = get_MARS_attr(model, "temperature")
        print('[%d] \t Loss: %.3f \t Train Acc: %.2f%% \t Test Acc: %.2f%% \t T: %.3f' %
              (epoch + 1, 
               running_loss / log_step,
               100 * train_acc,
               100 * test_acc, 
               np.nan if temp is None else temp))
                
        losses[-1] /= i + 1
        
        if save:
            if train_acc > best_train_acc:
                torch.save(model.state_dict(), model_path[:-3] + "-best_train.pt")
                best_train_acc = train_acc
                best_train_epoch = epoch + 1
            if test_acc > best_test_acc:
                torch.save(model.state_dict(), model_path[:-3] + "-best_test.pt")
                best_test_acc = test_acc
                best_test_epoch = epoch + 1
        
        temp = get_MARS_attr(model, "temperature")
        if temp is not None:
            new_temp = temp_anneal(temp)
            set_MARS_attr(model, "temperature", new_temp)

    losses = np.array(losses)
    print('Finished Training.')
    print("Best train accuracy:\t%.2f%% on epoch %d" % (100 * best_train_acc, best_train_epoch))
    print("Best test accuracy:\t%.2f%% on epoch %d" % (100 * best_test_acc, best_test_epoch))
    
    if save:
        torch.save(model.state_dict(), model_path)
        np.save(losses_path, losses)
        print('Saved model parameters to disk.')
    
    return model, losses

In [9]:
def eval_model(model):
    "Evaluate a single model on test set."
    model.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            predicted = outputs.argmax(-1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [10]:
def eval_ensemble(models):
    "Evaluate the whole ensemble on test set."
    for model in models:
        model.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = torch.stack([model(images) for model in models])
            outputs = torch.softmax(outputs, -1)
            outputs = outputs.mean(0)
                
            predicted = outputs.argmax(-1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [11]:
def get_comp_info(model):
    "Plot model masks probabilities, print compression info and return total compression."
    MARS_layers = [l for l in model.modules() if isinstance(l, MARS)]
    
    ranks_list = []
    totals, dofs = [], []

    for l in MARS_layers:
        print("Layer: ", l.tensorized_model)
        phi_logits_list = l.phi_logits_list
        F = l.F
        eval_logits_threshold = l.eval_logits_threshold
        p_threshold = F(torch.tensor(eval_logits_threshold)).item()
        ranks = []

        for m, logits in enumerate(phi_logits_list, 1):
            logits = logits.detach().cpu()
            probs = F(logits).data.numpy()
            
            plt.title(f"Mask {m}")
            plt.bar(np.arange(1, len(probs) + 1), probs)
            plt.xlabel('Rank')
            plt.ylabel(r'$\phi$ value')
            plt.hlines(p_threshold, 0, len(probs) + 1, linestyles='--')
            plt.text(0, p_threshold * 1.05, 'Rounding threshold')
            plt.show()

            rank = (logits > eval_logits_threshold).sum().item()
            print("#nz ranks: {0}/{1}".format(rank, len(logits)))
            ranks.append(rank)

        ranks_list.append(ranks)
        print()
        
        dofs.append(l.tensorized_model.calc_dof(ranks))
        totals.append(l.tensorized_model.total)
        comp = totals[-1] / dofs[-1]
        
        print("Compression:\t%.3f" % comp)
        print(100*"=")

    all_params = sum(p.numel() for p in model.parameters())
    mars_params = sum(logits.numel() for l in MARS_layers for logits in l.phi_logits_list)
    mars_params += sum(l.tensorized_model.calc_dof() for l in MARS_layers)
    other_params = all_params - mars_params
    
    total_comp = (other_params + sum(totals)) / (other_params + sum(dofs))
    print("Total compression:\t%.3f" % total_comp)
    
    return total_comp

**No MARS base training**

In [12]:
from dataclasses import dataclass

@dataclass
class CIFARResNetConfig:
    blocks_per_group = (18, 18, 18)
    num_classes = 10
    width = 16
    mars_configs = (
        MarsConfig(
            enabled=False
        ),
        MarsConfig(
            enabled=False
        ),
        MarsConfig(
            enabled=False
        )
    )

config = CIFARResNetConfig()

In [13]:
model, loss = train_model(0, mode="base")
acc = eval_model(model)
print(f"Accuracy of base model:\t%.2f%%" % (100 * acc))

Model path:  ../models/CIFAR10-ResNet/base/0-model.pt
Training...
[1] 	 Loss: 1.867 	 Train Acc: 29.49% 	 Test Acc: 39.33% 	 T: nan
[2] 	 Loss: 1.628 	 Train Acc: 39.05% 	 Test Acc: 41.56% 	 T: nan
[3] 	 Loss: 1.440 	 Train Acc: 46.72% 	 Test Acc: 54.15% 	 T: nan
[4] 	 Loss: 1.304 	 Train Acc: 52.31% 	 Test Acc: 55.80% 	 T: nan
[5] 	 Loss: 1.177 	 Train Acc: 57.33% 	 Test Acc: 57.28% 	 T: nan
[6] 	 Loss: 1.084 	 Train Acc: 61.17% 	 Test Acc: 64.82% 	 T: nan
[7] 	 Loss: 0.995 	 Train Acc: 64.46% 	 Test Acc: 66.70% 	 T: nan
[8] 	 Loss: 0.920 	 Train Acc: 67.26% 	 Test Acc: 69.22% 	 T: nan
[9] 	 Loss: 0.864 	 Train Acc: 69.34% 	 Test Acc: 71.13% 	 T: nan
[10] 	 Loss: 0.784 	 Train Acc: 72.32% 	 Test Acc: 72.54% 	 T: nan
[11] 	 Loss: 0.739 	 Train Acc: 73.97% 	 Test Acc: 76.40% 	 T: nan
[12] 	 Loss: 0.689 	 Train Acc: 75.78% 	 Test Acc: 76.30% 	 T: nan
[13] 	 Loss: 0.647 	 Train Acc: 77.30% 	 Test Acc: 77.06% 	 T: nan
[14] 	 Loss: 0.608 	 Train Acc: 78.90% 	 Test Acc: 80.53% 	 T: nan
[15] 