In [None]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.datasets.mnist import MNIST
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torchvision.datasets.imagenet import ImageNet
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import collections
import random

Error: Kernel is dead

# 模型搭建

## 教师模型

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out =F.relu(out)
        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out, feature
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
 
def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3,4,6,3], num_classes)
 
def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)
 
def ResNet101(num_classes=10):
    return ResNet(Bottleneck, [3,4,23,3], num_classes)
 
def ResNet152(num_classes=10):
    return ResNet(Bottleneck, [3,8,36,3], num_classes)

# 模型训练

## 教师训练类搭建

In [None]:
class TeacherTrainer:
	def __init__(self, path_ckpt, path_loss, path_dataset, name_dataset='cifar10', 
				 bs=512, num_epochsaving=100, resume_train=False, path_resume=None):
		if name_dataset == 'cifar10':
			transform_train = transforms.Compose([
				transforms.RandomCrop(32, padding=4),
				transforms.RandomHorizontalFlip(),
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
				])
			transform_test = transforms.Compose([
				transforms.ToTensor(),
				transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
			])
			self.dataset_train = CIFAR10(path_dataset, transform=transform_train)
			self.dataset_test = CIFAR10(path_dataset, train=False, transform=transform_test)
			self.dataset_test_loader = DataLoader(self.dataset_test, batch_size=512, num_workers=0)
			self.dataset_train_loader = DataLoader(self.dataset_train, batch_size=bs, shuffle=True, num_workers=8)
			
			
			torch.cuda.set_device('cuda:0')
			net = ResNet34().cuda()
			self.net = nn.DataParallel(net,device_ids=[0,1,2,3], output_device=0)
			self.criterion = torch.nn.CrossEntropyLoss().cuda()
			self.optimizer = torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
			self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=5, T_mult=2, eta_min=0, last_epoch=-1)
			self.last_epoch = 0
			if resume_train:
				ckpt = torch.load(path_resume)
				net.load_state_dict(ckpt['net'])
				self.optimizer.load_state_dict(ckpt['optimizer'])
				self.lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
				self.last_epoch = ckpt['epoch']
			
			
		
		# 训练相关参数
		self.best_accr = 0
		self.list_loss = []
		self.path_ckpt = path_ckpt
		self.path_loss = path_loss
		self.num_epochsaving = num_epochsaving
	def train(self, epochs):
		
		for epoch in range(self.last_epoch+1, epochs+1):
			self.net.train()
			loss_epoch = 0
			for i, (batch_img, batch_label) in enumerate(self.dataset_train_loader, start=1):
				batch_img, batch_label = Variable(batch_img).cuda(non_blocking=True), Variable(batch_label).cuda(non_blocking=True)
				self.optimizer.zero_grad()
				output = self.net(batch_img)
				loss = self.criterion(output, batch_label)
				loss.backward()
				self.optimizer.step()
				loss_epoch += loss.data.item()
			# 一个epoch结束
			self.lr_scheduler.step()
			self.list_loss.append(loss_epoch)
			print('EPOCH:%d, LOSS:%f'%(epoch, loss_epoch))
			# 测试
			self.test(epoch)
		self.save_experiment(epochs)
		
	
	def adjust_lr(self, epoch):
		if epoch < 80:
			lr = 0.1
		if epoch < 120:
			lr = 0.01
		else:
			lr = 0.001
		for param_group in self.optimizer.param_groups:
			param_group['lr'] = lr

	def test(self, epoch):
		self.net.eval()
		total_correct = 0
		with torch.no_grad():
			for i, (images, labels) in enumerate(self.dataset_test_loader, start=1):
				images, labels = Variable(images).cuda(), Variable(labels).cuda()
				output = self.net(images)
				pred = output.data.max(1)[1]
				total_correct += pred.eq(labels.data.view_as(pred)).sum()
            
		acc = float(total_correct) / len(self.dataset_test)
		if acc > self.best_accr:
			self.best_accr = acc
			if epoch > self.num_epochsaving:
				self.save_model(self.path_ckpt, epoch)
		
		print('Test Accuracy:%f' % (acc))

	def save_model(self, path, epoch):
		state = {
			'net': self.net.state_dict(), 
			'optimizer':self.optimizer.state_dict(), 
			'lr_scheduler':self.lr_scheduler.state_dict(),
			'epoch':epoch}
		filename = path + 'teacher__accr%f_epoch%d.pth'%(self.best_accr, epoch)
		torch.save(state, filename)
	def save_experiment(self, epochs):
		lossfile = np.array(self.list_loss)
		np.save(self.path_loss + '/teacher_loss_{}'.format(epochs), lossfile)

## 训练

## deepinversion

In [None]:
class InputTrainer:
    def __init__(self, path_ckpt_teacher, path_inputs_saving):
        # teacher and student net
        ckpt = torch.load(path_ckpt_teacher)
        print('Using trained teacher network...')
        self.teacher = resnet.ResNet34().cuda()
        self.teacher.load_state_dict(ckpt['net'])
        self.student = resnet.ResNet18().cuda()
        self.teacher.eval()
        self.student.eval()
        
        self.inputs = torch.randn((256, 3, 32, 32), requires_grad=True, device='cuda', dtype=torch.float)
        self.criterion = nn.CrossEntropyLoss().cuda()
        self.optimizer_in = torch.optim.Adam([self.inputs], lr=0.05)
        self.optimizer_in.state = collections.defaultdict(dict)
        #arguments
        self.last_epoch = 0
        self.loss_r_feature = list()
        self.path_inputs_saving = path_inputs_saving
        
    def train(self, epochs, l2_coeff=0.0, var_scale=2.5e-5, bn_reg_scale=10):
        # batch size 250+6
        targets = torch.LongTensor([0,1,2,3,4,5,6,7,8,9]*25 + [0,1,2,3,4,5]).cuda()
        largest_loss = 1e6
        for epoch in range(self.last_epoch+1, epochs+1):
            off1 = random.randint(-2,-2)
            off2 = random.randint(-2,-2)
            # shift the input along the second(w) and third(h) axis by a random number
            inputs_jit = torch.roll(self.inputs, shifts=(off1, off2), dims=(2,3))
            
            self.optimizer_in.zero_grad()
            self.teacher.zero_grad()
            outputs = self.teacher(inputs_jit)
            loss = self.criterion(outputs, targets)
            
            # apply total variation regularization
            diff1 = inputs_jit[:,:,:,:-1] - inputs_jit[:,:,:,1:]
            diff2 = inputs_jit[:,:,:-1,:] - inputs_jit[:,:,1:,:]
            diff3 = inputs_jit[:,:,1:,:-1] - inputs_jit[:,:,:-1,1:]
            diff4 = inputs_jit[:,:,:-1,:-1] - inputs_jit[:,:,1:,1:]
            loss_var = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
            loss = loss + var_scale*loss_var

            # 每个BN层的running_mean和running_var与新输入的差值
            loss_distr = sum([r_feature for r_feature in self.loss_r_feature])
            loss = loss + bn_reg_scale*loss_distr

            # l2 loss
            loss = loss + l2_coeff*torch.norm(inputs_jit, 2)
            print(epoch, loss.item())
            # 采用这种保存方法，节省保存次数
            if largest_loss > loss.item():
                largest_loss = loss.item()
                best_inputs = self.inputs.data
            loss.backward()
            self.optimizer_in.step()
        vutils.save_image(best_inputs[:20].clone(), self.path_inputs_saving, 
                          normalize=True, scale_each=True, nrow=10)
            
    def hook_fn(self, module, input_data, output_data):
        num_ch = input_data[0].shape[1]
        mean = input_data[0].mean([0,2,3])
        var = input_data[0].permute(1,0,2,3).contiguous().view([num_ch, -1]).var(1, unbiased=False)
        r_feature = torch.norm(module.running_var.data.type(var.type()) - var, 2) + torch.norm(module.running_mean.data.type(mean.type()) - mean, 2)
        self.loss_r_feature.append(r_feature)

## 开始训练

### 训练教师网络

In [None]:
path_current = os.getcwd()
path_ckpt = os.path.join(path_current, 'cache/models/teacher/')
path_loss = os.path.join(path_current,'cache/experimental_data/')
path_cifar = "/home/ubuntu/datasets/"
train_teacher = TeacherTrainer(path_ckpt, path_loss, path_cifar, bs=2048, num_epochsaving=150)
train_teacher.train(400)

EPOCH:1, LOSS:68.896122
Test Accuracy:0.113400
EPOCH:2, LOSS:48.271791
Test Accuracy:0.289700
EPOCH:3, LOSS:44.163279
Test Accuracy:0.356200
EPOCH:4, LOSS:41.442758
Test Accuracy:0.396000
EPOCH:5, LOSS:39.679080
Test Accuracy:0.422000
EPOCH:6, LOSS:48.267901
Test Accuracy:0.304400
EPOCH:7, LOSS:42.839829
Test Accuracy:0.379700
EPOCH:8, LOSS:39.221532
Test Accuracy:0.357800
EPOCH:9, LOSS:36.694259
Test Accuracy:0.439000
EPOCH:10, LOSS:34.455036
Test Accuracy:0.467500
EPOCH:11, LOSS:32.197157
Test Accuracy:0.515300
EPOCH:12, LOSS:30.074572
Test Accuracy:0.577400
EPOCH:13, LOSS:28.109925
Test Accuracy:0.584400
EPOCH:14, LOSS:26.505617
Test Accuracy:0.628200
EPOCH:15, LOSS:25.531642
Test Accuracy:0.633900
EPOCH:16, LOSS:46.982348
Test Accuracy:0.345700
EPOCH:17, LOSS:41.582078
Test Accuracy:0.399900
EPOCH:18, LOSS:38.344050
Test Accuracy:0.450700
EPOCH:19, LOSS:35.624523
Test Accuracy:0.440200
EPOCH:20, LOSS:33.561480
Test Accuracy:0.502700
EPOCH:21, LOSS:31.083786
Test Accuracy:0.553700
E

### 训练输入

In [35]:
print('start...')
path_current = os.getcwd()

path_ckpt_teacher = os.path.join(path_current, 'paper-reading/DeepInversion/cache/models/teacher/teacher')
print(path_ckpt_teacher)
path_save_img = os.path.join(path_current, 'paper-reading/DeepInversion/cache/best_img.jpg')

# ckpt = torch.load(path_ckpt_teacher)
# net = ResNet34().cuda()
# net.load_state_dict(ckpt['net'])
trainer = InputTrainer('/home/ubuntu/YZP/gitee/paper-reading/DeepInversion/cache/models/teacher/teacher', path_save_img)
trainer.train(1000)

/home/ubuntu/YZP/gitee/paper-reading/DeepInversion/cache/models/teacher/teacher
Using trained teacher network...


TypeError: randn() received an invalid combination of arguments - got (tuple, dtype=torch.dtype, required_grad=bool, device=str), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
