In [1]:
import sys
sys.path.append('/nfs/xwx/model-doctor-xwx')

import torch
import torchvision
import models
import loaders
import argparse
import os
import datetime
import time
import matplotlib
import yaml
import math


from torch import optim
from configs import config
from utils.lr_util import get_lr_scheduler
from utils.time_util import print_time, get_current_time
from sklearn.metrics import classification_report
from loss.refl import reduce_equalized_focal_loss
from loss.fl import focal_loss
from loss.hcl import hc_loss
from modify_kernel.util.draw_util import draw_lr, draw_fc_weight
from modify_kernel.util.cfg_util import print_yml_cfg
from functools import partial
from utils.args_util import print_args
from utils.general import init_seeds, get_head_and_kernel, get_head_ratio

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import warnings # ignore warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable
from torch.nn.parameter import Parameter


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=3, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)
        self.mask = Parameter(torch.Tensor(1, 64, 1, 1))
        
        # self._init_mask()
        
    def _init_mask(self):
        import numpy as np
        mask_path = "/nfs/xwx/model-doctor-xwx/output/result/channels/resnet32-cifar-10-lt-ir100-refl-th-0.4-wr/channel_grads_-1_epoch0.npy"
        mask_data = np.load(mask_path)
        self.mask.data = (torch.tensor(mask_data)).unsqueeze(-1).unsqueeze(-1)
        # self.mask.data[:3] = 1 - self.mask.data[:3]
        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x, Y=None):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        feature = out
        
        if self.training:
            out = out * self.mask
            
                            
        out = F.avg_pool2d(out, out.size()[3])

        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out, feature


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32(in_channels=3, num_classes=10):
    return ResNet(BasicBlock, [5, 5, 5], in_channels, num_classes)


def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])


In [2]:
model_name = "resnet32"
data_name  = "cifar-10-lt-ir100"
model_path = "/nfs/xwx/model-doctor-xwx/output/model/pretrained/resnet32-cifar-10-lt-ir100-refl-th-0.4-wr/checkpoint.pth"

In [3]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [4]:
data_loaders, _ = loaders.load_data(data_name=data_name)

----------------------------------------
LOAD DATA: cifar-10-lt-ir100
----------------------------------------
load cifar dataset from image dir

load cifar dataset from image dir



In [5]:
base_model = resnet32()
base_model.load_state_dict(torch.load(model_path)["model"], strict=False)
base_model.to(device);

In [6]:
def test(dataloader, model, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)
        pred, _ = model(X, y)
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    correct /= size
        
    print(f"Test Error: Accuracy: {(100*correct):>0.2f}%")

In [7]:
test(data_loaders["val"], base_model, device)

Test Error: Accuracy: 71.82%


In [8]:
import copy

model = copy.deepcopy(base_model)

In [9]:
# 冻结特征层参数
# for param in model.parameters():
#     param.requires_grad = False
    
# for param in model.linear.parameters():
#     param.requires_grad = True
    
# model.mask.requires_grad = True
    
parameters = [p for p in model.parameters() if p.requires_grad]

# for name, p in model.named_parameters():
#     if p.requires_grad:
#         print(name,':',p.size())

In [12]:
def train(dataloader, model, loss_fn, optimizer, device):
    train_loss, correct = 0, 0
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()
    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)

        with torch.set_grad_enabled(True):
            pred, _ = model(X, y)  # 网络前向计算

            loss = loss_fn(pred, y)
            train_loss += loss.item()
            
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
            # Backpropagation
            optimizer.zero_grad()  # 清除过往梯度
            loss.backward()  # 得到模型中参数对当前输入的梯度
            optimizer.step()  # 更新参数
    
    train_loss /= num_batches
    correct /= size
    
    print(f"Train Error: Accuracy: {(100*correct):>0.2f}%, Avg loss: {train_loss:>8f}")

In [13]:
from utils.lr_util import get_lr_scheduler

base_lr = 0.1
total_epoch_num = 10
weight_decay = 2e-4

optimizer = optim.SGD(parameters, lr=base_lr, momentum=0.9, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch_num, eta_min=0.0)
# scheduler = get_lr_scheduler(optimizer, True)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(total_epoch_num):
    print(f"\nEpoch {epoch+1}")
    train(data_loaders["train"], model, loss_fn, optimizer, device)
    test(data_loaders["val"], model, device)
    scheduler.step()


Epoch 1
Train Error: Accuracy: 40.30%, Avg loss:      nan
Test Error: Accuracy: 10.00%

Epoch 2
Train Error: Accuracy: 40.30%, Avg loss:      nan
Test Error: Accuracy: 10.00%

Epoch 3
Train Error: Accuracy: 40.30%, Avg loss:      nan
Test Error: Accuracy: 10.00%

Epoch 4


KeyboardInterrupt: 