In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [2]:
import torch
import torch.nn as nn

from torchvision import datasets, transforms as T
from torch.utils import data

In [3]:
from tqdm import tqdm
import os, time, sys

In [4]:
import resnet_mixer

In [5]:
cifar_train = T.Compose([
    T.RandomCrop(size=32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

In [6]:
# train_dataset.data = train_dataset.data.view(-1, 28*28)
# test_dataset.data = test_dataset.data.view(-1, 28*28)

In [7]:
batch_size = 128
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [8]:
device = torch.device("cuda:1")

In [9]:
criterion = nn.CrossEntropyLoss()

In [10]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## Group-Butterfly for CNN

In [11]:
# model = resnet_mixer.cifar_resnet20(mixer=True).to(device)

In [12]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [13]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [14]:
EPOCHS = 200

In [15]:
asfsdafwdfs

In [None]:
acc_dict = {}
#     net = torch.compile(net)
#     net = torch.compile(net, mode="reduce-overhead")
#     net = torch.compile(net, mode="max-autotune")

### e0 default cifar_resnet20 with 4,8,8 groups per block
# model_name = f"00.0_c10_butterfly_e0"
# net = resnet_mixer.cifar_resnet20(num_classes=10, mixer=True).to(device)

### e0 32, 64, 128 cifar_resnet20 with 8,8,16 groups per block
model_name = f"00.0_c10_butterfly_e1"
net = resnet_mixer.cifar_resnet20(num_classes=10, mixer=True, planes=32, G=[8, 8, 16]).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_acc = -1

for epoch in range(EPOCHS):
    train(epoch, net, optimizer)
    test(epoch, net, model_name)
    scheduler.step()
# acc_dict[key] = float(best_acc)

In [None]:
# !mkdir models

In [None]:
best_acc

In [None]:
asdasd

In [None]:
'''
{'stereographic': 90.51}
{'linear': 92.77}
{'butterfly': 88.61}
{'butterfly-in32': 91.68}
'''

In [None]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=True).parameters()])

In [None]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=False).parameters()])

In [None]:
43802/272474

In [None]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=True, planes=32, G=[8, 8, 16]).parameters()])

## Computing the MACs

In [16]:
from ptflops import get_model_complexity_info

for i in range(9):
    print("MODEL INDEX: ", i)
    if i==0:
        ## hard core ignore
        model = resnet_mixer.cifar_resnet20(mixer=False)
    elif i == 1:
        ### FOR ORIGINAL MIXER V1 -- Default values used
        model = resnet_mixer.cifar_resnet20(mixer=True)
    elif i == 2:
        ### Larger plane creates smaller than original
        model = resnet_mixer.cifar_resnet20(mixer=True, planes=32, G=[8, 8, 16])
    elif i == 3:
        ## Use filter wise conv in second conv of resblock
        model = resnet_mixer.cifar_resnet20(mixer=True, G=[1, 1, 1])
    elif i == 4:
        ## Use filter wise conv in first conv of resblock
        model = resnet_mixer.cifar_resnet20(mixer=True, G=[16, 32, 64])
    else: 
        break
        
    macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True,
                                   print_per_layer_stat=False, verbose=False)

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('')