In [1]:
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Softloss(nn.Module):
    def __init__(self,T=4,loss_portion=[1,0,0]) -> None:
        '''
        T: temperature
        loss_portion: KLD, cosine, mse
        '''
        super(Softloss,self).__init__()
        self.T=T
        self.portion=loss_portion
    def forward(self,x,y):
        soft_x=F.log_softmax(x/self.T,dim=-1)
        soft_y=F.softmax(y/self.T,dim=-1)
        loss=self.portion[0]*F.kl_div(soft_x,soft_y,reduction="batchmean")\
            +self.portion[1]*F.cosine_embedding_loss(soft_x,soft_y,torch.ones(soft_x.shape[0]).to(soft_x.device))\
            +self.portion[2]*F.mse_loss(soft_x,soft_y)
        return loss*self.T*self.T


In [3]:
class ResBlock11(nn.Module):
    def __init__(self,channel,filter_size=3):
        super(ResBlock11, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.norm3 = nn.BatchNorm2d(channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=self.norm3(out+x)
        return out

class ResDownSampling11(nn.Module):
    def __init__(self,channel,out_channel,filter_size=3):
        super(ResDownSampling11, self).__init__()
        self.conv1 = nn.Conv2d(channel, out_channel, filter_size,stride=2,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.norm3 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(channel,out_channel,kernel_size=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm4 = nn.BatchNorm2d(out_channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        x=F.leaky_relu(self.norm3(self.conv3(x)))
        out=self.norm4(out+x)
        return out
class ResNetCIFAR11(nn.Module):
    def __init__(self ):
        super(ResNetCIFAR11, self).__init__()
        self.inconv=nn.Conv2d(3, 16,3,padding=1,bias=False)
        nn.init.xavier_normal_(self.inconv.weight)
        self.res_block=nn.Sequential(self.inconv,
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    ResBlock11(16),
                                    ResBlock11(16),
                                    ResBlock11(16),
                                    ResDownSampling11(16,32),
                                    ResBlock11(32),
                                    ResBlock11(32),
                                    ResDownSampling11(32,64),
                                    ResBlock11(64),
                                    ResBlock11(64),
                                    nn.AvgPool2d(8),
                                    nn.Flatten(),
                                    nn.Linear(64*1*1,10))
    def forward(self, x):
        return self.res_block(x)

In [4]:
class ResBlock(nn.Module):
    def __init__(self,channel,filter_size=3):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm2 = nn.BatchNorm2d(channel)
        self.norm3 = nn.BatchNorm2d(channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=F.leaky_relu(self.norm2(self.conv2(out)))
        out=self.norm3(out+x)
        return out

class ResDownSampling(nn.Module):
    def __init__(self,channel,out_channel,filter_size=3):
        super(ResDownSampling, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,stride=2,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, out_channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.norm3 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(channel,out_channel,kernel_size=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm4 = nn.BatchNorm2d(out_channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=F.leaky_relu(self.norm2(self.conv2(out)))
        x=F.leaky_relu(self.norm3(self.conv3(x)))
        out=self.norm4(out+x)
        return out
class ResNetCIFAR(nn.Module):
    def __init__(self ):
        super(ResNetCIFAR, self).__init__()
        self.inconv=nn.Conv2d(3, 16,3,padding=1,bias=False)
        nn.init.xavier_normal_(self.inconv.weight)
        self.res_block=nn.Sequential(self.inconv,
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    ResBlock(16),
                                    ResBlock(16),
                                    ResBlock(16),
                                    ResDownSampling(16,32),
                                    ResBlock(32),
                                    ResBlock(32),
                                    ResDownSampling(32,64),
                                    ResBlock(64),
                                    ResBlock(64),
                                    nn.AvgPool2d(8),
                                    nn.Flatten(),
                                    nn.Linear(64*1*1,10))
    def forward(self, x):
        return self.res_block(x)


In [5]:
# useful libraries
import torchvision
import torchvision.transforms as transforms

#############################################
# your code here
# specify preprocessing function
transform = transforms.Compose(
    (
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    )
)
transform_train = transforms.Compose(
    (
    
    transforms.RandomCrop((32,32),padding=4),
    transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    
    #
    #transforms.ColorJitter(0.2,0,0)
    
    )
)

transform_val = transform
#############################################
# do NOT change these
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# a few arguments, do NOT change these
DATA_ROOT = "./data"
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 100

#############################################
# your code here
# construct dataset
train_set = CIFAR10(
    root=DATA_ROOT, 
    train=True, 
    download=True,
    transform=transform_train    # your code
)

val_set = CIFAR10(
    root=DATA_ROOT, 
    train=False, 
    download=True,
    transform=transform_val    # your code
)

# construct dataloader
train_loader = DataLoader(
    train_set, 
    batch_size=TRAIN_BATCH_SIZE,  # your code
    shuffle=True,     # your code
    num_workers=2
)

val_loader = DataLoader(
    val_set, 
    batch_size=VAL_BATCH_SIZE,  # your code
    shuffle=False,     # your code
    num_workers=2
)
#############################################

Files already downloaded and verified
Files already downloaded and verified


In [58]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
student=ResNetCIFAR11()
student=student.to(device)


In [6]:
import torch.nn as nn
import torch.optim as optim



In [10]:
def train(T,portion,alpha,train_loader,val_loader,EPOCHS=200):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    student=ResNetCIFAR()
    student=student.to(device)
    model=student
    teacher=ResNetCIFAR11()
    state=torch.load("./saved_model/res11_8968_sgd.pth")["state_dict"]
    teacher.load_state_dict(state)
    teacher=teacher.to(device)
    state=None
    # some hyperparameters
    # total number of training epochs
    teacher.eval()
    # hyperparameters, do NOT change right now
    # initial learning rate
    INITIAL_LR = 0.1

    # momentum for optimizer
    MOMENTUM = 0.9

    # L2 regularization strength
    REG = 0.0004
    criterion = nn.CrossEntropyLoss()

    # Add optimizer
    optimizer = optim.SGD(student.parameters(),weight_decay=REG,lr=INITIAL_LR,momentum=MOMENTUM,nesterov=True)
    soft_criterion=Softloss(T,portion)
    # the folder where the trained model is saved
    CHECKPOINT_FOLDER = "./tmp_model"
    DECAY_EPOCHS=1
    DECAY=0.1
    # start the training/validation process
    # the process should take about 5 minutes on a GTX 1070-Ti
    # if the code is written efficiently.
    best_val_acc = 0
    current_learning_rate = INITIAL_LR
    
    print("==> Training starts!")
    print("="*50)
    for i in range(0, EPOCHS):
        # handle the learning rate scheduler.
        
        if i in [70,140]:
            current_learning_rate = current_learning_rate * DECAY
        
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            #print("Current learning rate has decayed to %f" %current_learning_rate)
        
        #######################
        # your code here
        # switch to train mode
        model.train()
        
        #######################
        
        print("Epoch %d:" %i)
        # this help you compute the training accuracy
        total_examples = 0
        correct_examples = 0

        train_loss = 0 # track training loss if you want
        loader=train_loader
        
        # Train the model for 1 epoch.
        for batch_idx, (inputs, targets) in enumerate(loader):
            ####################################
            # your code here
            # copy inputs to device
            inputs=inputs.to(device)
            targets=targets.to(device).long()

            
            # compute the output and loss
            out=model(inputs)
            with torch.no_grad():
                soft_target=teacher(inputs)
            loss=(1-alpha)*criterion(out,targets)+alpha*soft_criterion(out,soft_target)
            
            # zero the gradient
            
            optimizer.zero_grad()
            # backpropagation
            loss.backward()

            
            # apply gradient and update the weights
            optimizer.step()
            train_loss+=loss.item()
            
            # count the number of correctly predicted samples in the current batch
            correct_examples+=torch.sum(out.argmax(-1)==targets).item()
            ####################################
        total_examples=len(train_loader.dataset)      
        avg_loss = train_loss / len(train_loader)
        avg_acc = correct_examples / total_examples
        print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

        # Validate on the validation dataset
        #######################
        # your code here
        # switch to eval mode
        model.eval()
        
        #######################

        # this help you compute the validation accuracy
        total_examples = 0
        correct_examples = 0
        
        val_loss = 0 # again, track the validation loss if you want

        # disable gradient during validation, which can save GPU memory
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                ####################################
                # your code here
                # copy inputs to device
                inputs=inputs.to(device)
                targets=targets.to(device).long()
                # compute the output and loss
                out=model(inputs)
                loss=criterion(out,targets)
                # count the number of correctly predicted samples in the current batch
                val_loss+=loss.item()
                correct_examples+=torch.sum(out.argmax(-1)==targets).item()
                
                ####################################
        total_examples=len(val_loader.dataset)
        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))
        
        # save the model checkpoint
        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            if not os.path.exists(CHECKPOINT_FOLDER):
                os.makedirs(CHECKPOINT_FOLDER)
            print("Saving ...")
            state = {'state_dict': model.state_dict(),
                    'epoch': i,
                    }
            torch.save(state, os.path.join(CHECKPOINT_FOLDER, str(T)+'_distilled.pth'))
            

    print("="*50)
    print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
    return best_val_acc

In [11]:
train(4,[1,0,0],0.2,train_loader,val_loader,EPOCHS=200)

==> Training starts!
Epoch 0:
Training loss: 1.5067, Training accuracy: 0.4879
Validation loss: 1.1423, Validation accuracy: 0.6079
Saving ...
Epoch 1:
Training loss: 0.9481, Training accuracy: 0.6831
Validation loss: 0.9367, Validation accuracy: 0.6774
Saving ...
Epoch 2:
Training loss: 0.7466, Training accuracy: 0.7497
Validation loss: 0.7266, Validation accuracy: 0.7511
Saving ...
Epoch 3:
Training loss: 0.6561, Training accuracy: 0.7805
Validation loss: 0.7169, Validation accuracy: 0.7559
Saving ...
Epoch 4:
Training loss: 0.6055, Training accuracy: 0.7967
Validation loss: 0.7179, Validation accuracy: 0.7580
Saving ...
Epoch 5:
Training loss: 0.5713, Training accuracy: 0.8082
Validation loss: 0.6357, Validation accuracy: 0.7928
Saving ...
Epoch 6:
Training loss: 0.5473, Training accuracy: 0.8182
Validation loss: 0.6555, Validation accuracy: 0.7812
Epoch 7:
Training loss: 0.5293, Training accuracy: 0.8237
Validation loss: 0.6590, Validation accuracy: 0.7856
Epoch 8:
Training loss: 0

0.925