In [9]:
# Auto reload modules
%load_ext autoreload
%autoreload 2

# Set CUDA_VISIBLE_DEVICES

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# matplotlib setting
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8, 8)
plt.rcParams["font.size"] = 14
%matplotlib inline


import numpy as np
from functools import partial

import torch
from torch import nn as nn

from itertools import repeat
import collections.abc

import torchvision.models as models
import copy
 
from utils.aggregate_block.model_trainer_generate import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
MLPBLOCK_INSTANCE = models.vision_transformer.MLPBlock

In [11]:
def change_module(model, name, module):
    name_list = name.split(".")
    if len(name_list) == 1:
        model._modules[name_list[0]] = module
    else:
        change_module(model._modules[name_list[0]], ".".join(name_list[1:]), module)

def collapse_model(model, fraction=1.0, threshold=0.05, device=None):
    num_mlp_layers = len(list(model.named_modules()))
    for name, module in list(model.named_modules())[::-1][:int(num_mlp_layers * fraction)]:
        if isinstance(module, CollapsibleMlp):
            print("Collapsing layer {}".format(name))
            module.collapse(threshold=threshold)

In [12]:
model = generate_cls_model("vit_b_16", )

In [13]:
model

Sequential(
  (0): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
  (1): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear_1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (dropout_1): Dropout(p=0.0, inplace=False)
            (linear_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout_2): Dropout(p=0.0, inpl

In [14]:
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse

to_2tuple = _ntuple(2)


class CollapsibleMlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            batch_norm=False,
            bias=True,
            drop=0.,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = nn.PReLU(num_parameters=1, init=0.01)
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = nn.BatchNorm1d(hidden_features) if batch_norm else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])
        self.batch_norm = batch_norm


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
    
    def linear_loss(self):
        if isinstance(self.act, nn.Identity):
            return 0
        return (self.act.weight - 1)**2
    
    def collapse(self, threshold=0.05):
        if isinstance(self.act, nn.Identity):
            return
        if (self.act.weight - 1).abs() < threshold:
            if self.batch_norm:
                W1 = self.fc1.weight.data
                B1 = self.fc1.bias.data
                gamma = self.norm.weight.data
                beta = self.norm.bias.data
                mean = self.norm.running_mean
                var = self.norm.running_var
                eps = self.norm.eps
                W2 = self.fc2.weight.data
                B2 = self.fc2.bias.data

                new_W = W2 @ torch.diag(gamma / torch.sqrt(var + eps)) @ W1
                new_B = W2 @ (gamma * (B1 - mean) / torch.sqrt(var + eps) + beta) + B2

                self.fc1 = nn.Linear(self.fc1.in_features, self.fc2.out_features)
                self.fc1.weight.data = new_W
                self.fc1.bias.data = new_B
                self.fc2 = nn.Identity()
                self.norm = nn.Identity()
                self.act = nn.Identity()
                self.drop1 = nn.Identity()
            else:
                W1 = self.fc1.weight.data
                B1 = self.fc1.bias.data
                W2 = self.fc2.weight.data
                B2 = self.fc2.bias.data

                new_W = W2 @ W1
                new_B = W2 @ B1 + B2

                self.fc1 = nn.Linear(self.fc1.in_features, self.fc2.out_features)
                self.fc1.weight.data = new_W
                self.fc1.bias.data = new_B
                self.fc2 = nn.Identity()
                self.act = nn.Identity()
                self.drop1 = nn.Identity()
        else:
            print("Not collapsible")

    def load_from_Mlp(self, module):
        self.fc1.weight.data = module.linear_1.weight.data
        self.fc1.bias.data = module.linear_1.bias.data
        self.fc2.weight.data = module.linear_2.weight.data
        self.fc2.bias.data = module.linear_2.bias.data
    

In [19]:
def get_collapsible_model(model, fraction=1.0, device=None):
    num_mlp_layers = len(list(model.named_modules()))
    copy_model = copy.deepcopy(model).to(device)
    for name, module in list(copy_model.named_modules())[::-1][:int(num_mlp_layers * fraction)]:
        if isinstance(module, MLPBLOCK_INSTANCE):
            print("Collapsing layer {}".format(name))
            in_features = module.linear_1.in_features
            hidden_features = module.linear_2.in_features
            out_features = module.linear_2.out_features
            bias = module.linear_2.bias
            collapsibleMLP = CollapsibleMlp(in_features=in_features, hidden_features=hidden_features, out_features=out_features, batch_norm=False, bias=bias, drop=0)
            collapsibleMLP.load_from_Mlp(module)
            if device is not None:
                collapsibleMLP.to(device)
            change_module(copy_model, name, collapsibleMLP)
    return copy_model.to(device)

In [21]:
get_collapsible_model(model, 0.1, "cpu")

Collapsing layer 1.encoder.layers.encoder_layer_11.mlp


Sequential(
  (0): Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
  (1): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (linear_1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU()
            (dropout_1): Dropout(p=0.0, inplace=False)
            (linear_2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout_2): Dropout(p=0.0, inpl