In [1]:
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Softloss(nn.Module):
    def __init__(self,T=4,loss_portion=[1,0,0]) -> None:
        '''
        T: temperature
        loss_portion: KLD, cosine, mse
        '''
        super(Softloss,self).__init__()
        self.T=T
        self.portion=loss_portion
    def forward(self,x,y):
        soft_x=F.log_softmax(x/self.T,dim=-1)
        soft_y=F.softmax(y/self.T,dim=-1)
        loss=self.portion[0]*F.kl_div(soft_x,soft_y,reduction="batchmean")\
            +self.portion[1]*F.cosine_embedding_loss(soft_x,soft_y,torch.ones(soft_x.shape[0]).to(soft_x.device))\
            +self.portion[2]*F.mse_loss(soft_x,soft_y)
        return loss*self.T*self.T


In [4]:
class Teacher(nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        self.layers=nn.Sequential(nn.Flatten(),
                                    #nn.Dropout(0.2),
                                    nn.Linear(28*28,1200),
                                    #n.Dropout(0.5),
                                    nn.ReLU(),
                                    nn.Linear(1200,10))
    def forward(self,x):
        return self.layers(x)


In [5]:
class Student(nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        self.layers=nn.Sequential(nn.Flatten(),
                                    nn.Linear(28*28,400),
                                    nn.ReLU(),
                                    nn.Linear(400,10))
    def forward(self,x):
        return self.layers(x)

In [6]:
# useful libraries
import torchvision
import torchvision.transforms as transforms

#############################################
# your code here
# specify preprocessing function
transform = transforms.Compose(
    
    (transforms.ToTensor(),)
    
)
transform_train = transforms.Compose(
    (
    
    transforms.RandomCrop((28,28),padding=2),
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    
    #
    #transforms.ColorJitter(0.2,0,0)
    
    )
)

transform_val = transform
#############################################
# do NOT change these
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# a few arguments, do NOT change these
DATA_ROOT = "./data"
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 100

#############################################
# your code here
# construct dataset
train_set = MNIST(
    root=DATA_ROOT, 
    train=True, 
    download=True,
    transform=transform_train    # your code
)

val_set = MNIST(
    root=DATA_ROOT, 
    train=False, 
    download=True,
    transform=transform_val    # your code
)

# construct dataloader
train_loader = DataLoader(
    train_set, 
    batch_size=TRAIN_BATCH_SIZE,  # your code
    shuffle=True,     # your code
    num_workers=2
)

val_loader = DataLoader(
    val_set, 
    batch_size=VAL_BATCH_SIZE,  # your code
    shuffle=False,     # your code
    num_workers=2
)
#############################################

In [7]:
import torch.nn as nn
import torch.optim as optim


In [9]:
def train(T,portion,alpha,train_loader,val_loader,EPOCHS=150):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    student=Teacher()
    student=student.to(device)
    model=student
    teacher=Student()
    state=torch.load("./saved_model/4.5114_distilled_9921.pth")["state_dict"]
    teacher.load_state_dict(state)
    teacher=teacher.to(device)
    state=None
    # some hyperparameters
    # total number of training epochs
    teacher.eval()
    # hyperparameters, do NOT change right now
    # initial learning rate
    INITIAL_LR = 0.1

    # momentum for optimizer
    MOMENTUM = 0.9

    # L2 regularization strength
    REG = 0.00
    criterion = nn.CrossEntropyLoss()

    # Add optimizer
    optimizer = optim.SGD(student.parameters(),lr=INITIAL_LR,momentum=MOMENTUM,nesterov=True)
    soft_criterion=Softloss(T,portion)
    # the folder where the trained model is saved
    CHECKPOINT_FOLDER = "./tmp_model"
    DECAY_EPOCHS=1
    DECAY=0.95
    # start the training/validation process
    # the process should take about 5 minutes on a GTX 1070-Ti
    # if the code is written efficiently.
    best_val_acc = 0
    current_learning_rate = INITIAL_LR
    
    print("==> Training starts!")
    print("="*50)
    for i in range(0, EPOCHS):
        # handle the learning rate scheduler.
        
        if i % DECAY_EPOCHS == 0 and i != 0 :
            current_learning_rate = current_learning_rate * DECAY
        
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            #print("Current learning rate has decayed to %f" %current_learning_rate)
        
        #######################
        # your code here
        # switch to train mode
        model.train()
        
        #######################
        
        print("Epoch %d:" %i)
        # this help you compute the training accuracy
        total_examples = 0
        correct_examples = 0

        train_loss = 0 # track training loss if you want
        loader=train_loader
        
        # Train the model for 1 epoch.
        for batch_idx, (inputs, targets) in enumerate(loader):
            ####################################
            # your code here
            # copy inputs to device
            inputs=inputs.to(device)
            targets=targets.to(device).long()

            
            # compute the output and loss
            out=model(inputs)
            with torch.no_grad():
                soft_target=teacher(inputs)
            loss=(1-alpha)*criterion(out,targets)+alpha*soft_criterion(out,soft_target)
            
            # zero the gradient
            
            optimizer.zero_grad()
            # backpropagation
            loss.backward()

            
            # apply gradient and update the weights
            optimizer.step()
            train_loss+=loss.item()
            
            # count the number of correctly predicted samples in the current batch
            correct_examples+=torch.sum(out.argmax(-1)==targets).item()
            ####################################
        total_examples=len(train_loader.dataset)      
        avg_loss = train_loss / len(train_loader)
        avg_acc = correct_examples / total_examples
        print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

        # Validate on the validation dataset
        #######################
        # your code here
        # switch to eval mode
        model.eval()
        
        #######################

        # this help you compute the validation accuracy
        total_examples = 0
        correct_examples = 0
        
        val_loss = 0 # again, track the validation loss if you want

        # disable gradient during validation, which can save GPU memory
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                ####################################
                # your code here
                # copy inputs to device
                inputs=inputs.to(device)
                targets=targets.to(device).long()
                # compute the output and loss
                out=model(inputs)
                loss=criterion(out,targets)
                # count the number of correctly predicted samples in the current batch
                val_loss+=loss.item()
                correct_examples+=torch.sum(out.argmax(-1)==targets).item()
                
                ####################################
        total_examples=len(val_loader.dataset)
        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))
        
        # save the model checkpoint
        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            if not os.path.exists(CHECKPOINT_FOLDER):
                os.makedirs(CHECKPOINT_FOLDER)
            print("Saving ...")
            state = {'state_dict': model.state_dict(),
                    'epoch': i,
                    }
            torch.save(state, os.path.join(CHECKPOINT_FOLDER, str(T)+'_distilled.pth'))
            

    print("="*50)
    print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
    return best_val_acc

In [10]:
train(4,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 1.8436, Training accuracy: 0.9232
Validation loss: 0.0498, Validation accuracy: 0.9830
Saving ...
Epoch 1:
Training loss: 0.2759, Training accuracy: 0.9814
Validation loss: 0.0365, Validation accuracy: 0.9872
Saving ...
Epoch 2:
Training loss: 0.1772, Training accuracy: 0.9845
Validation loss: 0.0347, Validation accuracy: 0.9884
Saving ...
Epoch 3:
Training loss: 0.1364, Training accuracy: 0.9868
Validation loss: 0.0325, Validation accuracy: 0.9886
Saving ...
Epoch 4:
Training loss: 0.1137, Training accuracy: 0.9883
Validation loss: 0.0322, Validation accuracy: 0.9883
Epoch 5:
Training loss: 0.0984, Training accuracy: 0.9888
Validation loss: 0.0308, Validation accuracy: 0.9894
Saving ...
Epoch 6:
Training loss: 0.0874, Training accuracy: 0.9894
Validation loss: 0.0305, Validation accuracy: 0.9895
Saving ...
Epoch 7:
Training loss: 0.0793, Training accuracy: 0.9898
Validation loss: 0.0297, Validation accuracy: 0.9891
Epoch 8:
Training loss: 0

0.9915

In [11]:
train(3,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 1.3589, Training accuracy: 0.9217
Validation loss: 0.0491, Validation accuracy: 0.9833
Saving ...
Epoch 1:
Training loss: 0.2146, Training accuracy: 0.9794
Validation loss: 0.0383, Validation accuracy: 0.9860
Saving ...
Epoch 2:
Training loss: 0.1284, Training accuracy: 0.9842
Validation loss: 0.0320, Validation accuracy: 0.9893
Saving ...
Epoch 3:
Training loss: 0.0932, Training accuracy: 0.9870
Validation loss: 0.0289, Validation accuracy: 0.9895
Saving ...
Epoch 4:
Training loss: 0.0787, Training accuracy: 0.9873
Validation loss: 0.0298, Validation accuracy: 0.9895
Epoch 5:
Training loss: 0.0667, Training accuracy: 0.9884
Validation loss: 0.0274, Validation accuracy: 0.9901
Saving ...
Epoch 6:
Training loss: 0.0583, Training accuracy: 0.9892
Validation loss: 0.0261, Validation accuracy: 0.9914
Saving ...
Epoch 7:
Training loss: 0.0529, Training accuracy: 0.9895
Validation loss: 0.0280, Validation accuracy: 0.9910
Epoch 8:
Training loss: 0

0.9915