# import modules

In [1]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.datasets.imagenet import ImageNet
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import collections
import random

# model structure definition

# ResNet modules

In [2]:
class BasicBlock(nn.Module):
    expansion = 1
 
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
 
class Bottleneck(nn.Module):
    expansion = 4
 
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
 
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
 
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
 
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
 
    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out,feature
 
 
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
 
def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)
 
def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)
 
def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes)
 
def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes)
 


# Model Trainer

In [3]:
class TeacherTrainer:
	def __init__(self, path_ckpt, path_loss, path_dataset, name_dataset='cifar10', 
				 bs=512, num_epochsaving=100, resume_train=False, path_resume=None):
		if name_dataset == 'cifar10':
			transform_train = transforms.Compose([
				transforms.RandomCrop(32, padding=4),
				transforms.RandomHorizontalFlip(),
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
				])
			transform_test = transforms.Compose([
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
			])
			self.dataset_train = CIFAR10(path_dataset, transform=transform_train)
			self.dataset_test = CIFAR10(path_dataset, train=False, transform=transform_test)
			self.dataset_test_loader = DataLoader(self.dataset_test, batch_size=100, num_workers=0)
			self.dataset_train_loader = DataLoader(self.dataset_train, batch_size=128, shuffle=True, num_workers=8)
			
			
			torch.cuda.set_device('cuda:0')
			self.net = ResNet34().cuda()
			#self.net = nn.DataParallel(self.net,device_ids=[0,1,2,3], output_device=0)
			self.criterion = torch.nn.CrossEntropyLoss().cuda()
			self.optimizer = torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
			#self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=5, T_mult=2, eta_min=0, last_epoch=-1)
			self.last_epoch = 0
			if resume_train:
				ckpt = torch.load(path_resume)
				net.load_state_dict(ckpt['net'])
				self.optimizer.load_state_dict(ckpt['optimizer'])
				self.lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
				self.last_epoch = ckpt['epoch']
			
			
		
		# 训练相关参数
		self.best_accr = 0
		self.list_loss = []
		self.path_ckpt = path_ckpt
		self.path_loss = path_loss
		self.num_epochsaving = num_epochsaving
	def train(self, epochs):
		
		for epoch in range(self.last_epoch+1, epochs+1):
			self.net.train()
			loss_epoch = 0
			for i, (batch_img, batch_label) in enumerate(self.dataset_train_loader, start=1):
				batch_img, batch_label = Variable(batch_img).cuda(), Variable(batch_label).cuda()
				self.optimizer.zero_grad()
				output = self.net(batch_img)
				loss = self.criterion(output, batch_label)
				loss.backward()
				self.optimizer.step()
				loss_epoch += loss.data.item()
			# 一个epoch结束
			#self.lr_scheduler.step()
			self.adjust_lr(epoch)
			self.list_loss.append(loss_epoch)
			print('Train-Epoch:%d, Loss:%f'%(epoch, loss_epoch))
			# 测试
			self.test(epoch)
		self.save_experiment(epochs)
		
	
	def adjust_lr(self, epoch):
		if epoch < 80:
			lr = 0.1
		if epoch < 120:
			lr = 0.01
		else:
			lr = 0.001
		for param_group in self.optimizer.param_groups:## ??
			param_group['lr'] = lr

	def test(self, epoch):
		self.net.eval()
		total_correct = 0
		with torch.no_grad():
			for i, (images, labels) in enumerate(self.dataset_test_loader, start=1):
				images, labels = Variable(images).cuda(), Variable(labels).cuda()
				output = self.net(images)
				pred = output.data.max(1)[1]
				total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
		acc = float(total_correct) / len(self.dataset_test)
		if acc > self.best_accr:
			self.best_accr = acc
			if epoch > self.num_epochsaving:
				self.save_model(self.path_ckpt, epoch)
		
		print('Test-Accuracy:%f' % (acc))

	def save_model(self, path, epoch):
		state = {
			'net': self.net.state_dict(), 
			'optimizer':self.optimizer.state_dict(), 
			#'lr_scheduler':self.lr_scheduler.state_dict(),
			'epoch':epoch}
		filename = path + 'teacher__accr%f_epoch%d.pth'%(self.best_accr, epoch)
		torch.save(state, filename)
	def save_experiment(self, epochs):
		lossfile = np.array(self.list_loss)
		np.save(self.path_loss + '/teacher_loss_{}'.format(epochs), lossfile)

In [4]:
path_current = os.getcwd()
path_ckpt = os.path.join(path_current, 'cache/models/teacher/')
path_loss = os.path.join(path_current,'cache/experimental_data/')
path_cifar = "/home/ubuntu/datasets/"
train_teacher = TeacherTrainer(path_ckpt, path_loss, path_cifar, bs=128, num_epochsaving=150)
train_teacher.train(310)

Train-Epoch:1, Loss:828.806569
Test-Accuracy:0.376300
Train-Epoch:2, Loss:638.174941
Test-Accuracy:0.440600
Train-Epoch:3, Loss:591.991789
Test-Accuracy:0.472200
Train-Epoch:4, Loss:553.309129
Test-Accuracy:0.494200
Train-Epoch:5, Loss:520.737592
Test-Accuracy:0.533500
Train-Epoch:6, Loss:485.463630
Test-Accuracy:0.582400
Train-Epoch:7, Loss:448.980289
Test-Accuracy:0.614300
Train-Epoch:8, Loss:413.368319
Test-Accuracy:0.649300
Train-Epoch:9, Loss:380.640997
Test-Accuracy:0.647500
Train-Epoch:10, Loss:353.763481
Test-Accuracy:0.679700
Train-Epoch:11, Loss:326.419531
Test-Accuracy:0.722100
Train-Epoch:12, Loss:305.624478
Test-Accuracy:0.709900
Train-Epoch:13, Loss:285.205608
Test-Accuracy:0.754300
Train-Epoch:14, Loss:267.143549
Test-Accuracy:0.747900
Train-Epoch:15, Loss:252.011575
Test-Accuracy:0.773600
Train-Epoch:16, Loss:237.373681
Test-Accuracy:0.780900
Train-Epoch:17, Loss:223.197099
Test-Accuracy:0.784000
Train-Epoch:18, Loss:210.005565
Test-Accuracy:0.810800
Train-Epoch:19, Los

# test model accuracy

In [5]:
net = ResNet34().cuda()
path_cifar = "/home/ubuntu/datasets/"
ckpt = torch.load('/home/ubuntu/YZP/gitee/paper-reading/DeepInversion/cache/models/teacher/teacher__accr0.938000_epoch200.pth')
#print(ckpt['epoch'])
net.load_state_dict(ckpt['net'])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

data_train = CIFAR10(path_cifar,
                    transform=transform_train)
data_test = CIFAR10(path_cifar,
                    train=False,
                    transform=transform_test)

data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=100, num_workers=0)
criterion = nn.CrossEntropyLoss().cuda()
def test():
    global acc, acc_best
    net.eval()
    total_correct = 0
    avg_loss = 0.0
    with torch.no_grad():
        for i, (images, labels) in enumerate(data_test_loader):
            images, labels = Variable(images).cuda(), Variable(labels).cuda()
            output = net(images)
            avg_loss += criterion(output, labels).sum()
            pred = output.data.max(1)[1]
            total_correct += pred.eq(labels.data.view_as(pred)).sum()
 
    avg_loss /= len(data_test)
    acc = float(total_correct) / len(data_test)
    #if acc_best < acc:
    #    acc_best = acc
    print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), acc))
test()


Test Avg. Loss: 0.002731, Accuracy: 0.938000
