In [9]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.datasets.imagenet import ImageNet
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np

# 网络定义

## 教师网络RestNet18, 34, 50, 101, 152

In [10]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out =F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out, feature
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
 
def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)
 
def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)
 
def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes)
 
def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes)

## 教师网络 LeNet5

In [11]:
class LeNet5(nn.Module):
    """
    Input: [batch, 1, 32, 32]
    Output:[batch, 10]
    """
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5))    #[28, 28]
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[14, 14]
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))       #[10, 10]
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[5, 5]
        self.conv3 = nn.Conv2d(16, 120, kernel_size=(5, 5)) #[1, 1]
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(84, 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)  
        
        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)
        
        output = self.conv3(output)
        output = self.relu3(output)
        
        feature = output.view(-1, 120)
        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)
        if out_feature == False:
            return output
        else:
            return output,feature

# 网络训练

## 教师网络训练


In [15]:
class TeacherTrainer():
    def __init__(self, epochs, path_dataset, path_ckpt, path_loss):
        # 数据集
        self.dataset_train = CIFAR10(path_dataset, transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]))
        self.dataset_test = CIFAR10(path_dataset, train = False, transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]))
        self.dataset_train_loader = DataLoader(self.dataset_train, batch_size=128, shuffle=True, num_workers=8)
        self.dataset_test_loader = DataLoader(self.dataset_test, batch_size=100, num_workers=8)
        #网络和损失优化
        self.net = ResNet34().cuda()
        self.criterion = torch.nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

        # 训练有关参数
        self.epochs = epochs
        self.best_accr = 0
        self.list_loss = []
        self.path_ckpt = path_ckpt
        self.path_loss = path_loss

    def train(self):
        for epoch in range(1, self.epochs + 1):
            self.net.train()
            loss_epoch = 0
            for i, (images, labels) in enumerate(self.dataset_train_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                self.optimizer.zero_grad()
                output = self.net(images)
                loss = self.criterion(output, labels)
                loss.backward()
                self.optimizer.step()

                loss_epoch += loss.data.item()
            self.list_loss.append(loss_epoch)
            print('Epoch:%d, Loss:%f' % (epoch, loss_epoch))
            self.test(epoch)
        lossfile = np.array(self.list_loss)
        np.save(self.path_loss + '/teacher_loss_{}'.format(self.epochs), lossfile)
    def test(self, epoch):
        self.net.eval()
        total_correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.dataset_test_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                output = self.net(images)
                pred = output.data.max(1)[1]
                total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
        acc = float(total_correct) / len(self.dataset_test)
        if acc > self.best_accr:
            self.best_accr = acc
            self.save_model(self.path_ckpt, epoch)

        print('Test Accuracy:%f' % (acc))

    def save_model(self, path, epoch):
        state = {'net': self.net.state_dict(), 'optimizer':self.optimizer.state_dict(), 'epoch':epoch}
        filename = path + 'teacher__accr%f_epoch%d.pth'%(self.best_accr, epoch)
        torch.save(state, filename)

        

## 学生网络训练

In [13]:
class StudentTrainer():
    def __init__(self, epochs, name_dataset, path_dataset_cifar, path_imagenet, 
                 path_loss, path_student_ckpt, path_teacher_ckpt, num_select):
        self.epochs = epochs
        self.path_loss = path_loss
        self.path_student_ckpt = path_student_ckpt
        self.path_teacher_ckpt = path_teacher_ckpt
        # 测试数据集准备和教师网络
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        if name_dataset == 'cifar10':
            self.data_test = CIFAR10(path_dataset_cifar,
                                     train=False,
                                     transform=transform_test)
            self.teacher_acc = torch.Tensor([0.9523])
            self.class_num = 10
            self.teacher = ResNet34().cuda()
        if name_dataset == 'cifar100':
            self.data_test = CIFAR100(path_dataset_cifar,
                                     train=False,
                                     transform=transform_test)
            self.teacher_acc = torch.Tensor([0.7774])
            self.class_num = 100
            self.teacher = ResNet34(num_classes=100).cuda()
        self.teacher.load_state_dict(torch.load(self.path_teacher_ckpt)['net'])

        self.data_test_loader = DataLoader(self.data_test, batch_size=1000, num_workers=8)
        # 用于筛选正样本的原始训练数据集
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.data_train = ImageFolder(path_imagenet, transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            normalize,
        ]))
        self.data_train_loader_noshuffle = DataLoader(self.data_train, batch_szie=256, shuffle=False, num_workers=8)
        #  从原始数据集中筛选
        self.num_select = num_select
        self.positive_index = self._select_dataset()
        self.dataset_to_selected = ImageFolder(path_imagenet, transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        self.dataset_selected = torch.utils.data.Subset(self.dataset_to_selected, self.positive_index)
        self.dataset_selected_loader = torch.utils.data.DataLoader(self.dataset_selected, batch_size=256, shuffle=True, num_workers=8)
        # 网络
        self.noise_adaption = torch.nn.Parameter(torch.zeros(self.class_num, self.class_num-1))
        self.student = ResNet18(self.class_num).cuda()
        self.nll = nn.NLLLoss().cuda()      ## cross-entropy = softmax + log + nll(把log结果中对应labels的值取负，再求平均)
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.SGD(list(self.student.parameters()), lr=0.1, momentum=0.9, weight_decay=5e-4)
        self.optimizer_noise = torch.optim.Adam([self.noise_adaption], lr=0.001)

        self.loss_list = []
        self.best_acc = 0

    def train(self):
        for epoch in range(1, self.epochs+1):
            loss_epoch = 0
            self.student.train()
            for i, (images, labels) in enumerate(self.dataset_selected_loader):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()

                self.optimizer.zero_grad()
                self.optimizer_noise.zero_grad()

                output = self.student(images)
                output_t = self.teacher(images).detach()
                pseudo_labels = output_t.data.max(1)[1]
                # 损失1：硬损失
                loss = self._kdloss(output, output_t)
                # 将学生预测结果[batch_size, class_num]乘以一个矩阵[class_num, class_num]
                # 将相乘结果和伪标签求损失
                output_s = F.softmax(output, dim=1)
                output_s_adaption = torch.matmul(output_s, self._noise())
                loss += self.nll(torch.log(output_s_adaption), pseudo_labels)

                loss.backward()
                self.optimizer.step()
                self.optimizer_noise.step()
                loss_epoch += loss.data.item()
            self.loss_list.append(loss_epoch)
            print("Epoch:%d, Loss:%f"%(epoch, loss_epoch))
            self.test(epoch)
        file_loss = np.array(self.loss_list)
        np.save(self.path_loss + 'student_epoch{}'.format(self.epochs), file_loss)
    def test(self, epoch):
        self.student.eval()
        
        total_correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.data_test_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                output = self.student(images)
                pred = output.data.max(1)[1]
                total_correct += pred.eq(labels.data.view_as(pred)).sum()

        acc = float(total_correct) / len(self.data_test)
        print('Test ACC: %f'%(acc))
        if acc > self.best_acc:
            self.best_acc = acc
            self.save_model(epoch, acc)
    
    def save_model(self, epochs, accr):
        state = {'net': self.student.state_dict(), 'optimizer':self.optimizer.state_dict(), 'epoch':epochs}
        filename = self.path_student_ckpt + 'student_accr%f_epoch_%d.pth' %(accr, epochs)
        torch.save(state, filename)



                




    def _kdloss(self, student_scores, teacher_scores, T=4):
        p = F.log_softmax(student_scores/T, dim=1)
        q = F.softmax(teacher_scores/T, dim=1)
        l_kl = F.kl_div(p, q, reduce=False)
        loss = torch.sum(l_kl) / teacher_scores.shape[0]
        return loss * (T**2)

    def _noise(self):
        ## 把10x9的矩阵按行取softmax，再乘(1-teacher_acc)，再在对角线添加teacher_acc，得到一个10x10的矩阵
        noise_adaption_softmax = F.softmax(self.noise_adaption, dim=1) * (1 - self.teacher_acc)
        noise_adaption_layer = torch.zero(self.class_num, self.class_num)
        for i in range(self.class_num):
            if i == 0:
                noise_adaption_layer[i] = torch.cat([self.teacher_acc, noise_adaption_softmax[i][i:]])
            if i == self.class_num-1:
                noise_adaption_layer[i] = torch.cat([noise_adaption_softmax[i][:i], self.teacher_acc])
            else:
                noise_adaption_layer[i] = torch.cat([noise_adaption_softmax[i][:i], self.teacher_acc, noise_adaption_softmax[i][i:]])
        return noise_adaption_layer.cuda()
    
    def _select_dataset(self):
        loss_list, pseudo_labels_list = self._identify_outlier()
        positive_index = loss_list.topk(self.num_select, largest=False)[1]
        positive_index = positive_index.tolist()
        return positive_index


    def _identify_outlier(self):
        value = []
        pseudo_labels_list = []
        index = 0
        celoss = nn.CrossEntropyLoss(reduction='none').cuda()
        self.teacher.eval()
        for i, (inputs, labels) in enumerate(self.data_train_loader_noshuffle, start=1):
            inputs = inputs.cuda()
            outputs = self.teacher(inputs)
            pseudo_labels = outputs.data.max(1)[1]
            loss = celoss(outputs, pseudo_labels)
            value.append(loss.detach().clone())
            index += inputs.shape[0]
            pseudo_labels_list.append(pseudo_labels)
        # cat将[tensor([1]),tensor([2])]改为tensor([1, 2])
        return torch.cat(value, dim=0), torch.cat(pseudo_labels_list, dim=0)


            




## 训练教师网络--CIFAR10

In [16]:
path_current = os.getcwd()
path_imagenet = "/home/yinzp/workspace/dataset"
path_cifar = "/home/yinzp/workspace/dataset/cifar-10-python"
path_loss = os.path.join(path_current, 'cache/models/teacher/')
path_student_ckpt = os.path.join(path_current, 'cache/models/student')
path_teacher_ckpt = os.path.join(path_current, 'cache/models/teacher')
train_teacher = TeacherTrainer(120, path_cifar, path_teacher_ckpt, path_loss)
train_teacher.train()

Epoch:1, Loss:792.418685
Test Accuracy:0.405900
Epoch:2, Loss:606.494960
Test Accuracy:0.473700
Epoch:3, Loss:521.761431
Test Accuracy:0.482200
Epoch:4, Loss:433.693335
Test Accuracy:0.652500
Epoch:5, Loss:363.521431
Test Accuracy:0.674200
Epoch:6, Loss:308.265708
Test Accuracy:0.742400
Epoch:7, Loss:270.097119
Test Accuracy:0.708800
Epoch:8, Loss:243.298810
Test Accuracy:0.776900
Epoch:9, Loss:226.158543
Test Accuracy:0.787000
Epoch:10, Loss:213.838875
Test Accuracy:0.740400
Epoch:11, Loss:201.648535
Test Accuracy:0.755600
Epoch:12, Loss:190.229908
Test Accuracy:0.798200
Epoch:13, Loss:184.381434
Test Accuracy:0.772500
Epoch:14, Loss:177.682976
Test Accuracy:0.774500
Epoch:15, Loss:172.242573
Test Accuracy:0.803700
Epoch:16, Loss:164.769477
Test Accuracy:0.805300
Epoch:17, Loss:164.672796
Test Accuracy:0.813300
Epoch:18, Loss:158.660461
Test Accuracy:0.782400
Epoch:19, Loss:157.266840
Test Accuracy:0.787000
Epoch:20, Loss:156.301559
Test Accuracy:0.850000
Epoch:21, Loss:150.743430
Tes

## 训练学生网络--ImageNet(挑选)

In [12]:
path_current = os.getcwd()
path_imagenet = "/home/yinzp/workspace/dataset"
path_cifar = "/home/yinzp/workspace/dataset/cifar10"
path_loss = os.path.join(path_current, 'cache/models/student/')
path_student_ckpt = os.path.join(path_current, 'cache/models/student')
path_teacher_ckpt = os.path.join(path_current, 'cache/models/teacher')
trainstudent = StudentTrainer(120, 'cifar10', path_cifar, path_imagenet, path_loss, path_student_ckpt, path_teacher_ckpt, 60000)
print(path_loss)
dataset_imagenet = ImageNet(path_imagenet, download=False, )

/home/yinzp/gitee/paper-reading/DFND/cache/models/student/
