In [None]:
import torch
from torch import nn


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, num_classes, dim, qkv_bias=False, qk_scale=None, attn_drop=0.):
        super().__init__()
        self.scale = qk_scale or dim ** -0.5
        # self.norm1q = nn.LayerNorm(dim)
        # self.norm1k = nn.LayerNorm(dim)

        self.wq = nn.Linear(dim, dim, bias=qkv_bias)
        self.wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        
        self.kx = nn.Parameter(torch.empty(num_classes, 10, dim).normal_(std=0.01))

    def forward(self, qx):
        qx = qx.unsqueeze(1) # qx:[Bq, 1, C]

        # q = self.wq(self.norm1q(qx))
        # k = self.wk(self.norm1k(kx))
        q = self.wq(qx)
        k = self.wk(self.kx)
        v = self.kx
        attn = torch.einsum('qoc,knc->qkn', q, k) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        cos = nn.CosineSimilarity(dim=2, eps=1e-6)
        x = torch.einsum('knc,qkn->qkc', v, attn)

        idx = cos(qx, x).argmax(-1)
        return x[:, idx, :][0]


class GGR(nn.Module):
    def __init__(self, num_classes=10, dim=64):
        super().__init__()
        self.mlp = Mlp(dim)
        self.attn = Attention(num_classes, dim)

    def forward(self, x):
        # out1 = self.mlp(x)
        out2 = self.attn(x)

        # return out1 + out2
        return x + out2

In [9]:
'''
Properly implemented ResNet-s for CIFAR10 as described in paper [1].

The implementation and structure of this file is hugely influenced by [2]
which is implemented for ImageNet and doesn't have option A for identity.
Moreover, most of the implementations on the web is copy-paste from
torchvision's resnet and has wrong number of params.

Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
number of layers and parameters:

name      | layers | params
ResNet20  |    20  | 0.27M
ResNet32  |    32  | 0.46M
ResNet44  |    44  | 0.66M
ResNet56  |    56  | 0.85M
ResNet110 |   110  |  1.7M
ResNet1202|  1202  | 19.4m

which this implementation indeed has.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

If you use this implementation in you work, please don't forget to mention the
author, Yerlan Idelbayev.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A', groups=4):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False, groups=groups),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=3, num_classes=10, use_ggr=False, groups=[1, 1, 4]):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, groups=groups[0])
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, groups=groups[1])
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, groups=groups[2])
        self.linear = nn.Linear(64, num_classes)
        # self.ggr = GGR(num_classes, 64)
        # self.use_ggr = use_ggr

    def _make_layer(self, block, planes, num_blocks, stride, groups=4):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, groups=groups))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)  # b x 64
        
        # if self.use_ggr:
        #     out = self.ggr(out)
            
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32(in_channels=3, num_classes=10, use_ggr=False, groups=[1, 1, 4]):
    return ResNet(BasicBlock, [5, 5, 5], in_channels, num_classes, use_ggr, groups=groups)


def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])


In [2]:
import sys
sys.path.append('/nfs/xwx/model-doctor-xwx')

import torch
import torchvision
import models
import loaders
import argparse
import os
import datetime
import time
import matplotlib
import yaml
import math


from torch import optim
from configs import config
from utils.lr_util import get_lr_scheduler
from utils.time_util import print_time, get_current_time
from sklearn.metrics import classification_report
from loss.refl import reduce_equalized_focal_loss
from loss.fl import focal_loss
from loss.hcl import hc_loss
from modify_kernel.util.draw_util import draw_lr, draw_fc_weight
from modify_kernel.util.cfg_util import print_yml_cfg
from functools import partial
from utils.args_util import print_args
from utils.general import init_seeds, get_head_and_kernel, get_head_ratio

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import warnings # ignore warnings
warnings.filterwarnings("ignore")


In [None]:
model_path = "/nfs/xwx/model-doctor-xwx/output/model/pretrained/resnet32/cifar-10-lt-ir100/lr0.01/cosine_lr_scheduler/ce_loss/2022-07-15_17-27-58/best-model-acc0.7144.pth"

In [3]:
model_name = "resnet32"
data_name  = "cifar-10-lt-ir100"

grad_path = "/nfs/xwx/model-doctor-xwx/output/result/channels/resnet32-cifar-10-lt-ir100/channel_grads_-1.npy"

In [None]:
kx = torch.from_numpy(np.asarray(np.load(grad_path)))
kx.shape

In [4]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [5]:
data_loaders, _ = loaders.load_data(data_name=data_name)

----------------------------------------
LOAD DATA: cifar-10-lt-ir100
----------------------------------------
load cifar dataset from image dir

load cifar dataset from image dir



In [None]:
base_model = resnet32()

base_model.load_state_dict(torch.load(model_path)["model"], strict=False)
base_model.to(device);

In [6]:
def test(dataloader, model, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)
        # kx = kx.to(device)
        pred = model(X)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    correct /= size
        
    print(f"Test Error: Accuracy: {(100*correct):>0.2f}%")
    
    return correct

In [None]:
test(data_loaders["val"], base_model, device)

In [7]:
def train(dataloader, model, loss_fn, optimizer, device):
    train_loss, correct = 0, 0
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)

        with torch.set_grad_enabled(True):
            pred= model(X)  # 网络前向计算

            loss = loss_fn(pred, y)
            train_loss += loss.item()
            
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
            # Backpropagation
            optimizer.zero_grad()  # 清除过往梯度
            loss.backward()  # 得到模型中参数对当前输入的梯度
            optimizer.step()  # 更新参数
    
    train_loss /= num_batches
    correct /= size
    
    print(f"Train Error: Accuracy: {(100*correct):>0.2f}%, Avg loss: {train_loss:>8f}")
    

In [None]:
base_model = resnet32(use_ggr=True)

base_model.load_state_dict(torch.load(model_path)["model"], strict=False)
base_model.to(device)

test(data_loaders["val"], base_model, device)

In [15]:
import math
import copy

model = resnet32(groups=[1, 4, 1])
model.to(device)

# active_layers = [model.linear.weight, model.linear.bias]  

# for param in model.parameters(): #freez all model paramters except the classifier layer
#     param.requires_grad = False
    
# for param in active_layers:
#     param.requires_grad = True

# for param in model.ggr.parameters():
#     param.requires_grad = True
    
# for param in model.layer3[4].conv2.parameters():
#     param.requires_grad = True
    
parameters = [p for p in model.parameters() if p.requires_grad]

# for name, param in model.named_parameters():
#     if param.requires_grad == True:
#         print(name,param.requires_grad)

In [16]:
base_lr = 0.01
total_epoch_num = 200
weight_decay = 5e-4 #weight decay value

optimizer = optim.SGD(parameters, lr=base_lr, momentum=0.9, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch_num, eta_min=0.0)
loss_fn = nn.CrossEntropyLoss()

best_acc = 0
for epoch in range(total_epoch_num):
    print(f"\nEpoch {epoch+1}")
    train(data_loaders["train"], model, loss_fn, optimizer, device)
    val_acc = test(data_loaders["val"], model, device)
    if val_acc > best_acc:
        best_acc = val_acc
        print(f"[FEAT] best acc: {best_acc}")
    scheduler.step()


Epoch 1
Train Error: Accuracy: 53.56%, Avg loss: 1.352593
Test Error: Accuracy: 22.32%
[FEAT] best acc: 0.2232

Epoch 2
Train Error: Accuracy: 62.78%, Avg loss: 1.066049
Test Error: Accuracy: 26.97%
[FEAT] best acc: 0.2697

Epoch 3
Train Error: Accuracy: 66.48%, Avg loss: 0.978785
Test Error: Accuracy: 29.49%
[FEAT] best acc: 0.2949

Epoch 4
Train Error: Accuracy: 68.47%, Avg loss: 0.919876
Test Error: Accuracy: 26.62%

Epoch 5
Train Error: Accuracy: 70.10%, Avg loss: 0.868026
Test Error: Accuracy: 28.80%

Epoch 6
Train Error: Accuracy: 71.32%, Avg loss: 0.839674
Test Error: Accuracy: 36.63%
[FEAT] best acc: 0.3663

Epoch 7
Train Error: Accuracy: 72.96%, Avg loss: 0.794101
Test Error: Accuracy: 36.71%
[FEAT] best acc: 0.3671

Epoch 8
Train Error: Accuracy: 73.61%, Avg loss: 0.761874
Test Error: Accuracy: 39.51%
[FEAT] best acc: 0.3951

Epoch 9
Train Error: Accuracy: 75.25%, Avg loss: 0.717668
Test Error: Accuracy: 37.90%

Epoch 10
Train Error: Accuracy: 76.29%, Avg loss: 0.698325
Test