In [2]:
###### MODIFIED from 
## https://github.com/chenyaofo/CIFAR-pretrained-models

import torch
import torch.nn as nn
from typing import Union, Tuple
import numpy as np

__all__ = ['CifarResNet', 'cifar_resnet20', 'cifar_resnet32', 'cifar_resnet44', 'cifar_resnet56']
    

class ChannelNorm2D(nn.Module):
    
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps
        
    def forward(self, x):
        x = x-x.mean(dim=1, keepdim=True)
        x = x/torch.sqrt(x.var(dim=1, keepdim=True)+self.eps)
        return x
    

def conv3x3(in_planes, out_planes, stride=1, groups=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)


def conv1x1(in_planes, out_planes, stride=1, groups=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=groups, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1):
        super(BasicBlock, self).__init__()
        
        self.groups = groups
        self.conv1 = conv3x3(inplanes, planes, stride, self.groups)
        
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        
#         self.groups2 = planes//self.groups
        self.conv2 = conv3x3(planes, planes, groups=self.groups)
        
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)        
        out = self.bn1(out)
        out = self.relu(out)

#         B, C, H, W = out.shape
#         out = out.view(B, C//self.groups2, self.groups2, H, W)\
#                     .transpose(1,2).contiguous()\
#                     .view(B, C, H, W)
                      
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class SequentialMixer(nn.Module):
    
    def __init__(self, blocks, inplanes, group_size):
        super().__init__()
        
        self.blocks = nn.ModuleList(blocks)
        self.inplanes = inplanes
        self.group_sz = group_size
        self.groups = inplanes//group_size
        
        def log_base(a, base):
            return np.log(a) / np.log(base)
        
        
        ### total number of layers to complete mixing
        self.num_layers = int(np.ceil(log_base(self.inplanes, base=self.group_sz)))
        
        self.gaps = []
        for i in range(len(self.blocks)):
            butterfly_layer_index = i%self.num_layers ## repeated index in blocks (for layers)

            gap = self.group_sz**butterfly_layer_index
            if gap*self.group_sz > self.inplanes:
                gap = int(np.ceil(self.inplanes/self.group_sz))
            self.gaps += [gap]
            pass
        
        
        pass
    
    def forward(self, x):
        
        B, C, H, W = x.shape
#         out = out.view(B, C//self.groups2, self.groups2, H, W)\
#                     .transpose(1,2).contiguous()\
#                     .view(B, C, H, W)

        for gap, fn in zip(self.gaps, self.blocks):
#         for i, fn in enumerate(self.blocks):
#             butterfly_layer_index = i%self.num_layers
#             gap = self.group_sz**butterfly_layer_index
#             if gap*self.group_sz > self.inplanes:
#                 gap = int(np.ceil(self.inplanes/self.group_sz))
            
            
            
            x = x.view(B, -1, self.group_sz, gap, H, W).transpose(2, 3).contiguous().view(B, -1, H, W)
            x = fn(x)
            _, _, H, W = x.shape
            x = x.view(B, -1, gap, self.group_sz, H, W).transpose(2, 3).contiguous().view(B, -1, H, W)

#         x = x.view(B, C, H, W)
        return x
        

class CifarResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10, planes=16, group_sizes=None):
        super(CifarResNet, self).__init__()
        global conv3x3, conv1x1
        
        self.inplanes = planes
        self.conv1 = conv3x3(3, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        
        if group_sizes is None:
            group_sizes = [-1, -1, -1]
                
        self.layer1 = self._make_layer(block, planes, layers[0], group_sz=group_sizes[0])
        self.layer2 = self._make_layer(block, planes*2, layers[1], stride=2, group_sz=group_sizes[1])
        self.layer3 = self._make_layer(block, planes*4, layers[2], stride=2, group_sz=group_sizes[2])

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(planes*4 * block.expansion, num_classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, group_sz=-1):
        downsample = None
        if group_sz <= 0:
            groups = 1
        else:
            groups = self.inplanes//group_sz
        
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride, groups=groups),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, groups=groups))
        
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=groups))

        return SequentialMixer(layers, self.inplanes, group_sz)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def cifar_resnet20(**kwargs):
    model = CifarResNet(BasicBlock, [3, 3, 3], group_sizes=[4, 8, 8], **kwargs)
    return model

def cifar_resnet23(**kwargs):
    model = CifarResNet(BasicBlock, [4, 4, 4], group_sizes=[4, 8, 8], **kwargs)
    return model


# def cifar_resnet32(**kwargs):
#     model = CifarResNet(BasicBlock, [5, 5, 5], **kwargs)
#     return model


# def cifar_resnet44(**kwargs):
#     model = CifarResNet(BasicBlock, [7, 7, 7], **kwargs)
#     return model


# def cifar_resnet56(**kwargs):
#     model = CifarResNet(BasicBlock, [9, 9, 9], **kwargs)
#     return model

In [3]:
model = cifar_resnet20()
model

CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): SequentialMixer(
    (blocks): ModuleList(
      (0-2): 3 x BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2): SequentialMixer(
    (blocks): ModuleList(
      (0): BasicBlock(
        (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=2, bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affin

In [5]:
model(torch.randn(1, 3, 32, 32))

tensor([[ 0.2257,  0.4249,  0.2232,  0.7458, -0.2010, -0.6317, -0.0889,  0.7805,
          1.1233,  1.2063]], grad_fn=<AddmmBackward0>)

In [6]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [7]:
import torch
import torch.nn as nn

from torchvision import datasets, transforms as T
from torch.utils import data

In [8]:
from tqdm import tqdm
import os, time, sys

In [9]:
# import resnet_mixer

In [10]:
cifar_train = T.Compose([
    T.RandomCrop(size=32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

cifar_test = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
        std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
    ),
])

train_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=True, download=True, transform=cifar_train)
test_dataset = datasets.CIFAR10(root="../../../../../_Datasets/cifar10/", train=False, download=True, transform=cifar_test)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
# train_dataset.data = train_dataset.data.view(-1, 28*28)
# test_dataset.data = test_dataset.data.view(-1, 28*28)

In [12]:
batch_size = 128
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [13]:
device = torch.device("cuda:0")

In [14]:
criterion = nn.CrossEntropyLoss()

In [15]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128])


## Group-Butterfly for CNN

In [16]:
# model = resnet_mixer.cifar_resnet20(mixer=True).to(device)

In [17]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [18]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [19]:
EPOCHS = 200

In [20]:
acc_dict = {}
#     net = torch.compile(net)
#     net = torch.compile(net, mode="reduce-overhead")
#     net = torch.compile(net, mode="max-autotune")


# model_name = f"00.0_c10_ordinary_e0"
# net = resnet_mixer.cifar_resnet20(num_classes=10, mixer=False).to(device)

### e0 default cifar_resnet20 with 4,8,8 groups per block
# model_name = f"00.0_c10_butterfly_e0"
# net = resnet_mixer.cifar_resnet20(num_classes=10, mixer=True).to(device)

### e0 32, 64, 128 cifar_resnet20 with 8,8,16 groups per block
# model_name = f"00.0_c10_butterfly_e1"
# net = resnet_mixer.cifar_resnet20(num_classes=10, mixer=True, planes=32, G=[8, 8, 16]).to(device)

model_name = f"00.1_c10_butterfly_block_e1"
net = CifarResNet(BasicBlock, [4, 4, 4], num_classes=10, planes=16, group_sizes=[4, 8, 8]).to(device)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_acc = -1

for epoch in range(EPOCHS):
    train(epoch, net, optimizer)
    test(epoch, net, model_name)
    scheduler.step()
# acc_dict[key] = float(best_acc)

100%|███████████████████████████████████████████████████| 391/391 [00:16<00:00, 23.87it/s]


[Train] 0 Loss: 1.537 | Acc: 42.884 21442/50000


100%|████████████████████████████████████████████████████| 79/79 [00:00<00:00, 141.79it/s]


[Test] 0 Loss: 1.318 | Acc: 54.610 5461/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:06<00:00, 55.87it/s]


[Train] 1 Loss: 1.078 | Acc: 61.322 30661/50000


100%|████████████████████████████████████████████████████| 79/79 [00:00<00:00, 146.60it/s]


[Test] 1 Loss: 1.214 | Acc: 58.990 5899/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:07<00:00, 55.27it/s]


[Train] 2 Loss: 0.914 | Acc: 67.646 33823/50000


100%|████████████████████████████████████████████████████| 79/79 [00:00<00:00, 145.78it/s]


[Test] 2 Loss: 1.228 | Acc: 57.830 5783/10000


100%|███████████████████████████████████████████████████| 391/391 [00:07<00:00, 54.49it/s]


[Train] 3 Loss: 0.805 | Acc: 71.770 35885/50000


100%|████████████████████████████████████████████████████| 79/79 [00:00<00:00, 144.27it/s]


[Test] 3 Loss: 0.992 | Acc: 66.000 6600/10000
Saving..


100%|███████████████████████████████████████████████████| 391/391 [00:07<00:00, 53.32it/s]


[Train] 4 Loss: 0.742 | Acc: 74.210 37105/50000


100%|████████████████████████████████████████████████████| 79/79 [00:00<00:00, 140.78it/s]


[Test] 4 Loss: 0.911 | Acc: 68.950 6895/10000
Saving..


 25%|████████████▉                                       | 97/391 [00:02<00:06, 44.26it/s]


RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [16]:
# !mkdir models

In [55]:
best_acc

90.14

In [18]:
asdasd

NameError: name 'asdasd' is not defined

In [20]:
'''
FOR RESNET 20
{'stereographic': 90.51}
{'linear': 92.77} ??? uses different settings ?!
{'linear': 90.14} /!\ using the same configs
{'butterfly': 88.61}
{'butterfly-in32': 91.68}
{'block-butterfly': 90.61}

RESNET 26 -- for even mixing in each resoultion
{'block-butterfly[4,4,4]': 91.2}
'''

"\n{'stereographic': 90.51}\n{'linear': 92.77}\n{'butterfly': 88.61}\n{'butterfly-in32': 91.68}\n"

In [15]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=True).parameters()])

43802

In [16]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=False).parameters()])

130346

In [17]:
43802/130346

0.33604406732849496

In [18]:
sum([p.numel() for p in resnet_mixer.cifar_resnet20(mixer=True, planes=32, G=[8, 8, 16]).parameters()])

129578

## Computing the MACs

In [56]:
from ptflops import get_model_complexity_info

for i in range(4):
    if i==0:
        ## hard core ignore
        model = resnet_mixer.cifar_resnet20(mixer=False)
    elif i == 1:
        ### FOR ORIGINAL MIXER V1
        model = resnet_mixer.cifar_resnet20(mixer=True)
    elif i == 2:
        model = resnet_mixer.cifar_resnet20(mixer=True, planes=32, G=[8, 8, 16])
        
    elif i == 3:
#         model = CifarResNet(BasicBlock, [3, 3, 3], num_classes=10, group_sizes=[4, 8, 8])
        model = CifarResNet(BasicBlock, [4, 4, 4], num_classes=10, planes=16, group_sizes=[4, 8, 8])
#         model = CifarResNet(BasicBlock, [2, 2, 2], num_classes=10, planes=32, group_sizes=[8, 8, 16])




    macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True,
                                   print_per_layer_stat=False, verbose=False)

    print("MODEL INDEX: ", i)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('')

MODEL INDEX:  0
Computational complexity:       20.95 MMac
Number of parameters:           130.35 k

MODEL INDEX:  1
Computational complexity:       8.97 MMac
Number of parameters:           43.8 k  

MODEL INDEX:  2
Computational complexity:       25.24 MMac
Number of parameters:           129.58 k

MODEL INDEX:  3
Computational complexity:       19.3 MMac
Number of parameters:           112.15 k

