In [10]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.datasets.imagenet import ImageNet
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np

# 模型搭建

## 教师模型

In [11]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out =F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out, feature
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
 
def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)
 
def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)
 
def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes)
 
def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes)

# 模型训练

## 教师训练类搭建

In [12]:
class TeacherTrainer:
	def __init__(self, path_ckpt, path_loss, path_dataset, name_dataset='cifar10'):
		if name_dataset == 'cifar10':
			transform_train = transforms.Compose([
				transforms.RandomCrop(32, padding=4),
				transforms.RandomHorizontalFlip(),
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
				])
			transform_test = transforms.Compose([
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
			])
			self.dataset_train = CIFAR10(path_dataset, transform=transform_train)
			self.dataset_test = CIFAR10(path_dataset, train=False, transform=transform_test)
			self.dataset_test_loader = DataLoader(self.dataset_test, batch_size=100, num_workers=0)
			self.dataset_train_loader = DataLoader(self.dataset_train, batch_size=128, shuffle=True, num_workers=8)
			self.net = ResNet34().cuda()
			self.criterion = torch.nn.CrossEntropyLoss().cuda()
			self.optimizer = torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
		
		# 训练相关参数
		self.best_accr = 0
		self.list_loss = []
		self.path_ckpt = path_ckpt
		self.path_loss = path_loss
	def train(self, epochs):
		self.net.train()
		for epoch in range(1, epochs+1):
			loss_epoch = 0
			for i, (batch_img, batch_label) in enumerate(self.dataset_train_loader, start=1):
				batch_img, batch_label = Variable(batch_img).cuda(), Variable(batch_label).cuda()
				self.optimizer.zero_grad()
				output = self.net(batch_img)
				loss = self.criterion(output, batch_label)
				loss.backward()
				self.optimizer.step()
				loss_epoch += loss.data.item()
			# 一个epoch结束
			self.adjust_lr(epoch)
			self.list_loss.append(loss_epoch)
			print('EPOCH:%d, LOSS:%f'%(epoch, loss_epoch))
			# 测试
			self.test(epoch)
		self.save_experiment(epochs)
		
	
	def adjust_lr(self, epoch):
		if epoch < 80:
			lr = 0.1
		if epoch < 120:
			lr = 0.01
		else:
			lr = 0.001
		for param_group in self.optimizer.param_groups:
			param_group['lr'] = lr

	def test(self, epoch):
		self.net.eval()
		total_correct = 0
		with torch.no_grad():
			for i, (images, labels) in enumerate(self.dataset_test_loader, start=1):
				images, labels = Variable(images).cuda(), Variable(labels).cuda()
				output = self.net(images)
				pred = output.data.max(1)[1]
				total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
		acc = float(total_correct) / len(self.dataset_test)
		if acc > self.best_accr:
			self.best_accr = acc
			self.save_model(self.path_ckpt, epoch)
		
		print('Test Accuracy:%f' % (acc))

	def save_model(self, path, epoch):
		state = {'net': self.net.state_dict(), 'optimizer':self.optimizer.state_dict(), 'epoch':epoch}
		filename = path + 'teacher__accr%f_epoch%d.pth'%(self.best_accr, epoch)
		torch.save(state, filename)
	def save_experiment(self, epochs):
		lossfile = np.array(self.list_loss)
		np.save(self.path_loss + '/teacher_loss_{}'.format(epochs), lossfile)

## 训练

In [13]:
path_current = os.getcwd()
path_ckpt = os.path.join(path_current, 'cache/models/teacher/')
path_loss = os.path.join(path_current,'cache/experimental_data/')
path_cifar = "/home/yinzp/workspace/dataset/cifar-10-python"
train_teacher = TeacherTrainer(path_ckpt, path_loss, path_cifar)
train_teacher.train(200)

EPOCH:1, LOSS:783.125555
Test Accuracy:0.403300
EPOCH:2, LOSS:603.690376
Test Accuracy:0.435400
EPOCH:3, LOSS:576.128691
Test Accuracy:0.469500
EPOCH:4, LOSS:558.985040
Test Accuracy:0.502200
EPOCH:5, LOSS:537.809400
Test Accuracy:0.516200
EPOCH:6, LOSS:518.755001
Test Accuracy:0.523300
EPOCH:7, LOSS:503.027959
Test Accuracy:0.542300
EPOCH:8, LOSS:483.501203
Test Accuracy:0.563400
EPOCH:9, LOSS:466.483572
Test Accuracy:0.590600
EPOCH:10, LOSS:447.249925
Test Accuracy:0.569200
EPOCH:11, LOSS:432.437818
Test Accuracy:0.614100
EPOCH:12, LOSS:416.612284
Test Accuracy:0.640500
EPOCH:13, LOSS:402.118391
Test Accuracy:0.616800
EPOCH:14, LOSS:385.008771
Test Accuracy:0.665600
EPOCH:15, LOSS:374.456891
Test Accuracy:0.667100
EPOCH:16, LOSS:363.665389
Test Accuracy:0.653800
EPOCH:17, LOSS:354.664164
Test Accuracy:0.686000
EPOCH:18, LOSS:341.136677
Test Accuracy:0.677200
EPOCH:19, LOSS:334.333832
Test Accuracy:0.712000
EPOCH:20, LOSS:328.903583
Test Accuracy:0.713200
EPOCH:21, LOSS:324.452787
Tes

KeyboardInterrupt: 