In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import torch.nn.functional as F
import time
from torch.optim.lr_scheduler import StepLR
import myResnet
import os
import math
# from AE import AE_conv
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
###hyper-parameters settings

batch_size = 128
lr = 0.1
num_epochs = 200

dataset = 'cifar100'
t_name = 'resnet101'
s_name = 'resnet18'

t_path = '../saved_model/' + dataset + '_' + t_name + '_Fitnet_32*32.pkl'
s_path = 'saved_model/S' + s_name[-2:] + '_T' + t_name[-3:] + '_' + dataset + '.pkl'
curve_path = 'saved_curve/S' + s_name[-2:] + '_T' + t_name[-3:] + '_' + dataset + '_curve.jpg'
AE4_path = 'AE_conv_' + t_name + '_' + dataset + 'features4.pkl'
AE3_path = 'AE_conv_' + t_name + '_' + dataset + 'features3.pkl'

print(t_path)
print(s_path)
print(curve_path)

../saved_model/cifar100_resnet101_Fitnet_32*32.pkl
saved_model/S18_T101_cifar100.pkl
saved_curve/S18_T101_cifar100_curve.jpg


In [3]:
train_process = transforms.Compose([transforms.RandomCrop(32, padding=4),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(), 
                              transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
test_process = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
if dataset == 'cifar100':
    train_data = datasets.CIFAR100(root='/root/userfolder/xxx/cifar100', transform=train_process, train=True, download=True)
    test_data = datasets.CIFAR100(root='/root/userfolder/xxx/cifar100', transform=test_process, train=False, download=True)
    num_classes = 100
elif dataset == 'cifar10':
    train_data = datasets.CIFAR10(root='/root/userfolder/xxx/cifar10', transform=train_process, train=True, download=True)
    test_data = datasets.CIFAR10(root='/root/userfolder/xxx/cifar10', transform=test_process, train=False, download=True)
    num_classes = 10
train_dataLoader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataLoader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
Teacher = torch.load(t_path)
#Student = torchvision.models.resnet34()
if(s_name == 'resnet18'):
    Student = myResnet.resnet18(num_classes=num_classes)
    Student.fc = nn.Linear(128, num_classes)
    Student.branch1 = myResnet.conv1x1(16, 128)
    Student.branch2 = myResnet.conv1x1(32, 256)
    Student.branch3 = myResnet.conv1x1(64, 512)
    Student.branch4 = myResnet.conv1x1(128, 1024)
elif(s_name == 'resnet34'):
    Student = myResnet.resnet34(num_classes=num_classes)
    Student.fc = nn.Linear(128, num_classes)
    Student.branch1 = myResnet.conv1x1(16, 128)
    Student.branch2 = myResnet.conv1x1(32, 256)
    Student.branch3 = myResnet.conv1x1(64, 512)
    Student.branch4 = myResnet.conv1x1(128, 1024)
elif(s_name == 'resnet50'):
    Student = myResnet.resnet50(num_classes=num_classes)
    Student.fc = nn.Linear(512, num_classes)
    Student.branch1 = myResnet.conv1x1(64, 128)
    Student.branch2 = myResnet.conv1x1(128, 256)
    Student.branch3 = myResnet.conv1x1(256, 512)
    Student.branch4 = myResnet.conv1x1(512, 1024)

# AE4 = AE_conv(1024, 4)
# AE4.load_state_dict(torch.load(AE4_path))
# AE3 = AE_conv(512, 2)
# AE3.load_state_dict(torch.load(AE3_path))
print('#params of Student:', sum([para.numel() for para in Student.parameters()]))
for para in Teacher.parameters():
    para.requires_grad = False

#params of Student: 1329760


In [5]:
###transfer to GPU
Teacher = Teacher.cuda()
Student = Student.cuda()
# AE4 = AE4.cuda()
# AE3 = AE3.cuda()

In [6]:
def test(model):
    model.eval()
    test_acc = 0
    for img, label in test_dataLoader:
        img = img.cuda()
        label = label.cuda()
        out = model(img)
        out = out[0]
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        test_acc += num_correct.data.item()
    return test_acc/len(test_data)

def correlation(feature):
    CR = []
    for i in range(feature.size(0)):
        tmp = feature[i].reshape(1,feature[i].size(0), feature[i].size(1), feature[i].size(2))
        tmp = tmp.expand(feature.size(0), -1, -1, -1)
#         print((tmp**2).sum(dim=[1,2,3]).sqrt())
        cos_sim = (tmp*feature).sum(dim=[1,2,3]) / (tmp.detach()**2).sum(dim=[1,2,3]).sqrt() / (feature.detach()**2).sum(dim=[1,2,3]).sqrt()
        CR.append(cos_sim)
    CR = torch.stack(CR)
    return CR

def cluster(feature, label, num_classes):
    length = feature.size(0)
    center_list = []
    res = []
    valid_center_list = []
    for i in range(num_classes):
        center_list.append([])
    for i in range(length):
        center_list[label[i]].append(feature[i])
    for i in range(num_classes):
        if(len(center_list[i]) > 0):
            center_list[i] = sum(center_list[i]) / len(center_list[i])
            valid_center_list.append(center_list[i])
#     print(center_list[label[0]].size())
    for i in range(length):
        res.append(center_list[label[i]])
    
    return torch.stack(res), torch.stack(valid_center_list)

def cluster_filter(feature, pred, label, num_classes, mode='intra'):  
    length = feature.size(0)
    center_list = []
    res = []
    valid_center_list = []
    for i in range(num_classes):
        center_list.append([])
    for i in range(length):
        if(pred[i] == label[i]):
            center_list[label[i]].append(feature[i])
    for i in range(num_classes):
        if(len(center_list[i]) > 0):
            center_list[i] = sum(center_list[i]) / len(center_list[i])
            valid_center_list.append(center_list[i])
#     print(center_list[label[0]].size())
    for i in range(length):
        if(len(center_list[label[i]]) > 0):
            res.append(center_list[label[i]])
        else:
            res.append(feature[i])
#     print(len(res), len(valid_center_list))
    if mode == 'intra':
        return torch.stack(res)#, torch.stack(valid_center_list)
    elif(len(valid_center_list) > 0):
        return torch.stack(valid_center_list)
    else:
        return []

def boundary_loss(T_mean, S_feature, label, criterion):
    res = []
    balance = 0.8
    for i in range(S_feature.size(0)):
        f = S_feature[i].reshape(1, S_feature[i].size(0), S_feature[i].size(1), S_feature[i].size(2))
        f.expand(S_feature.size(0), -1, -1, -1)
        diff = (T_mean - f).abs().sum(dim=[1,2,3])
        _, idx = torch.min(diff, -1)
        if label[idx] != label[i]:
            res.append(balance*((T_mean[i].detach()-S_feature[i])**2).sum()/T_mean.size(1)/T_mean.size(2)/T_mean.size(3))
        else:
            res.append((1-balance)*((T_mean[i].detach()-S_feature[i])**2).sum()/T_mean.size(1)/T_mean.size(2)/T_mean.size(3))
    if(len(res) == 0):
        return 0.0
    return torch.stack(res).sum()/len(res)


def correlation_logits(feature):
    CR = []
    for i in range(feature.size(0)):
        tmp = feature[i].reshape(1,feature[i].size(0))
        tmp = tmp.expand(feature.size(0), -1)   
        cos_sim = (tmp*feature).sum(dim=-1) / (tmp.detach()**2).sum(dim=-1).sqrt() / (feature.detach()**2).sum(dim=-1).sqrt()
        CR.append(cos_sim)
    CR = torch.stack(CR)
    return CR

def reg_loss(feature, label, num_classes):
    center, _ = cluster(feature, label, num_classes)
    intra_var_loss = MSE(feature, center.detach())
    
    t_center = center.sum(dim=0)
    t_center = t_center.reshape(1, center.size(1))
    
    t_center = t_center.expand(center.size(0), -1)
    inter_var_loss = MSE(center, t_center.detach())
    
    return intra_var_loss, inter_var_loss

def cosine_sim(x, y):
    dimention = [i for i in range(1,len(x.size()))]
    return (x*y).sum(dim=dimention) / (x**2).sum(dim=dimention).sqrt() / (y**2).sum(dim=dimention).sqrt()

def margin_loss(t_feature, t_prob, s_feature, s_prob, label, num_classes, criterion):
    t_category_list = []
    s_category_list = []
    t_category_feature = []
    s_category_feature = []
    t_category_prob = []
    _, t_pred = torch.max(t_prob, 1)
    _, s_pred = torch.max(s_prob, 1)
    for i in range(num_classes):
        t_category_list.append([])
        s_category_list.append([])
        t_category_feature.append([])
        s_category_feature.append([])
        t_category_prob.append([])
        
    for i in range(len(t_pred)):
        if(t_pred[i] == label[i]):
            t_category_list[label[i]].append(t_feature[i])
            s_category_list[label[i]].append(s_feature[i])
            
            t_category_prob[label[i]].append(t_prob[i])
            t_category_feature[label[i]].append(t_feature[i])
            s_category_feature[label[i]].append(s_feature[i])
    
    category_list = []
    for i in range(num_classes):
        if(len(t_category_list[i]) > 0):
            t_category_list[i] = sum(t_category_list[i]) / len(t_category_list[i])
            s_category_list[i] = sum(s_category_list[i]) / len(s_category_list[i])
            category_list.append(i)
    
    t_cor = []
    s_cor = []
    for (idx1,i) in enumerate(category_list):
        t_category_feature[i] = torch.stack(t_category_feature[i])
        t_category_prob[i] = torch.stack(t_category_prob[i])
        s_category_feature[i] = torch.stack(s_category_feature[i])
        
        for idx2 in range(idx1+1, len(category_list)):
            j = category_list[idx2]
            _, sample_id = torch.max(t_category_prob[i][:,j], 0)
            t_class_cor = (t_category_feature[i][sample_id] - t_category_list[j]).detach()
            s_class_cor = s_category_feature[i][sample_id] - s_category_list[j].detach()
            t_cor.append(t_class_cor)
            s_cor.append(s_class_cor)
#             t_cor = t_class_cor
#             s_cor = s_class_cor
    return criterion(torch.stack(t_cor),torch.stack(s_cor))

In [7]:
CE = nn.CrossEntropyLoss()
MSE = nn.MSELoss()

CE = CE.cuda()
MSE = MSE.cuda()

loss_list = []
train_acc_list = []
test_acc_list = []

#initialize center list
center_list = []
for i in range(num_classes):
    center_list.append([])

use_filter = False
    
for epoch in range(num_epochs):
    Student.train()
    if epoch < 80:
        optimizer = optim.SGD(Student.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
    elif epoch < 120:
        optimizer = optim.SGD(Student.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
    else:
        optimizer = optim.SGD(Student.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001)
       
    cost = 0.0
    cnt = 0
    num_correct = 0
    s_tick = time.time()

    for i, (img, label) in enumerate(train_dataLoader):
        s = time.time()
        img, label = img.cuda(), label.cuda()
        t_out, tf1, tf2, tf3, tf4 = Teacher(img)
        s_out, sf1, sf2, sf3, sf4 = Student(img)

#         print(sf1.size(), sf2.size(), sf3.size(), sf4.size())
        _, t_pred = torch.max(t_out, 1)
        _, s_pred = torch.max(s_out, 1)
        if use_filter:
#             TF4, T_center_list = cluster_filter(tf4, t_pred, label, num_classes)
#             SF4, S_center_list = cluster_filter(sf4, s_pred, label, num_classes)
#             T_OUT, T_logits = cluster_filter(t_out, t_pred, label, num_classes)
#             S_OUT, S_logits = cluster_filter(s_out, s_pred, label, num_classes)
            
            
            TF4 = cluster_filter(tf4, t_pred, label, num_classes)
            SF4 = cluster_filter(sf4, s_pred, label, num_classes)
            T_OUT = cluster_filter(t_out, t_pred, label, num_classes)
            S_OUT = cluster_filter(s_out, s_pred, label, num_classes)
            
            T_center_list = cluster_filter(tf4, s_pred, label, num_classes, mode='inter')
            S_center_list = cluster_filter(sf4, s_pred, label, num_classes, mode='inter')
            
            T_logits = cluster_filter(t_out, s_pred, label, num_classes, mode='inter')
            S_logits = cluster_filter(s_out, s_pred, label, num_classes, mode='inter')
        else:
            TF4, T_center_list = cluster(tf4, label, num_classes)
            SF4, S_center_list = cluster(sf4, label, num_classes)
            T_OUT, T_logits = cluster(t_out, label, num_classes)
            S_OUT, S_logits = cluster(s_out, label, num_classes)
#         class_correlation = torch.tensor(0.0, requires_grad=False).cuda()
#         if(len(T_center_list) > 0):
#             T_cor = correlation(T_center_list)
#             S_cor = correlation(S_center_list)
#             class_correlation += MSE(T_cor.detach(), S_cor) / T_cor.size(1)
#         if(len(T_logits) > 0):
#             T_cor_logits = correlation_logits(T_logits)
#             S_cor_logits = correlation_logits(S_logits)
#             class_correlation += MSE(T_cor_logits.detach(), S_cor_logits) / T_cor_logits.size(1)
        
#         sim = MSE(cosine_sim(sf4, SF4), cosine_sim(tf4, TF4).detach()) + \
#                 MSE(cosine_sim(s_out, S_OUT), cosine_sim(t_out, T_OUT).detach())
        
        sim = MSE((tf4-TF4).detach(), sf4-SF4.detach()) / TF4.size(1) / TF4.size(2) / TF4.size(3) + \
                MSE((t_out-T_OUT).detach(), s_out-S_OUT.detach()) / T_OUT.size(1)
#         class_correlation = margin_loss(tf4, t_out, sf4, s_out, label, num_classes, MSE) / tf4.size(1) / tf4.size(2) / tf4.size(3)
        class_correlation = margin_loss(t_out, t_out, s_out, s_out, label, num_classes, MSE) / t_out.size(1)
        cluster_loss =  0.002*sim
#         TF3 = cluster(tf3, label, num_classes)
#         SF4 = cluster(sf4, label, num_classes)  
#         _, t_b4 = AE4(tf4)
#         _, s_b4 = AE4(sf4)
        
#         _, t_b3 = AE3(tf3)
#         _, s_b3 = AE3(sf3)
        temperature = 4
        cross_entropy = CE(s_out, label)# + 0.4*CE(s_out1, label) + 0.6*CE(s_out2, label) + 0.8*CE(s_out3, label)
        loss_KD = 0.00001*MSE(t_out.detach(), s_out)
#         loss_KD = - (F.softmax(t_out / temperature, 1).detach() *
#                      (F.log_softmax(s_out / temperature, 1) - F.log_softmax(t_out / temperature, 1).detach())).sum() / batch_size
#         loss_KD1 = - (F.softmax(t_out / temperature, 1).detach() *
#                      (F.log_softmax(s_out1 / temperature, 1) - F.log_softmax(t_out / temperature, 1).detach())).sum() / batch_size
#         loss_KD2 = - (F.softmax(t_out / temperature, 1).detach() *
#                      (F.log_softmax(s_out2 / temperature, 1) - F.log_softmax(t_out / temperature, 1).detach())).sum() / batch_size
#         loss_KD3 = - (F.softmax(t_out / temperature, 1).detach() *
#                      (F.log_softmax(s_out3 / temperature, 1) - F.log_softmax(t_out / temperature, 1).detach())).sum() / batch_size
        
#         sim = MSE((tf4-TF4).detach(), sf4-SF4) / TF4.size(1) / TF4.size(2) / TF4.size(3)
#         if(epoch > 20):
#         cr4 = correlation(sf4.detach())
#         cr3 = correlation(sf3)
#         cr2 = correlation(sf2)
#         cr1 = correlation(sf1)
#         self_sup_Loss = MSE(cr4, cr3)
#         loss = cross_entropy + loss_KD1 + loss_KD2 + loss_KD3
#         loss = cross_entropy + MSE(F.softmax(t_out / temperature, 1).detach(), F.softmax(s_out /temperature, 1))
#         soft_supervision = MSE(t_out.detach(), s_out) / t_out.size(1)
        alpha = 1
        intra_var_loss, inter_var_loss = reg_loss(s_out, label, num_classes)
        loss = alpha*cross_entropy +   cluster_loss + 0.1*intra_var_loss #+ (1-alpha)*loss_KD# + 0.1 * intra_var_loss / inter_var_loss
#         else:
#             loss= loss_KD + 0.00001*sim
#         print(alpha*cross_entropy, sim, (1-alpha)*loss_KD)
        _, pred = torch.max(s_out, 1)
    
    
        num_correct += (pred == label).sum().item()
        cost += loss.item()
        cnt += 1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        e = time.time()
        if(i%5 == 0):
            print('epoch %d: %d/%d, time:%.2fs  ce:%.3f sim:%.3f cor:%.3f cluster:%.3f kd:%.3f'%(epoch, cnt, math.ceil(len(train_data)/batch_size), e-s, cross_entropy.item(), sim.item(),class_correlation.item(), cluster_loss.item(),loss_KD.item()), end='\r')
    e_tick = time.time()
    loss_list.append(cost/cnt)
    train_acc_list.append(num_correct/len(train_data))
    test_acc_list.append(test(Student))
    log = 'epoch:%d loss:%f train_acc:%f test_acc:%f time:%.2fs'%(epoch, cost/cnt, num_correct/len(train_data), test(Student), e_tick-s_tick)
    print(log)
    open(log_path, 'a').write(log)
    torch.save(Student.state_dict(), s_path)
print('Maximum test accuracy:', max(test_acc_list))

ValueError: too many values to unpack (expected 5)

交叉熵+AE压缩最后一层：91.22
交叉熵：90.94
交叉熵+AE压缩最后二层：90.45
交叉熵+最后两层特征图逼近：90.75

四个交叉熵：91.64
四个交叉熵+3个浅层软标签约束：90.93
四个交叉熵+四个KL软标签约束：91.36
最后一层KL软约束：92.01
四个KL软约束：91.35
交叉熵+3个浅层软约束：90.81
最后一层softmax后的MSE软约束：90
交叉熵+最后一层KL软约束：91.95
交叉熵+logits的MSE：91.74
四个交叉熵+四个logits的MSE软标签约束：91.75
四个交叉熵+四个logtis的MSE软标签约束+最后一层AE的约束:91.9
交叉熵+四层特征图逼近：90.97
交叉熵+四层特征图逼近+AE压缩最后一层：91.63
交叉熵+四层特征图逼近（大权重）+AE压缩最后一层：91.18
四个交叉熵+四个logits的MSE+最后一层AE：91.75
交叉熵+cluster：91.37
交叉熵+Fitnet最后一层：91.33
交叉熵+cluster+logits的MSE：91.43
交叉熵+Fitnet最后一层+logits的MSE：91.56
交叉熵+KL散度+cluster：91.20

In [None]:
import matplotlib.pyplot as plt
x_axis = range(num_epochs)
plt.figure(figsize=(10,13))
plt.subplot(3,1,1)
plt.plot(x_axis, loss_list, 'r-')
plt.xlabel('epoch')
plt.ylabel('loss')

plt.subplot(3,1,2)
plt.plot(x_axis, train_acc_list, 'g-')
plt.xlabel('epoch')
plt.ylabel('train accuracy')

plt.subplot(3,1,3)
plt.plot(x_axis, test_acc_list, 'b-')
plt.xlabel('epoch')
plt.ylabel('test accuracy')

plt.savefig(curve_path, dpi=600)
plt.show()

In [None]:
import numpy as np
np.save('intra_loss1.npy', intra_CS_loss_list)
np.save('inter_loss1.npy', inter_CS_loss_list)
np.save('loss1.npy', loss_list)
np.save('train_acc1.npy', train_acc_list)
np.save('test_acc1.npy', test_acc_list)