# 1-总体思路
GAN + 知识蒸馏

网络总共有:
1. generator
2. teacher（已经训练好）
3. student

In [44]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
from torch.autograd import Variable
# import import_ipynb
# from model import Generator, LeNet5, LeNet5Half

# 2-网络的定义

## 2.1 教师网络LeNet

In [45]:
class LeNet5(nn.Module):
    """
    Input: [batch, 1, 32, 32]
    Output:[batch, 10]
    """
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5))    #[28, 28]
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[14, 14]
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))       #[10, 10]
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[5, 5]
        self.conv3 = nn.Conv2d(16, 120, kernel_size=(5, 5)) #[1, 1]
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(84, 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)  
        
        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)
        
        output = self.conv3(output)
        output = self.relu3(output)
        
        feature = output.view(-1, 120)
        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)
        if out_feature == False:
            return output
        else:
            return output,feature

## 2.2 学生网络LeNet5Half

In [46]:
class LeNet5Half(nn.Module):

    def __init__(self):
        super(LeNet5Half, self).__init__()

        self.conv1 = nn.Conv2d(1, 3, kernel_size=(5, 5))
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv2 = nn.Conv2d(3, 8, kernel_size=(5, 5))
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv3 = nn.Conv2d(8, 60, kernel_size=(5, 5))
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(60, 42)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(42, 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)

        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)

        output = self.conv3(output)
        output = self.relu3(output)
        feature = output.view(-1, 60)

        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)

        if out_feature == False:
            return output
        else:
            return output, feature

## 2.3生成器网络
input:[batch_size, 100]

output:[batch_size, 32, 32]

In [47]:
class Generator(nn.Module):
    """
    Input:[batch, 100]
    Out:[batch, 1, 32, 32]
    """
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 32 // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128*self.init_size**2))

        self.conv_blocks0 = nn.Sequential(
            nn.BatchNorm2d(128),
        )
        self.conv_blocks1 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_blocks2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
            nn.BatchNorm2d(1, affine=False) 
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks0(out)
        img = nn.functional.interpolate(img,scale_factor=2)
        img = self.conv_blocks1(img)
        img = nn.functional.interpolate(img,scale_factor=2)
        img = self.conv_blocks2(img)
        return img

# 3-教师网络的训练

数据集：MNIST

损失函数：交叉熵

梯度更新方式：Adam

学习率：lr = 0.001

## 3.1 训练类定义

## 3.2 开始训练及保存参数

# 3-训练类定义
激活损失函数：loss_activation

生成器损失函数：loss_noe_hot

## 3.1教师网络训练类

In [48]:
class TeacherTrainer:
    def __init__(self, path_teacher_ckpt):
        self.path_teacher_ckpt = path_teacher_ckpt
        self.teacher = LeNet5().cuda()
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.Adam(self.teacher.parameters(), lr = 0.001)
        self.loss_list = []
        self.best_accr = 0
        self.data_train = MNIST('~/workspace/dataset/', transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), download=False)
        self.data_test = MNIST('~/workspace/dataset/', train=False, transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        self.data_train_loader = DataLoader(self.data_train, batch_size=256, shuffle=True, num_workers=8)
        self.data_test_loader = DataLoader(self.data_test, batch_size=1024, num_workers=8)

    def train(self, epochs):
        self.teacher.train()
        for epoch in range(1, epochs+1):
            loss_epoch = 0
            for i, (images, labels) in enumerate(self.data_train_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                self.optimizer.zero_grad()
                outputs = self.teacher(images)
                loss = self.criterion(outputs, labels)
                # if i %100== 0:
                #     print('Train-Epoch %d, Batch:%d, Loss %f' % (epoch, i, loss.data.item()))
                loss.backward()
                self.optimizer.step()
                loss_epoch += loss.data.item()
            self.loss_list.append(loss_epoch)
            print('Finish epoch: %d, sum loss:%f' % (epoch, loss_epoch))
            if epoch % 1== 0:
                self.test(epoch)
        lossfile = np.array(self.loss_list)
        np.save('/home/yinzp/gitee/paper-reading/DAFL/cache/models/teacher/teacher_loss_{}'.format(epochs), lossfile)
    def test(self, epoch):
        self.teacher.eval()
        total_correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.data_test_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                output = self.teacher(images)
                pred = output.data.max(1)[1]
                total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
        acc = float(total_correct) / len(self.data_test)
        if acc > self.best_accr:
            self.best_accr = acc
            self.save_model(self.path_teacher_ckpt, epoch)

        print('Test Accuracy:%f' % (acc))

    def save_model(self, path, epoch):
        state = {'net': self.teacher.state_dict(), 'optimizer':self.optimizer.state_dict(), 'epoch':epoch}
        filename = path + 'teacher__accr%f_epoch%d.pth'%(self.best_accr, epoch)
        torch.save(state, filename)
        


## 3.2学生和生成器网络训练类

In [49]:
class StudentTrainer:
    def __init__(self, teacher_ckpt_path, student_ckpt_path, path_dataset, path_loss, epochs, batch_size, latent_dim, lr_G, lr_S, oh, ie, a):
        ## 训练参数
        self.epochs = epochs
        self.lr_G = lr_G
        self.lr_S = lr_S
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.oh = oh
        self.ie = ie
        self.a = a
        self.student_ckpt_path = student_ckpt_path
        self.best_accr = 0
        self.best_epoch = 0
        self.loss_list = []
        self.path_loss = path_loss
        ## 测试数据集
        self.data_test = MNIST(path_dataset, train=False, transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        self.data_test_loader = DataLoader(self.data_test, batch_size=64, num_workers=8)
        ## 网络定义
        self.teacher = LeNet5().cuda()
        self.teacher.load_state_dict(torch.load(teacher_ckpt_path)['net'])
        self.student = LeNet5Half().cuda()
        self.generate = Generator().cuda()
        ## 损失函数和参数更新方式
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer_G = torch.optim.Adam(self.generate.parameters(), lr = self.lr_G)
        self.optimizer_S = torch.optim.SGD(self.student.parameters(), lr = self.lr_S)
        #self.optimizer_S = torch.optim.Adam(self.student.parameters(), lr = self.lr_S)
        self.lr_scheduler_S = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer_S, T_0=5, T_mult=2)
        self.lr_scheduler_G = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer_G, T_0=5, T_mult=2)
    def train(self):
        for epoch in range(1, self.epochs+1):
            self.generate.train()
            self.student.train()
            loss_epoch = 0
            for i in range(1, 121):
                z = Variable(torch.randn(self.batch_size, self.latent_dim)).cuda()
                self.optimizer_G.zero_grad()
                self.optimizer_S.zero_grad()
                gen_img = self.generate(z)
                output, features = self.teacher(gen_img, out_feature = True)
                pseudo_labels = output.data.max(1)[1]
                # 用于纠正Generator的损失
                ## 损失1：激活损失
                loss_active = -features.abs().mean()
                ## 损失2：one-hot损失
                loss_onehot = self.criterion(output, pseudo_labels)
                ## 损失3：信息熵损失
                softmax_o_T = F.softmax(output, dim = 1).mean(dim=0)
                loss_information_entropy = (softmax_o_T * torch.log10(softmax_o_T)).sum()

                # 用于纠正学生网络的损失：知识蒸馏损失
                loss_kd = self.kdloss(self.student(gen_img.detach()), output.detach())
                # 硬损失
                loss_kd_h = self.criterion(self.student(gen_img.detach()), pseudo_labels)

                loss = loss_onehot * self.oh + loss_information_entropy * self.ie + loss_active * self.a + loss_kd + loss_kd_h
                loss.backward()
                self.optimizer_S.step()
                self.optimizer_G.step()
                # self.lr_scheduler_S.step()
                # self.lr_scheduler_G.step()
                loss_epoch += loss.data.item()
            print('Epoch%d, loss%f' % (epoch, loss_epoch))
            self.loss_list.append(loss_epoch)
            self.test(epoch)
        loss_file = np.array(self.loss_list)
        np.save(self.path_loss + 'student_SGD_epoch_{}'.format(self.epochs), loss_file)
        

    def test(self, epoch):
        total_corerect = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.data_test_loader, start=1):
                images, labels = images.cuda(), labels.cuda()
                self.student.eval()
                output = self.student(images)
                pred = output.data.max(1)[1]
                total_corerect += pred.eq(labels.view_as(pred)).sum()
        accr = float(total_corerect) / len(self.data_test)
        print('Test Accuracy: %f' % (accr))
        if accr > self.best_accr:
            self.best_accr, self.best_epoch = accr, epoch
            self.save_model(self.best_epoch, self.best_accr)

    def kdloss(self, y, teacher_scores):
        p = F.log_softmax(y, dim=1)
        q = F.softmax(teacher_scores, dim=1)
        l_kl = F.kl_div(p, q, size_average=False) / y.shape[0]
        return l_kl
    
    def save_model(self, epochs, accr):
        state = {'net': self.student.state_dict(), 'optimizer':self.optimizer_S.state_dict(), 'epoch':epochs}
        filename = self.student_ckpt_path + 'student_accr%f_epoch_%d.pth' %(accr, epochs)
        torch.save(state, filename)

## 3.3 生成器和学生函数分开训练

# 开始训练

## 训练教师网络

In [36]:
trainteacher = TeacherTrainer('/home/yinzp/gitee/paper-reading/DAFL/cache/models/teacher/')
epochs = 150
trainteacher.train(epochs)

Finish epoch: 1, sum loss:103.485972
Test Accuracy:0.959100
Finish epoch: 2, sum loss:24.745323
Test Accuracy:0.975300
Finish epoch: 3, sum loss:15.846061
Test Accuracy:0.984200
Finish epoch: 4, sum loss:12.867239
Test Accuracy:0.984600
Finish epoch: 5, sum loss:10.997393
Test Accuracy:0.985500
Finish epoch: 6, sum loss:9.722271
Test Accuracy:0.986100
Finish epoch: 7, sum loss:8.097537
Test Accuracy:0.989200
Finish epoch: 8, sum loss:7.421795
Test Accuracy:0.987300
Finish epoch: 9, sum loss:6.605359
Test Accuracy:0.987400
Finish epoch: 10, sum loss:5.920653
Test Accuracy:0.988600
Finish epoch: 11, sum loss:5.374886
Test Accuracy:0.990200
Finish epoch: 12, sum loss:4.719485
Test Accuracy:0.990100
Finish epoch: 13, sum loss:4.174470
Test Accuracy:0.990400
Finish epoch: 14, sum loss:4.173966
Test Accuracy:0.989100
Finish epoch: 15, sum loss:3.220392
Test Accuracy:0.989200
Finish epoch: 16, sum loss:3.376388
Test Accuracy:0.989400
Finish epoch: 17, sum loss:2.903660
Test Accuracy:0.990100


## 训练学生网络


In [50]:
current_path = os.getcwd()
path_ckpt_t_teacher = current_path + '/cache/models/teacher/teacher__accr0.992300_epoch118.pth'
path_ckpt_student = current_path + '/cache/models/student/'
path_dataset = '/home/yinzp/workspace/dataset/'
path_loss = current_path + '/cache/models/student/'
lr_G = 0.2
lr_S = 2e-3
epochs = 200
batch_size = 512
latent_dim = 100
oh = 1
ie = 5
a = 0.1
torch.cuda.empty_cache()
trainstudent = StudentTrainer(path_ckpt_t_teacher, path_ckpt_student, path_dataset, path_loss, epochs, batch_size, latent_dim, lr_G, lr_S, oh, ie, a)
trainstudent.train()



Epoch1, loss-45.903053
Test Accuracy: 0.102900
Epoch2, loss-67.552167
Test Accuracy: 0.103200
Epoch3, loss-69.354913
Test Accuracy: 0.103200
Epoch4, loss-71.558287
Test Accuracy: 0.103200
Epoch5, loss-73.528788
Test Accuracy: 0.103200
Epoch6, loss-84.936398
Test Accuracy: 0.103400
Epoch7, loss-146.076957
Test Accuracy: 0.097400
Epoch8, loss-193.976094
Test Accuracy: 0.097400
Epoch9, loss-251.662215
Test Accuracy: 0.097400
Epoch10, loss-280.368194
Test Accuracy: 0.097400
Epoch11, loss-281.429190
Test Accuracy: 0.097400
Epoch12, loss-281.806456
Test Accuracy: 0.097400
Epoch13, loss-282.056922
Test Accuracy: 0.097400
Epoch14, loss-282.225792
Test Accuracy: 0.097400
Epoch15, loss-282.340767
Test Accuracy: 0.097400
Epoch16, loss-282.419549
Test Accuracy: 0.097400
Epoch17, loss-282.473769
Test Accuracy: 0.097400
Epoch18, loss-282.511194
Test Accuracy: 0.097400
Epoch19, loss-282.536912
Test Accuracy: 0.097400
Epoch20, loss-282.554568
Test Accuracy: 0.097400
Epoch21, loss-282.566879
Test Accur

KeyboardInterrupt: 