In [6]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List
from utils import *

import numpy as np
import torch
import torch.nn as nn
from torch.optim import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from tqdm.auto import tqdm

assert torch.cuda.is_available(), \
"CUDA support is not available."

import pickle

import LiveTune as lt
import timm

In [7]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [8]:
dataloader = get_dataloader("imagenet", 256)

# Experiments on VIT Tiny_patch_16_augreg

In [40]:
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)

In [41]:
evaluate(base_vit, dataloader=dataloader['val'], device=device)

                                                       

75.34799194335938

In [42]:
collapsible_vit = get_collapsible_model(base_vit, fraction=.25, device=device)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp


In [47]:
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/finetuned_1epoch_frac025_nolc_2.pth", map_location=device))

<All keys matched successfully>

In [48]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

74.50399780273438

In [43]:
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/finetuned_11epoch_frac025_lc.pth", map_location=device))

<All keys matched successfully>

In [44]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 0.9334723353385925
blocks.10.mlp 0.829644501209259
blocks.9.mlp 0.5449391007423401


In [45]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

73.13199615478516

In [46]:
parameters_base = get_num_parameters(base_vit)
parameters_collapsible = get_num_parameters(collapsible_vit)
print(f"Parameters: {parameters_base} -> {parameters_collapsible} ({parameters_collapsible/parameters_base:.2f}x)")

Parameters: 5717416 -> 5717419 (1.00x)


In [13]:
collapsible_vit = get_collapsible_model(base_vit, fraction=1, device=device)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.5.mlp
Collapsing layer blocks.4.mlp
Collapsing layer blocks.3.mlp
Collapsing layer blocks.2.mlp
Collapsing layer blocks.1.mlp
Collapsing layer blocks.0.mlp


In [15]:
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/collapsible_vit_tiny_16.augreg_in21k_ft_10ephock_nolc.pth"))

<All keys matched successfully>

In [9]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

69.33999633789062

### After 11 epochs

In [11]:
collapsible_vit.load_state_dict(torch.load("./models_archive/collapsible_vit_tiny_16augreg_in21k_ft_15ephock_nolc.pth"))

<All keys matched successfully>

In [12]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

68.75599670410156

## Summary:
### Updated the numbers in the paper. file was not saved properly. So re run the experiments in this file (the training is finished and in models archive)

# Test vit_base_patch16_224.orig_in21k_ft_in1k

In [4]:
base_vit = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True).to(device)

In [5]:
evaluate(base_vit, dataloader=dataloader['val'], device=device)

                                                       

81.43199920654297

### Total params in MLP : 56M -> savings per layer collapse ~4M

In [5]:
get_num_parameters(base_vit)

86567656

In [5]:
collapsible_vit = get_collapsible_model(base_vit, fraction=.1, device=device)

Collapsing layer blocks.11.mlp


In [6]:
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/base/finetuned_7epoch_frac01_lc.pth"))

<All keys matched successfully>

In [10]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

79.50800323486328

In [7]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 0.9902544021606445


In [8]:
collapse_model(collapsible_vit, fraction=.1, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp


In [9]:
get_num_parameters(collapsible_vit)

82435816

In [10]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

79.50599670410156

In [11]:
evaluate_model(base_vit, dataloader=dataloader, device=device)



model has test accuracy=81.43%
model has top5 accuracy=0.96%
model has size=330.23 MiB
model has macs=16.85 Gmacs
average inference time is 0.0064 seconds
model has 86.57 M parameters


In [12]:
evaluate_model(collapsible_vit, dataloader=dataloader, device=device)



model has test accuracy=79.51%
model has top5 accuracy=0.95%
model has size=314.47 MiB
model has macs=16.04 Gmacs
average inference time is 0.0045 seconds
model has 82.44 M parameters


5 percent reduction in parameter count.

0.2 percent reduction in accuracy.

30 percent reduction in inference time.

5 percent reduction in total MACs.

lc was done with lc=1.0
Current results with the bad version of the dataset.

# New Base

In [4]:
base_vit = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=.1, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/base/new_finetuned_5epoch_frac01_lc.pth", map_location=device))

Collapsing layer blocks.11.mlp


<All keys matched successfully>

In [6]:
collapse_model(collapsible_vit, fraction=.1, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp


In [7]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

82.17399597167969

In [8]:
evaluate_model(collapsible_vit, dataloader=dataloader, device=device)



model has test accuracy=82.17%
model has top5 accuracy=0.96%
model has size=314.47 MiB
model has macs=16.04 Gmacs
average inference time is 0.0042 seconds
model has 82.44 M parameters


In [4]:
base_vit = timm.create_model("vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=.1, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/base/new_finetuned_5epoch_frac01_lc.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=1/6, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/base/new_finetuned_5epoch_frac01_lc_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp


In [5]:
evaluate_model(collapsible_vit, dataloader=dataloader, device=device)



model has test accuracy=80.31%
model has top5 accuracy=0.95%
model has size=298.71 MiB
model has macs=15.23 Gmacs
average inference time is 0.0037 seconds
model has 78.30 M parameters


In [4]:
base_vit = timm.create_model("vit_large_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)

In [5]:
evaluate(base_vit, dataloader=dataloader['val'], device=device)

                                                       

85.68199920654297

In [5]:
get_num_parameters(base_vit)

304326632

In [5]:
collapse_vit = get_collapsible_model(base_vit, fraction=.1, device=device)

Collapsing layer blocks.23.mlp
Collapsing layer blocks.22.mlp


In [None]:
collapse_vit.load_state_dict(torch.load("./models_archive/vit/large/finetuned_2epoch_frac01_nolc.pth", map_location=device))

In [6]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

85.70199584960938

In [6]:
collapse_vit.load_state_dict(torch.load("./models_archive/vit/large/finetuned_17epoch_frac01_lc.pth", map_location=device))

<All keys matched successfully>

In [7]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

84.30799865722656

In [8]:
get_model_collapsible_slopes(collapse_vit)

blocks.23.mlp 0.9962775707244873
blocks.22.mlp 0.9683569073677063


In [9]:
collapse_model(collapse_vit, fraction=.1, threshold=0.05, device=device)

Collapsing layer blocks.23.mlp
Collapsing layer blocks.22.mlp


In [10]:
evaluate_model(collapse_vit, dataloader=dataloader, device=device)



model has test accuracy=84.22%
model has top5 accuracy=0.97%
model has size=1104.88 MiB
model has macs=56.77 Gmacs
average inference time is 0.0125 seconds
model has 289.64 M parameters


In [11]:
evaluate_model(base_vit, dataloader=dataloader, device=device)



model has test accuracy=85.68%
model has top5 accuracy=0.98%
model has size=1160.91 MiB
model has macs=59.66 Gmacs
average inference time is 0.0130 seconds
model has 304.33 M parameters


# ViT Small
## vit_small_patch16_224.augreg_in21k_ft_in1k

In [4]:
base_vit = timm.create_model("vit_small_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)

In [5]:
get_num_parameters(base_vit)

22050664

In [15]:
evaluate(base_vit, dataloader=dataloader['val'], device=device)

                                                       

81.38200378417969

In [10]:
collapse_vit = get_collapsible_model(base_vit, fraction=.2, device=device)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp


In [13]:
collapse_vit.load_state_dict(torch.load("./models_archive/vit/small/finetuned_1epoch_frac02_nolc.pth", map_location=device))

<All keys matched successfully>

In [14]:
get_model_collapsible_slopes(collapse_vit)

blocks.11.mlp 0.31300053000450134
blocks.10.mlp -2.9146413803100586


In [11]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

32.51799774169922

In [8]:
collapse_vit.load_state_dict(torch.load("./models_archive/vit/small/finetuned_13epoch_frac02_lc.pth", map_location=device))

<All keys matched successfully>

In [9]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

79.26799774169922

In [33]:
get_model_collapsible_slopes(collapse_vit)

blocks.11.mlp 0.9834568500518799
blocks.10.mlp -0.02318074367940426


In [34]:
collapse_model(collapse_vit, fraction=.2, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible


In [35]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

79.20999145507812

In [36]:
evaluate_model(base_vit, dataloader=dataloader, device=device)



model has test accuracy=81.38%
model has top5 accuracy=0.96%
model has size=84.12 MiB
model has macs=4.24 Gmacs
average inference time is 0.0043 seconds
model has 22.05 M parameters


In [37]:
evaluate_model(collapse_vit, dataloader=dataloader, device=device)



model has test accuracy=79.21%
model has top5 accuracy=0.95%
model has size=80.17 MiB
model has macs=4.04 Gmacs
average inference time is 0.0061 seconds
model has 21.02 M parameters


In [38]:
collapse_vit = get_collapsible_model(base_vit, fraction=.2, device=device)
collapse_vit.load_state_dict(torch.load("./models_archive/vit/small/finetuned_1epoch_frac02_nolc.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp


<All keys matched successfully>

In [39]:
evaluate(collapse_vit, dataloader=dataloader['val'], device=device)

                                                       

52.433998107910156

# New Small

In [14]:
base_vit = timm.create_model("vit_small_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=.1, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/small/new_finetuned_5epoch_frac01_lc.pth", map_location=device))
collapse_model(collapsible_vit, fraction=.1, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp


<All keys matched successfully>

In [10]:
evaluate_model(collapsible_vit, dataloader=dataloader, device=device)



model has test accuracy=80.96%
model has top5 accuracy=0.96%
model has size=80.17 MiB
model has macs=4.04 Gmacs
average inference time is 0.0043 seconds
model has 21.02 M parameters


In [9]:
base_vit = timm.create_model("vit_small_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=.1, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/small/new_finetuned_5epoch_frac01_lc.pth", map_location=device))
collapse_model(collapsible_vit, fraction=.1, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=1/6, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/small/new_finetuned_5epoch_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=1/6, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp


In [10]:
evaluate_model(collapsible_vit, dataloader=dataloader, device=device)



model has test accuracy=79.74%
model has top5 accuracy=0.95%
model has size=76.23 MiB
model has macs=3.84 Gmacs
average inference time is 0.0040 seconds
model has 19.98 M parameters


# Sensitivity Analysis

In [4]:
# tiny
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)

# no layer

In [5]:
# acc : 75.34

# 1 layer

In [6]:
collapsible_vit = get_collapsible_model(base_vit, fraction=1/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing1.pth", map_location=device)) 

Collapsing layer blocks.11.mlp


<All keys matched successfully>

In [37]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 0.9937456250190735


In [42]:
collapse_model(collapsible_vit, fraction=1/12, threshold=0.05, device=device)

Collapsing layer blocks.11.mlp


In [39]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

75.78799438476562

# 2 layers

In [7]:
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible


<All keys matched successfully>

In [8]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 0.9963091611862183


In [9]:
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp


In [10]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

75.33799743652344

In [11]:
torch.save(collapsible_vit.state_dict(), "./models_archive/vit/tiny/sensitivity_analysis_collapsed2.pth")

# 3 layers

In [5]:
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp


In [6]:
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))

<All keys matched successfully>

In [7]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 0.9955717921257019


In [8]:
collapse_model(collapsible_vit, fraction=3/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp


In [9]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

74.41999816894531

In [10]:
torch.save(collapsible_vit.state_dict(), "./models_archive/vit/tiny/sensitivity_analysis_collapsed3.pth")

In [13]:
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))
collapse_model(collapsible_vit, fraction=3/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=4/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing4.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp


<All keys matched successfully>

In [7]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 0.9978395104408264


In [16]:
collapse_model(collapsible_vit, fraction=4/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp


In [9]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

73.13199615478516

In [10]:
torch.save(collapsible_vit.state_dict(), "./models_archive/vit/tiny/sensitivity_analysis_collapsed4.pth")

In [17]:
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=5/12, device=device)

In [18]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 1
blocks.7.mlp 0.009999999776482582


In [5]:
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=4/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing4.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=5/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing5.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp


<All keys matched successfully>

In [6]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 1
blocks.7.mlp 0.9954009056091309


In [7]:
collapse_model(collapsible_vit, fraction=5/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp


In [8]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

71.3499984741211

In [10]:
get_num_parameters(collapsible_vit)/get_num_parameters(base_vit) * 100

77.3659989057994

In [11]:
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=6/12, device=device)
get_model_collapsible_slopes(collapsible_vit)

Collapsing layer blocks.6.mlp
blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 1
blocks.7.mlp 1
blocks.6.mlp 0.009999999776482582


In [4]:
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=4/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing4.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=5/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing5.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=6/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing6.pth", map_location=device))
collapse_model(collapsible_vit, fraction=6/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp


In [5]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 1
blocks.7.mlp 1
blocks.6.mlp 1


In [6]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

68.76799774169922

In [4]:
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=4/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing4.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=5/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing5.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=6/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing6.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=7/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing7.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.5.mlp


<All keys matched successfully>

In [5]:
collapse_model(collapsible_vit, fraction=7/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.5.mlp


In [6]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

66.93199920654297

In [4]:
base_vit = timm.create_model("vit_tiny_patch16_224.augreg_in21k_ft_in1k", pretrained=True).to(device)
collapsible_vit = get_collapsible_model(base_vit, fraction=2/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_10epoch_frac0.166_lc2.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing2.pth", map_location=device))
collapse_model(collapsible_vit, fraction=2/12, device=device, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=3/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing3.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=4/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing4.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=5/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing5.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=6/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing6.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=7/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing7.pth", map_location=device))
collapse_model(collapsible_vit, threshold=0.05)
collapsible_vit = get_collapsible_model(collapsible_vit, fraction=8/12, device=device)
collapsible_vit.load_state_dict(torch.load("./models_archive/vit/tiny/sensitivity_analysis_collapsing8.pth", map_location=device))

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Not collapsible
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.5.mlp
Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsi

<All keys matched successfully>

In [5]:
get_model_collapsible_slopes(collapsible_vit)

blocks.11.mlp 1
blocks.10.mlp 1
blocks.9.mlp 1
blocks.8.mlp 1
blocks.7.mlp 1
blocks.6.mlp 1
blocks.5.mlp 1
blocks.4.mlp 0.9587531089782715


In [None]:
collapse_model(collapsible_vit, fraction=8/12, device=device, threshold=0.05)

In [6]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

64.0219955444336

In [7]:
collapse_model(collapsible_vit, fraction=8/12, device=device, threshold=0.05)

Collapsing layer blocks.11.mlp
Collapsing layer blocks.10.mlp
Collapsing layer blocks.9.mlp
Collapsing layer blocks.8.mlp
Collapsing layer blocks.7.mlp
Collapsing layer blocks.6.mlp
Collapsing layer blocks.5.mlp
Collapsing layer blocks.4.mlp


In [8]:
evaluate(collapsible_vit, dataloader=dataloader['val'], device=device)

                                                       

63.5359992980957