# 1-总体思路
GAN + 知识蒸馏

网络总共有:
1. generator
2. teacher（已经训练好）
3. student

In [4]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
from torch.autograd import Variable
# import import_ipynb
# from model import Generator, LeNet5, LeNet5Half

# 2-网络的定义

## 2.1 教师网络LeNet

In [5]:
class LeNet5(nn.Module):
    """
    Input: [batch, 1, 32, 32]
    Output:[batch, 10]
    """
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5))    #[28, 28]
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[14, 14]
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))       #[10, 10]
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)  #[5, 5]
        self.conv3 = nn.Conv2d(16, 120, kernel_size=(5, 5)) #[1, 1]
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(84, 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)  
        
        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)
        
        output = self.conv3(output)
        output = self.relu3(output)
        
        feature = output.view(-1, 120)
        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)
        if out_feature == False:
            return output
        else:
            return output,feature

## 2.2 学生网络LeNet5Half

In [6]:
class LeNet5Half(nn.Module):

    def __init__(self):
        super(LeNet5Half, self).__init__()

        self.conv1 = nn.Conv2d(1, 3, kernel_size=(5, 5))
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv2 = nn.Conv2d(3, 8, kernel_size=(5, 5))
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

        self.conv3 = nn.Conv2d(8, 60, kernel_size=(5, 5))
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(60, 42)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(42, 10)

    def forward(self, img, out_feature=False):
        output = self.conv1(img)
        output = self.relu1(output)
        output = self.maxpool1(output)

        output = self.conv2(output)
        output = self.relu2(output)
        output = self.maxpool2(output)

        output = self.conv3(output)
        output = self.relu3(output)
        feature = output.view(-1, 60)

        output = self.fc1(feature)
        output = self.relu4(output)
        output = self.fc2(output)

        if out_feature == False:
            return output
        else:
            return output, feature

## 2.3生成器网络
input:[batch_size, 100]

output:[batch_size, 32, 32]

In [7]:
class Generator(nn.Module):
    """
    Input:[batch, 100]
    Out:[batch, 1, 32, 32]
    """
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 32 // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128*self.init_size**2))

        self.conv_blocks0 = nn.Sequential(
            nn.BatchNorm2d(128),
        )
        self.conv_blocks1 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_blocks2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
            nn.BatchNorm2d(1, affine=False) 
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks0(out)
        img = nn.functional.interpolate(img,scale_factor=2)
        img = self.conv_blocks1(img)
        img = nn.functional.interpolate(img,scale_factor=2)
        img = self.conv_blocks2(img)
        return img

# 3-教师网络的训练

数据集：MNIST

损失函数：交叉熵

梯度更新方式：Adam

学习率：lr = 0.001

## 3.1 训练类定义

## 3.2 开始训练及保存参数

# 3-训练类定义
激活损失函数：loss_activation

生成器损失函数：loss_noe_hot

## 3.1教师网络训练类

In [None]:
class TeacherTrainer:
    def __init__(self):
        self.teacher = LeNet5().cuda()
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer = torch.optim.Adam(self.teacher.parameters(), lr = 0.001)
        self.data_train = MNIST('~/workspace/dataset/', transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), download=False)
        self.data_test = MNIST('~/workspace/dataset/', train=False, transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        self.data_train_loader = DataLoader(self.data_train, batch_size=256, shuffle=True, num_workers=8)
        self.data_test_loader = DataLoader(self.data_test, batch_size=1024, num_workers=8)

    def train(self, epochs):
        self.teacher.train()
        for epoch in range(1, epochs+1):
            loss_list, batch_list = [], []
            for i, (images, labels) in enumerate(self.data_train_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                self.optimizer.zero_grad()
                outputs = self.teacher(images)
                loss = self.criterion(outputs, labels)
                loss_list.append(loss.data.item())
                batch_list.append(i+1)
                # if i %100== 0:
                #     print('Train-Epoch %d, Batch:%d, Loss %f' % (epoch, i, loss.data.item()))
                loss.backward()
                self.optimizer.step()
        
            print('Finish epoch: %d, sum loss:%f' % (epoch, sum(loss_list)))
            if epoch % 1== 0:
                self.test()
    
    def test(self):
        self.teacher.eval()
        total_correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.data_test_loader, start=1):
                images, labels = Variable(images).cuda(), Variable(labels).cuda()
                output = self.teacher(images)
                pred = output.data.max(1)[1]
                total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
        acc = float(total_correct) / len(self.data_test)
        print('Test Accuracy:%f' % (acc))

    def save_model(self, path, epochs):
        state = {'net': self.teacher.state_dict(), 'optimizer':self.optimizer.state_dict(), 'epoch':epochs}
        filename = path + 'net_epoch_%d.pth'%(epochs)
        torch.save(state, filename)
        


## 3.2学生和生成器网络训练类

In [15]:
class StudentTrainer:
    def __init__(self, teacher_ckpt_path, student_ckpt_path, path_dataset, epochs, batch_size, latent_dim, lr_G, lr_S, oh, ie, a):
        ## 训练参数
        self.epochs = epochs
        self.lr_G = lr_G
        self.lr_S = lr_S
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.oh = oh
        self.ie = ie
        self.a = a
        self.student_ckpt_path = student_ckpt_path
        ## 测试数据集
        self.data_test = MNIST(path_dataset, train=False, transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        self.data_test_loader = DataLoader(self.data_test, batch_size=64, num_workers=8)
        ## 网络定义
        self.teacher = LeNet5().cuda()
        self.teacher.load_state_dict(torch.load(teacher_ckpt_path)['net'])
        self.student = LeNet5Half().cuda()
        self.generate = Generator().cuda()
        ## 损失函数和参数更新方式
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer_G = torch.optim.Adam(self.generate.parameters(), lr = self.lr_G)
        self.optimizer_S = torch.optim.Adam(self.student.parameters(), lr = self.lr_S)
        
    def train(self):
        for epoch in range(1, epochs+1):
            self.generate.train()
            self.student.train()
            loss_epoch = 0
            for i in range(1, 121):
                z = Variable(torch.randn(self.batch_size, self.latent_dim)).cuda()
                self.optimizer_G.zero_grad()
                self.optimizer_S.zero_grad()
                gen_img = self.generate(z)
                output, features = self.teacher(gen_img, out_feature = True)
                pseudo_labels = output.data.max(1)[1]
                # 用于纠正Generator的损失
                ## 损失1：激活损失
                loss_active = -features.abs().mean()
                ## 损失2：one-hot损失
                loss_onehot = self.criterion(output, pseudo_labels)
                ## 损失3：信息熵损失
                softmax_o_T = F.softmax(output, dim = 1).mean(dim=0)
                loss_information_entropy = (softmax_o_T * torch.log10(softmax_o_T)).sum()

                # 用于纠正学生网络的损失：知识蒸馏损失
                loss_kd = self.kdloss(self.student(gen_img.detach()), output.detach())

                loss = loss_onehot * self.oh + loss_information_entropy * self.ie + loss_active * self.a + loss_kd
                loss.backward()
                self.optimizer_S.step()
                self.optimizer_G.step()
                loss_epoch += loss.data.item()
            print('Epoch%d, loss%f' % (epoch, loss_epoch))
            self.test(epoch)

    def test(self, epoch):
        total_corerect = 0
        best_accr = 0
        best_epoch = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.data_test_loader, start=1):
                images, labels = images.cuda(), labels.cuda()
                self.student.eval()
                output = self.student(images)
                pred = output.data.max(1)[1]
                total_corerect += pred.eq(labels.view_as(pred)).sum()
        accr = float(total_corerect) / len(self.data_test)
        print('Test Accuracy: %f' % (accr))
        if accr > best_accr:
            best_accr, best_epoch = accr, epoch
            self.save_model(best_epoch, best_accr)

    def kdloss(self, y, teacher_scores):
        p = F.log_softmax(y, dim=1)
        q = F.softmax(teacher_scores, dim=1)
        l_kl = F.kl_div(p, q, size_average=False) / y.shape[0]
        return l_kl
    
    def save_model(self, epochs, accr):
        state = {'net': self.student.state_dict(), 'optimizer':self.optimizer_S.state_dict(), 'epoch':epochs}
        filename = self.student_ckpt_path + 'student_accr%f_epoch_%d.pth' %(accr, epochs)
        torch.save(state, filename)

# 开始训练

## 训练教师网络

In [9]:
trainteacher = TeacherTrainer()
epochs = 20
trainteacher.train(epochs)
path = os.getcwd() + '/cache/models/'
trainteacher.save_model('cache/models/', epochs)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Finish epoch: 1, sum loss:101.072858
Test Accuracy:0.962800
Finish epoch: 2, sum loss:24.796381
Test Accuracy:0.978200
Finish epoch: 3, sum loss:16.870111
Test Accuracy:0.982900
Finish epoch: 4, sum loss:13.435808
Test Accuracy:0.981900
Finish epoch: 5, sum loss:11.124334
Test Accuracy:0.985200
Finish epoch: 6, sum loss:9.164217
Test Accuracy:0.985900
Finish epoch: 7, sum loss:8.336562
Test Accuracy:0.983500
Finish epoch: 8, sum loss:7.125623
Test Accuracy:0.988100
Finish epoch: 9, sum loss:6.302748
Test Accuracy:0.989300
Finish epoch: 10, sum loss:5.492025
Test Accuracy:0.989400
Finish epoch: 11, sum loss:5.196884
Test Accuracy:0.991200
Finish epoch: 12, sum loss:4.570235
Test Accuracy:0.988500
Finish epoch: 13, sum loss:3.954003
Test Accuracy:0.988200
Finish epoch: 14, sum loss:3.541092
Test Accuracy:0.989000
Finish epoch: 15, sum loss:3.209433
Test Accuracy:0.988700
Finish epoch: 16, sum loss:2.979558
Test Accuracy:0.991400
Finish epoch: 17, sum loss:3.152918
Test Accuracy:0.988500


## 训练学生网络


In [17]:
current_path = os.getcwd()
path_ckpt_t_teacher = current_path + '/cache/models/net_epoch_20.
ckpt_path_student = current_path + '/cache/models'
path_dataset = '/home/yinzp/workspace/dataset/'
lr_G = 0.2
lr_S = 2e-3
epochs = 50
batch_size = 512
latent_dim = 100
oh = 1
ie = 5
a = 0.1
torch.cuda.empty_cache()
trainstudent = StudentTrainer(path_ckpt_t_teacher, ckpt_path_student, path_dataset, epochs, batch_size, latent_dim, lr_G, lr_S, oh, ie, a)
trainstudent.train()

Epoch1, loss-414.989403
Test Accuracy: 0.609200
Epoch2, loss-481.897155
Test Accuracy: 0.825900
Epoch3, loss-508.609571
Test Accuracy: 0.863800
Epoch4, loss-520.428852
Test Accuracy: 0.900500
Epoch5, loss-522.409932
Test Accuracy: 0.885700
Epoch6, loss-525.143704
Test Accuracy: 0.925700
Epoch7, loss-443.878463
Test Accuracy: 0.725800
Epoch8, loss-386.356447
Test Accuracy: 0.724700
Epoch9, loss-386.357029
Test Accuracy: 0.724700
Epoch10, loss-386.357031
Test Accuracy: 0.724700
Epoch11, loss-386.357031
Test Accuracy: 0.724700
Epoch12, loss-386.357031
Test Accuracy: 0.724700
Epoch13, loss-386.357031
Test Accuracy: 0.724700
Epoch14, loss-386.357031
Test Accuracy: 0.724700
Epoch15, loss-386.357031
Test Accuracy: 0.724700
Epoch16, loss-386.357031
Test Accuracy: 0.724700
Epoch17, loss-386.357031
Test Accuracy: 0.724700
Epoch18, loss-386.357031
Test Accuracy: 0.724700
Epoch19, loss-386.357031
Test Accuracy: 0.724700
Epoch20, loss-386.357031
Test Accuracy: 0.724700
Epoch21, loss-386.357031
Test

In [18]:
def save_list(list, path, filename):
    file = open(path + filename, 'w')
    for i in range(len(list)):
        file.write(str(list[i]))
        file.write('\n')
    file.close
def save_model(epoch):
    s_path = 'cache/models/student%d.pth'%(epoch)
    g_path = 'cache/models/generate%d.pth'%(epoch)
    state = {'net': student.state_dict(), 'optimizer':optimizer_S.state_dict(), 'epoch':epochs}
    torch.save(state, s_path)
    state = {'net': generator.state_dict(), 'optimizer':optimizer_G.state_dict(), 'epoch':epochs}
    torch.save(state, g_path)

In [36]:
loss_l, loss_kd_l, loss_ie_l, loss_oh_l, loss_a_l, accr_l = [], [], [], [], [], []
torch.cuda.empty_cache()
for epoch in range(1, epochs+1):
    loss_s, loss_oh_s, loss_a_s, loss_ie_s, loss_kd_s= 0, 0, 0, 0, 0
    for i in range(1, 121):
        student.train()
        generator.train()
        z = Variable(torch.randn(batch_size, latent_dim)).cuda()
        optimizer_G.zero_grad()
        optimizer_S.zero_grad()
        gen_images = generator(z)
        outputs_T, features_T = teacher(gen_images, out_feature=True)
        pred = outputs_T.data.max(1)[1]
        ## 损失1--激活损失
        loss_activation = -features_T.abs().mean()
        ## 损失2
        loss_one_hot = criterion(outputs_T, pred)
        ## 损失3
        softmax_o_T = F.softmax(outputs_T, dim=1).mean(dim=0)
        loss_information_entropy = (softmax_o_T * torch.log10(softmax_o_T)).sum()
        ## 总损失
        loss = loss_one_hot * oh + loss_information_entropy * ie + loss_activation * a

        ## 用于优化学生网络的知识蒸馏损失
        loss_kd = kdloss(student(gen_images.detach()), outputs_T.detach())
        loss += loss_kd
        loss.backward()
        optimizer_G.step()
        optimizer_S.step()
        ## 累计损失
        loss_s += loss.data.item()
        loss_a_s += loss_activation.data.item()
        loss_ie_s += loss_information_entropy.data.item()
        loss_oh_s += loss_one_hot.data.item()
        loss_kd_s += loss_kd.data.item()

        
    
    ## 结束一个epoch
    loss_l.append(loss_s)
    loss_a_l.append(loss_a_s)
    loss_ie_l.append(loss_ie_s)
    loss_oh_l.append(loss_oh_s)
    loss_kd_l.append(loss_kd_s)
    print('Epoch%d, loss%f, loss_oh:%f, loss_ie:%f, loss_a:%f, loss_kd:%f' % (epoch, loss_s, loss_oh_s, loss_ie_s, loss_a_s, loss_kd_s))
    total_corerect = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(data_test_loader, start=1):
            images, labels = images.cuda(), labels.cuda()
            student.eval()
            output = student(images)
            pred = output.data.max(1)[1]
            total_corerect += pred.eq(labels.view_as(pred)).sum()
    accr = float(total_corerect) / len(data_test)
    accr_l.append(accr)
    print('Test Accuracy: %f' % (accr))
    if epoch%20 == 0:
        save_model(epoch)



Epoch1, loss-423.809659, loss_oh:39.549178, loss_ie:-116.175034, loss_a:-170.437260, loss_kd:134.560059
Test Accuracy: 0.481800
Epoch2, loss-498.079248, loss_oh:35.703057, loss_ie:-118.230725, loss_a:-178.568815, loss_kd:75.228198
Test Accuracy: 0.701300
Epoch3, loss-522.322384, loss_oh:34.482052, loss_ie:-118.655312, loss_a:-196.556510, loss_kd:56.127772
Test Accuracy: 0.847800
Epoch4, loss-534.332417, loss_oh:34.483799, loss_ie:-118.676752, loss_a:-181.096145, loss_kd:42.677158
Test Accuracy: 0.911200
Epoch5, loss-540.048560, loss_oh:32.938086, loss_ie:-118.605378, loss_a:-206.691228, loss_kd:40.709370
Test Accuracy: 0.946800
Epoch6, loss-526.139665, loss_oh:37.252659, loss_ie:-116.936448, loss_a:-179.553354, loss_kd:39.245259
Test Accuracy: 0.932700
Epoch7, loss-539.169344, loss_oh:36.853369, loss_ie:-118.631240, loss_a:-177.298605, loss_kd:34.863345
Test Accuracy: 0.946800
Epoch8, loss-537.478224, loss_oh:35.957967, loss_ie:-118.319979, loss_a:-186.753548, loss_kd:36.839061
Test Ac

保存训练数据

In [37]:

save_list(loss_l, test_path, 'loss.txt')
save_list(loss_ie_l, test_path, 'loss_ie.txt')
save_list(loss_a_l, test_path, 'loss_a.txt')
save_list(loss_kd_l, test_path, 'loss_kd.txt')
save_list(loss_oh_l, test_path, 'loss_oh.txt')
save_list(accr_l, test_path, 'accr.txt')


FileNotFoundError: [Errno 2] No such file or directory: '/home/yinzp/workspace/testdata/loss.txt'