# Weight sharing on CIFAR10
Pytorch-cifar models have 3x3 instead of 7x7 for the first layer. This is significantly better for smaller images??? like cifar. Without this the accuracy is <=90%
Comparing the standard kuangliu models to my rwightman version to training this using lightning (weight-sharing repo)

----

## Result - Weight sharing is not helpful on CIFAR10

#### Regular Convolutions:
* Worse than pruning 

#### Depthwise separable convolutions:
* Weight Sharing is not generally useful. Spatial can be used without much loss in accuracy, but this saves almost no parameters.
* Channel wise weight sharing is worse than pruining per-paramater

In [1]:
import os
import random
import pickle
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from IPython.display import display, HTML

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

from models import *

%reload_ext autoreload
%autoreload 2
%matplotlib inline

## Standard Models - Resnet: Weight sharing isn't helpful
* Training both for 350 epochs, same hyperparams

In [2]:
for i in [2, 4, 8, 16]:
    save_path = './saved_models_old/cifar10/resnet18_ws1_ch'+str(i)
    save_path_ws = './saved_models/cifar10/resnet18_ws'+str(i)+'_ch1_dr'
    checkpoint = torch.load(save_path + '/ckpt.pth')
    checkpoint_ws = torch.load(save_path_ws + '/ckpt.pth')
    print(f'{i} acc frac: {checkpoint["acc"]} , acc ws: {checkpoint_ws["acc"]} {checkpoint_ws["num_params"]}')
    

2 acc frac: 94.79 , acc ws: 94.58 5664458
4 acc frac: 93.9 , acc ws: 93.78 2909706
8 acc frac: 92.27 , acc ws: 91.92 1532330
16 acc frac: 89.51 , acc ws: 87.84 843642


## Depthwise separable convolutions: Channel vs. Spatial

### Small Xception
Weight sharing either saves very few params (spatial) or has a big accuracy decrease (channel)

In [3]:
for i in [1, 2, 4, 8]:
    save_path_ch = './saved_models/cifar10/vsxception_ch'+str(i)+'_wss_1_wsc_1'
    save_path_wss = './saved_models/cifar10/vsxception_ch1_wss_'+str(i)+'_wsc_1'
    save_path_wsc = './saved_models/cifar10/vsxception_ch1_wss_1_wsc_' + str(i)
    checkpoint_ch = torch.load(save_path_ch + '/ckpt.pth')
    checkpoint_wss = torch.load(save_path_wss + '/ckpt.pth')
    checkpoint_wsc = torch.load(save_path_wsc + '/ckpt.pth')
    print(f'{i} acc frac: {checkpoint_ch["acc"]} {checkpoint_ch["num_params"]}, acc wss: {checkpoint_wss["acc"]} {checkpoint_wss["num_params"]}, acc wsc: {checkpoint_wsc["acc"]} {checkpoint_wsc["num_params"]}')

1 acc frac: 93.87 787242, acc wss: 93.87 787242, acc wsc: 93.87 787242
2 acc frac: 92.67 232746, acc wss: 93.98 777450, acc wsc: 92.76 508714
4 acc frac: 91.3 103338, acc wss: 93.54 772554, acc wsc: 92.82 369450
8 acc frac: 89.25 61034, acc wss: 93.44 770106, acc wsc: 91.63 299818


### Xception

In [26]:
for i in [1, 2, 4, 8]:
    save_path_ch = './saved_models/cifar10/xception_ch'+str(i)+'_wss_1_wsc_1'
    save_path_wss = './saved_models/cifar10/xception_ch1_wss_'+str(i)+'_wsc_1'
    save_path_wsc = './saved_models/cifar10/xception_ch1_wss_1_wsc_' + str(i)
    checkpoint_ch = torch.load(save_path_ch + '/ckpt.pth')
    checkpoint_wss = torch.load(save_path_wss + '/ckpt.pth')
    checkpoint_wsc = torch.load(save_path_wsc + '/ckpt.pth')
    print(f"""{i} acc ch: {checkpoint_ch["acc"]} {checkpoint_ch["num_params"]}, acc wss: {checkpoint_wss["acc"]} {checkpoint_wss["num_params"]}, acc wsc: {checkpoint_wsc["acc"]} {checkpoint_wsc["num_params"]}""")

1 acc ch: 95.42 7649898, acc wss: 95.42 7649898, acc wsc: 95.42 7649898
2 acc ch: 95.57 3012458, acc wss: 95.67 7617642, acc wsc: 95.43 5683818
4 acc ch: 94.93 1627626, acc wss: 95.44 7601514, acc wsc: 95.46 4700778
8 acc ch: 93.73 1168682, acc wss: 95.53 7593450, acc wsc: 94.98 4209258


## Spatial Mobile Net - Baseline Accuracy too low

In [4]:
save_path_ws = './saved_models/cifar10/mobile'
checkpoint = torch.load(save_path_ws + '/ckpt.pth')
acc_ws = checkpoint['acc']
print(f'{i} acc ws: {acc_ws}')

1 acc ws: 91.53


In [6]:
for i in [1, 2, 4, 8]:
    save_path_ws = './saved_models/cifar10/mobile_wss'+str(i)
    checkpoint = torch.load(save_path_ws + '/ckpt.pth')
    acc_ws = checkpoint['acc']
    print(f'{i} acc ws: {acc_ws}')

1 acc ws: 91.87
2 acc ws: 91.58
4 acc ws: 92.04
8 acc ws: 91.49


## Sanity checks
* BatchNorm before and after repeat is the same
* Cifar adjusted rwightman one == kuangliu(resnet18_std)

In [None]:
save_path = './saved_models/cifar10/resnet18_std'
checkpoint = torch.load(save_path + '/ckpt.pth')
checkpoint['acc']

In [None]:
for ws in [1, 2, 4, 8, 16]:
    save_path = './saved_models/cifar10/resnet18_ws'+str(ws)+'_ch1'
    save_path_bn = './saved_models/cifar10/resnet18bn_ws'+str(ws)+'_ch1'
    checkpoint = torch.load(save_path + '/ckpt.pth')
    acc = checkpoint['acc']
    checkpoint = torch.load(save_path_bn + '/ckpt.pth')
    acc_bn = checkpoint['acc']
    print(f'{ws} acc: {acc} acc_bn: {acc_bn}')

In [None]:
save_path = './saved_models/cifar10/resnet50_ws1_ch1'
checkpoint = torch.load(save_path + '/ckpt.pth')
checkpoint['acc']

## Imagenet (rwightman) vs. cifar (kuangliu) resnets
Weight differences:

In [None]:
imagenet_model = ResNetLight(BasicBlockLight, [2, 2, 2, 2], ws_factor=1, channel_factor=1,
                         num_classes=10, cifar=False)
cifar_model = ResNet18()


imagenet_model_names = []
imagenet_model_sizes = []
for name, param in imagenet_model.named_parameters():
    imagenet_model_names.append(name)
    imagenet_model_sizes.append(param.size())
    
cifar_model_names = []
cifar_model_sizes = []
for name, param in cifar_model.named_parameters():
    cifar_model_names.append(name)
    cifar_model_sizes.append(param.size())

In [None]:
for n,m in zip(cifar_model_sizes, imagenet_model_sizes):
    print(n,  m)

In [None]:
for n,m in zip(cifar_model_names, imagenet_model_names):
    print(n + '\t\t\t' + m)

## Debug strides, etc:

In [None]:
imagenet_model.to('cuda')
y = imagenet_model(torch.randn(2,3,32,32).to('cuda'))

In [None]:
cifar_model.to('cuda')
y = cifar_model(torch.randn(2,3,32,32).to('cuda'))

In [None]:
imagenet_model = ResNetLight(BasicBlockLight, [2, 2, 2, 2], ws_factor=1, channel_factor=1,
                         num_classes=10, cifar=True)

imagenet_model.to('cuda')
y = imagenet_model(torch.randn(2,3,32,32).to('cuda'))