In [14]:
!python train.py --dataset cifar10 --out ./snapshots --n_epoch 30

weight:  None
lr:  0.01
n_epoch:  30
batch_size:  64
n_gen:  3
dataset:  cifar10
outdir:  ./snapshots
print_interval:  50
Files already downloaded and verified
train...
epoch: 0, iter: 50, train_loss: 2.99985125541687, val_loss: 2.3925644394698415
epoch: 0, iter: 100, train_loss: 2.299298114776611, val_loss: 2.132373481799083
epoch: 0, iter: 150, train_loss: 2.0798945450782775, val_loss: 2.0386597168673375
epoch: 0, iter: 200, train_loss: 1.9813673377037049, val_loss: 1.9375261068344116
epoch: 0, iter: 250, train_loss: 1.9276276683807374, val_loss: 1.866250036628383
epoch: 0, iter: 300, train_loss: 1.8674716901779176, val_loss: 1.9111368041129628
epoch: 0, iter: 350, train_loss: 1.8574968981742859, val_loss: 1.7868014384227193
epoch: 0, iter: 400, train_loss: 1.7875313544273377, val_loss: 1.7574727148007436
epoch: 0, iter: 450, train_loss: 1.7844722747802735, val_loss: 1.6952189419679582
epoch: 0, iter: 500, train_loss: 1.7273043918609619, val_loss: 1.6567081509122423
epoch: 0, iter: 5

In [1]:
import argparse
import json
import time
from datetime import datetime
import warnings
import os
warnings.filterwarnings("ignore")
import numpy as np

import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.optim as optim
import random
#from logger import SummaryLogger
from torch.utils.tensorboard import SummaryWriter
#import utils
from ban.models.resnet import * 

In [2]:
parser = argparse.ArgumentParser(description='born again for CIFAR10')
parser.add_argument('--text', default='log.txt', type=str)
parser.add_argument('--load_pretrained_teacher', default='./trained/teacher-res50-cifar10.pth', type=str)
parser.add_argument('--exp_name', default='cifar10/stu_res34', type=str)
parser.add_argument('--log_time', default='1', type=str)
parser.add_argument('--lr', default='0.1', type=float)
parser.add_argument('--resume_epoch', default='0', type=int)
parser.add_argument('--epoch', default='163', type=int)
parser.add_argument('--n_gen', default='5', type=int) #决定generation的次数
#parser.add_argument('--n_models', default=[''], type=list)
parser.add_argument('--decay_epoch', default=[82, 123], nargs="*", type=int)
parser.add_argument('--w_decay', default='5e-4', type=float)
parser.add_argument('--cu_num', default='0', type=str)
parser.add_argument('--seed', default='1', type=str)
parser.add_argument('--load_pretrained', default='  ', type=str)
parser.add_argument('--save_model', default='ckpt.t7', type=str)

args = parser.parse_args(args=[])

In [3]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#### random Seed #####
num = random.randint(1, 10000)
random.seed(num)
torch.manual_seed(num)
#####################

os.environ['CUDA_VISIBLE_DEVICES'] = args.cu_num

In [4]:
#Data loader
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
#Other parameters
DEVICE = torch.device("cuda")
RESUME_EPOCH = args.resume_epoch
DECAY_EPOCH = args.decay_epoch
DECAY_EPOCH = [ep - RESUME_EPOCH for ep in DECAY_EPOCH]
FINAL_EPOCH = args.epoch
EXPERIMENT_NAME = args.exp_name
W_DECAY = args.w_decay
base_lr = args.lr

# Model
model_tea = ResNet50()
#print(model_tea)
model_tea.to(DEVICE)

#导入模型
path = args.load_pretrained_teacher
model_tea.load_state_dict(torch.load(path))

NameError: name 'resnet8_cifar' is not defined

In [6]:
def eval(net):
    loader = testloader
    flag = 'Test'

    epoch_start_time = time.time()
    net.eval()
    val_loss = 0

    correct = 0


    total = 0
    criterion_CE = nn.CrossEntropyLoss()

    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs= net(inputs)


        loss = criterion_CE(outputs, targets)
        val_loss += loss.item()

        _, predicted = torch.max(outputs.detach(), 1)


        total += targets.size(0)

        correct += predicted.eq(targets.detach()).cpu().sum().float().item()
        b_idx = batch_idx

    print('%s \t Time Taken: %.2f sec' % (flag, time.time() - epoch_start_time))
    print('Loss: %.3f | Acc net: %.3f%%' % (train_loss / (b_idx + 1), 100. * correct / total))
    return val_loss / (b_idx + 1),  correct / total

In [7]:
def distillation(outputs_stu, labels, outputs_tea, temp, alpha):
    #y: 学生预测的概率分布
    #labels: 实际标签
    #teacher_scores: 老师预测的概率分布
    #temp: 温度系数
    #alpha: 损失调整因子
    criterion = nn.KLDivLoss()
    outputs_S = F.log_softmax(outputs_stu/temp, dim=1) #dim指的是归一化的方式，如果为0是对列做归一化，1是对行做归一化
    outputs_T = F.softmax(outputs_tea/temp, dim=1)
    loss = criterion(outputs_S, outputs_T) * temp * temp * 2.0 * alpha
    
    final_loss = F.cross_entropy(outputs_stu, labels) * (1 - alpha) + loss * alpha 
    return final_loss

In [19]:
def train(model_stu, epoch, model_tea):
    epoch_start_time = time.time()
    print('\n EPOCH: %d' % epoch)
    model_stu.train()
    model_tea.eval()
    
    train_loss = 0
    correct = 0
    total = 0
    global optimizer  

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()

        outputs_tea = model_tea(inputs)
        outputs_stu = model_stu(inputs)
        
        # 计算误差
        loss = distillation(outputs_stu, targets, outputs_tea, temp=5., alpha=0.7)#alpha=.7
        
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  #loss都直接用loss表示的，结果就是每次迭代，空间占用就会增加，直到cpu或者gup爆炸。解决办法：把除了loss.backward()之外的loss调用都改成loss.item()，就可以解决。
        
        _, predicted = torch.max(outputs_stu.detach(), 1)
        total += targets.size(0)

        correct += predicted.eq(targets.detach()).cpu().sum().float().item()


        b_idx = batch_idx

    print('Train s1 \t Time Taken: %.2f sec' % (time.time() - epoch_start_time))
    print('Loss: %.3f | Acc net: %.3f%%|' % (train_loss / (b_idx + 1), 100. * correct / total))
    return train_loss / (b_idx + 1), correct / total

In [8]:
time_log = datetime.now().strftime('%m-%d %H:%M')
if int(args.log_time) :
    folder_name = 'teacher_{}'.format(time_log)
path = os.path.join(EXPERIMENT_NAME, folder_name)

if not os.path.exists('born-again/' + path):
    os.makedirs('born-again/' + path)
if not os.path.exists('logs/' + path):
    os.makedirs('logs/' + path)

# Save argparse arguments as logging
with open('logs/{}/commandline_args.txt'.format(path), 'w') as f:
    json.dump(args.__dict__, f, indent=2)

# Instantiate logger
#logger = SummaryLogger(path)

In [9]:
best_model = './born-again/cifar10/stu_res34/gen2/Model_149.pth'
#best_model1 = './born-again/cifar10/stu_res34/gen0/Model_159.pth'
#best_model2 = './born-again/cifar10/stu_res34/gen1/Model_161.pth' #需要不停置换
for gen in range(3,5):
    best_acc = 0 #这里每个gen都要更新一次
    
    print("This is the {}-th generation".format(gen))
    model_stu = ResNet34()
    model_stu.to(DEVICE)

    if gen >= 1:
        model_tea = ResNet34()
        model_tea.to(DEVICE)
        model_tea.load_state_dict(torch.load(best_model))

    # Loss and Optimizer
    optimizer = optim.SGD(model_stu.parameters(), lr=base_lr, momentum=0.9, weight_decay=W_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=DECAY_EPOCH, gamma=0.1)
    criterion_CE = nn.CrossEntropyLoss()
    
    for epoch in range(RESUME_EPOCH, FINAL_EPOCH+1):
        f = open(os.path.join("./born-again/cifar10", 'log.txt'), "a")

        ### Train ###
        
        train_loss, acc = train(model_stu, epoch, model_tea)
        
        scheduler.step()

        ### Evaluate  ###
        val_loss, test_acc  = eval(model_stu)
        if best_acc < test_acc:
            best_acc = test_acc
            torch.save(model_stu.state_dict(), './born-again/cifar10/stu_res34/gen{}/Model_{}.pth'.format(gen, epoch))
            
            best_model = './born-again/cifar10/stu_res34/gen{}/Model_{}.pth'.format(gen, epoch)
            print("best_model:", best_model)

        f.write('EPOCH {epoch} \t'
                'ACC_net : {acc_net:.4f} \t  \n'.format(epoch=epoch, acc_net=test_acc)
                )
        f.close()

"""
utils.save_checkpoint({
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
}, True, 'ckpt/' + path, filename='Model_{}.pth'.format(epoch))
"""

This is the 3-th generation


NameError: name 'train' is not defined

In [67]:
def load_checkpoint(model, checkpoint):
    m_keys = list(model.state_dict().keys())

    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        c_keys = list(checkpoint['state_dict'].keys())
        not_m_keys = [i for i in c_keys if i not in m_keys]
        not_c_keys = [i for i in m_keys if i not in c_keys]
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    else:
        c_keys = list(checkpoint.keys())
        not_m_keys = [i for i in c_keys if i not in m_keys]
        not_c_keys = [i for i in m_keys if i not in c_keys]
        model.load_state_dict(checkpoint, strict=False)
        
def eval_multi(net1, net2, net3, net4):
    loader = testloader
    flag = 'Test'

    epoch_start_time = time.time()
    net1.eval()
    net2.eval()
    net3.eval()
    '''
   
    
    net4.eval()
    
    net5.eval()

    '''
    val_loss = 0

    correct = 0


    total = 0
    criterion_CE = nn.CrossEntropyLoss()

    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs1= net1(inputs)
        outputs2= net2(inputs)
        outputs3= net3(inputs)
        outputs4= net4(inputs)
        '''    
        
        
         
        
        outputs5= net5(inputs)

        '''

        #outputs = (outputs1  + outputs2 + outputs3 + outputs4 + outputs5) / 5
        outputs = (outputs1 + outputs2 + outputs3 + outputs4 ) /4
        loss = criterion_CE(outputs, targets)
        val_loss += loss.item()

        _, predicted = torch.max(outputs.detach(), 1)


        total += targets.size(0)

        correct += predicted.eq(targets.detach()).cpu().sum().float().item()
        b_idx = batch_idx

    print('%s \t Time Taken: %.2f sec' % (flag, time.time() - epoch_start_time))
    #print('Loss: %.3f | Acc net: %.3f%%' % (train_loss / (b_idx + 1), 100. * correct / total))
    print('Loss: %.3f | Acc net: %.3f%%' % (1 / (b_idx + 1), 100. * correct / total))
    return val_loss / (b_idx + 1),  correct / total
model1 = ResNet34()
model2 = ResNet34()
model3 = ResNet34()
model4 = ResNet34()
model5 = ResNet34()

state1 = torch.load('./born-again/cifar10/stu_res34/1/gen0/Model_150.pth')
state2 = torch.load('./born-again/cifar10/stu_res34/1/gen1/Model_196.pth')
state3 = torch.load('./born-again/cifar10/stu_res34/1/gen2/Model_152.pth')
state4 = torch.load('./born-again/cifar10/stu_res34/1/gen3/Model_154.pth')
state5 = torch.load('./born-again/cifar10/stu_res34/1/gen4/Model_145.pth')
#state6 = torch.load('./born-again/cifar10/stu_res34/gen0-3/gen0/Model_159.pth')
#state7 = torch.load('./born-again/cifar10/stu_res34/gen0-3/gen1/Model_161.pth')
#state8 = torch.load('./born-again/cifar10/stu_res34/gen0-3/gen2/Model_162.pth')
load_checkpoint(model1, state1)
load_checkpoint(model2, state2)
load_checkpoint(model3, state3)
load_checkpoint(model4, state4)
load_checkpoint(model5, state5)
#load_checkpoint(model6, state6)
#load_checkpoint(model7, state7)
#load_checkpoint(model8, state8)
model1.to(DEVICE)
model2.to(DEVICE)
model3.to(DEVICE)
model4.to(DEVICE)
model5.to(DEVICE)
#model6.to(DEVICE)
#model7.to(DEVICE)
#model8.to(DEVICE)

val_loss, test_acc  = eval_multi(model2, model5, model1, model4)

Test 	 Time Taken: 7.91 sec
Loss: 0.010 | Acc net: 96.210%


Test 	 Time Taken: 3.50 sec
Loss: 0.000 | Acc net: 95.350%
Test 	 Time Taken: 3.51 sec
Loss: 0.000 | Acc net: 95.520%
Test 	 Time Taken: 3.58 sec
Loss: 0.000 | Acc net: 95.420%
