In [11]:
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Softloss(nn.Module):
    def __init__(self,T=4,loss_portion=[1,0,0]) -> None:
        '''
        T: temperature
        loss_portion: KLD, cosine, mse
        '''
        super(Softloss,self).__init__()
        self.T=T
        self.portion=loss_portion
    def forward(self,x,y):
        soft_x=F.log_softmax(x/self.T,dim=-1)
        soft_y=F.softmax(y/self.T,dim=-1)
        loss=self.portion[0]*F.kl_div(soft_x,soft_y,reduction="batchmean")\
            +self.portion[1]*F.cosine_embedding_loss(soft_x,soft_y,torch.ones(soft_x.shape[0]).to(soft_x.device))\
            +self.portion[2]*F.mse_loss(soft_x,soft_y)
        return loss*self.T*self.T


In [3]:
class ResBlock11(nn.Module):
    def __init__(self,channel,filter_size=3):
        super(ResBlock11, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.norm3 = nn.BatchNorm2d(channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=self.norm3(out+x)
        return out

class ResDownSampling11(nn.Module):
    def __init__(self,channel,out_channel,filter_size=3):
        super(ResDownSampling11, self).__init__()
        self.conv1 = nn.Conv2d(channel, out_channel, filter_size,stride=2,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.norm3 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(channel,out_channel,kernel_size=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm4 = nn.BatchNorm2d(out_channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        x=F.leaky_relu(self.norm3(self.conv3(x)))
        out=self.norm4(out+x)
        return out
class ResNetCIFAR11(nn.Module):
    def __init__(self ):
        super(ResNetCIFAR11, self).__init__()
        self.inconv=nn.Conv2d(3, 16,3,padding=1,bias=False)
        nn.init.xavier_normal_(self.inconv.weight)
        self.res_block=nn.Sequential(self.inconv,
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    ResBlock11(16),
                                    ResBlock11(16),
                                    ResBlock11(16),
                                    ResDownSampling11(16,32),
                                    ResBlock11(32),
                                    ResBlock11(32),
                                    ResDownSampling11(32,64),
                                    ResBlock11(64),
                                    ResBlock11(64),
                                    nn.AvgPool2d(8),
                                    nn.Flatten(),
                                    nn.Linear(64*1*1,10))
    def forward(self, x):
        return self.res_block(x)

In [4]:
class ResBlock(nn.Module):
    def __init__(self,channel,filter_size=3):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm2 = nn.BatchNorm2d(channel)
        self.norm3 = nn.BatchNorm2d(channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=F.leaky_relu(self.norm2(self.conv2(out)))
        out=self.norm3(out+x)
        return out

class ResDownSampling(nn.Module):
    def __init__(self,channel,out_channel,filter_size=3):
        super(ResDownSampling, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, filter_size,stride=2,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm1 = nn.BatchNorm2d(channel)
        self.conv2 = nn.Conv2d(channel, out_channel, filter_size,padding=1,bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.norm3 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(channel,out_channel,kernel_size=1,stride=2,bias=False)
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='leaky_relu')
        self.norm4 = nn.BatchNorm2d(out_channel)
    def forward(self,x):
        out=F.leaky_relu(self.norm1(self.conv1(x)))
        out=F.leaky_relu(self.norm2(self.conv2(out)))
        x=F.leaky_relu(self.norm3(self.conv3(x)))
        out=self.norm4(out+x)
        return out
class ResNetCIFAR(nn.Module):
    def __init__(self ):
        super(ResNetCIFAR, self).__init__()
        self.inconv=nn.Conv2d(3, 16,3,padding=1,bias=False)
        nn.init.xavier_normal_(self.inconv.weight)
        self.res_block=nn.Sequential(self.inconv,
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    ResBlock(16),
                                    ResBlock(16),
                                    ResBlock(16),
                                    ResDownSampling(16,32),
                                    ResBlock(32),
                                    ResBlock(32),
                                    ResDownSampling(32,64),
                                    ResBlock(64),
                                    ResBlock(64),
                                    nn.AvgPool2d(8),
                                    nn.Flatten(),
                                    nn.Linear(64*1*1,10))
    def forward(self, x):
        return self.res_block(x)


In [5]:
# useful libraries
import torchvision
import torchvision.transforms as transforms

#############################################
# your code here
# specify preprocessing function
transform = transforms.Compose(
    (
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    )
)
transform_train = transforms.Compose(
    (
    
    transforms.RandomCrop((32,32),padding=4),
    transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    
    #
    #transforms.ColorJitter(0.2,0,0)
    
    )
)

transform_val = transform
#############################################
# do NOT change these
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# a few arguments, do NOT change these
DATA_ROOT = "./data"
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 100

#############################################
# your code here
# construct dataset
train_set = CIFAR10(
    root=DATA_ROOT, 
    train=True, 
    download=True,
    transform=transform_train    # your code
)

val_set = CIFAR10(
    root=DATA_ROOT, 
    train=False, 
    download=True,
    transform=transform_val    # your code
)

# construct dataloader
train_loader = DataLoader(
    train_set, 
    batch_size=TRAIN_BATCH_SIZE,  # your code
    shuffle=True,     # your code
    num_workers=2
)

val_loader = DataLoader(
    val_set, 
    batch_size=VAL_BATCH_SIZE,  # your code
    shuffle=False,     # your code
    num_workers=2
)
#############################################

Files already downloaded and verified
Files already downloaded and verified


In [58]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
student=ResNetCIFAR11()
student=student.to(device)


In [59]:
import torch.nn as nn
import torch.optim as optim

# hyperparameters, do NOT change right now
# initial learning rate
#INITIAL_LR = 0.0001

# momentum for optimizer
#MOMENTUM = 0.9

# L2 regularization strength
REG = 0.00

#############################################
# your code here
# create loss function
criterion = nn.CrossEntropyLoss()

# Add optimizer
optimizer = optim.Adam(student.parameters(),weight_decay=REG,amsgrad=True)

In [60]:
teacher=ResNetCIFAR()
state=torch.load("./saved_model/res20_9188.pth")["state_dict"]
teacher.load_state_dict(state)
teacher=teacher.to(device)
state=None

In [9]:
def train(model,teacher,optimizer,criterion,T,portion,alpha,train_loader,val_loader,EPOCHS=200):
    # some hyperparameters
    # total number of training epochs
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    teacher.eval()
    soft_criterion=Softloss(T,portion)
    # the folder where the trained model is saved
    CHECKPOINT_FOLDER = "./saved_model"
    #DECAY_EPOCHS=5
    #DECAY=0.75
    # start the training/validation process
    # the process should take about 5 minutes on a GTX 1070-Ti
    # if the code is written efficiently.
    best_val_acc = 0
    #current_learning_rate = INITIAL_LR
    
    print("==> Training starts!")
    print("="*50)
    for i in range(0, EPOCHS):
        # handle the learning rate scheduler.
        '''
        if i % DECAY_EPOCHS == 0 and i != 0 :
            current_learning_rate = current_learning_rate * DECAY
        
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
        '''
        #######################
        # your code here
        # switch to train mode
        model.train()
        
        #######################
        
        print("Epoch %d:" %i)
        # this help you compute the training accuracy
        total_examples = 0
        correct_examples = 0

        train_loss = 0 # track training loss if you want
        loader=train_loader
        
        # Train the model for 1 epoch.
        for batch_idx, (inputs, targets) in enumerate(loader):
            ####################################
            # your code here
            # copy inputs to device
            inputs=inputs.to(device)
            targets=targets.to(device).long()

            
            # compute the output and loss
            out=model(inputs)
            with torch.no_grad():
                soft_target=teacher(inputs)
            loss=(1-alpha)*criterion(out,targets)+alpha*soft_criterion(out,soft_target)
            
            # zero the gradient
            
            optimizer.zero_grad()
            # backpropagation
            loss.backward()

            
            # apply gradient and update the weights
            optimizer.step()
            train_loss+=loss.item()
            
            # count the number of correctly predicted samples in the current batch
            correct_examples+=torch.sum(out.argmax(-1)==targets).item()
            ####################################
        total_examples=len(train_loader.dataset)      
        avg_loss = train_loss / len(train_loader)
        avg_acc = correct_examples / total_examples
        print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

        # Validate on the validation dataset
        #######################
        # your code here
        # switch to eval mode
        model.eval()
        
        #######################

        # this help you compute the validation accuracy
        total_examples = 0
        correct_examples = 0
        
        val_loss = 0 # again, track the validation loss if you want

        # disable gradient during validation, which can save GPU memory
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                ####################################
                # your code here
                # copy inputs to device
                inputs=inputs.to(device)
                targets=targets.to(device).long()
                # compute the output and loss
                out=model(inputs)
                loss=criterion(out,targets)
                # count the number of correctly predicted samples in the current batch
                val_loss+=loss.item()
                correct_examples+=torch.sum(out.argmax(-1)==targets).item()
                
                ####################################
        total_examples=len(val_loader.dataset)
        avg_loss = val_loss / len(val_loader)
        avg_acc = correct_examples / total_examples
        print("Validation loss: %.4f, Validation accuracy: %.4f" % (avg_loss, avg_acc))
        
        # save the model checkpoint
        if avg_acc > best_val_acc:
            best_val_acc = avg_acc
            if not os.path.exists(CHECKPOINT_FOLDER):
                os.makedirs(CHECKPOINT_FOLDER)
            print("Saving ...")
            state = {'state_dict': model.state_dict(),
                    'epoch': i,
                    }
            torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'distilled.pth'))
            
        print('')

    print("="*50)
    print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
    return best_val_acc

In [28]:
train(student,teacher,optimizer,criterion,1,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 1.3926, Training accuracy: 0.4149
Validation loss: 1.3530, Validation accuracy: 0.5114
Saving ...

Epoch 1:
Training loss: 1.0502, Training accuracy: 0.5635
Validation loss: 1.0980, Validation accuracy: 0.6066
Saving ...

Epoch 2:
Training loss: 0.9056, Training accuracy: 0.6254
Validation loss: 0.9878, Validation accuracy: 0.6418
Saving ...

Epoch 3:
Training loss: 0.8198, Training accuracy: 0.6604
Validation loss: 0.9314, Validation accuracy: 0.6679
Saving ...

Epoch 4:
Training loss: 0.7583, Training accuracy: 0.6873
Validation loss: 0.8622, Validation accuracy: 0.7000
Saving ...

Epoch 5:
Training loss: 0.7086, Training accuracy: 0.7081
Validation loss: 0.8287, Validation accuracy: 0.7045
Saving ...

Epoch 6:
Training loss: 0.6557, Training accuracy: 0.7299
Validation loss: 0.7754, Validation accuracy: 0.7257
Saving ...

Epoch 7:
Training loss: 0.6124, Training accuracy: 0.7482
Validation loss: 0.7244, Validation accuracy: 0.7478
Saving 

0.89

In [33]:
train(student,teacher,optimizer,criterion,2,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 2.5655, Training accuracy: 0.4233
Validation loss: 1.3180, Validation accuracy: 0.5279
Saving ...

Epoch 1:
Training loss: 1.9464, Training accuracy: 0.5716
Validation loss: 1.1163, Validation accuracy: 0.6095
Saving ...

Epoch 2:
Training loss: 1.6837, Training accuracy: 0.6352
Validation loss: 1.0471, Validation accuracy: 0.6463
Saving ...

Epoch 3:
Training loss: 1.5188, Training accuracy: 0.6716
Validation loss: 0.9085, Validation accuracy: 0.6855
Saving ...

Epoch 4:
Training loss: 1.4000, Training accuracy: 0.7055
Validation loss: 0.8691, Validation accuracy: 0.6981
Saving ...

Epoch 5:
Training loss: 1.3154, Training accuracy: 0.7224
Validation loss: 0.8272, Validation accuracy: 0.7175
Saving ...

Epoch 6:
Training loss: 1.2348, Training accuracy: 0.7404
Validation loss: 0.8601, Validation accuracy: 0.7113

Epoch 7:
Training loss: 1.1705, Training accuracy: 0.7554
Validation loss: 0.7240, Validation accuracy: 0.7535
Saving ...

Epoch 

0.8878

In [37]:
train(student,teacher,optimizer,criterion,4,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 2.1979, Training accuracy: 0.4179
Validation loss: 1.4050, Validation accuracy: 0.4981
Saving ...

Epoch 1:
Training loss: 1.7070, Training accuracy: 0.5686
Validation loss: 1.2272, Validation accuracy: 0.5743
Saving ...

Epoch 2:
Training loss: 1.4861, Training accuracy: 0.6276
Validation loss: 1.0451, Validation accuracy: 0.6436
Saving ...

Epoch 3:
Training loss: 1.3582, Training accuracy: 0.6641
Validation loss: 0.9836, Validation accuracy: 0.6619
Saving ...

Epoch 4:
Training loss: 1.2665, Training accuracy: 0.6885
Validation loss: 0.9084, Validation accuracy: 0.6915
Saving ...

Epoch 5:
Training loss: 1.2010, Training accuracy: 0.7094
Validation loss: 0.8758, Validation accuracy: 0.6994
Saving ...

Epoch 6:
Training loss: 1.1415, Training accuracy: 0.7236
Validation loss: 0.8436, Validation accuracy: 0.7207
Saving ...

Epoch 7:
Training loss: 1.0881, Training accuracy: 0.7383
Validation loss: 0.7945, Validation accuracy: 0.7347
Saving 

0.8906

In [40]:
train(student,teacher,optimizer,criterion,8,[1,0,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 1.6901, Training accuracy: 0.4247
Validation loss: 1.5441, Validation accuracy: 0.4756
Saving ...

Epoch 1:
Training loss: 1.3540, Training accuracy: 0.5568
Validation loss: 1.1554, Validation accuracy: 0.6042
Saving ...

Epoch 2:
Training loss: 1.1841, Training accuracy: 0.6200
Validation loss: 1.0486, Validation accuracy: 0.6425
Saving ...

Epoch 3:
Training loss: 1.0758, Training accuracy: 0.6593
Validation loss: 0.9456, Validation accuracy: 0.6724
Saving ...

Epoch 4:
Training loss: 0.9906, Training accuracy: 0.6916
Validation loss: 0.8990, Validation accuracy: 0.6989
Saving ...

Epoch 5:
Training loss: 0.9288, Training accuracy: 0.7142
Validation loss: 0.8528, Validation accuracy: 0.7130
Saving ...

Epoch 6:
Training loss: 0.8692, Training accuracy: 0.7350
Validation loss: 0.7906, Validation accuracy: 0.7356
Saving ...

Epoch 7:
Training loss: 0.8199, Training accuracy: 0.7506
Validation loss: 0.7455, Validation accuracy: 0.7523
Saving 

0.8878

In [10]:
train(student,teacher,optimizer,criterion,1,[0,1,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 1.1248, Training accuracy: 0.3453
Validation loss: 2.6631, Validation accuracy: 0.3827
Saving ...

Epoch 1:
Training loss: 1.0688, Training accuracy: 0.4397
Validation loss: 2.4484, Validation accuracy: 0.4463
Saving ...

Epoch 2:
Training loss: 1.0560, Training accuracy: 0.4807
Validation loss: 2.1280, Validation accuracy: 0.5024
Saving ...

Epoch 3:
Training loss: 1.0474, Training accuracy: 0.5131
Validation loss: 2.3435, Validation accuracy: 0.4942

Epoch 4:
Training loss: 1.0413, Training accuracy: 0.5384
Validation loss: 2.1788, Validation accuracy: 0.5265
Saving ...

Epoch 5:
Training loss: 1.0377, Training accuracy: 0.5527
Validation loss: 2.0903, Validation accuracy: 0.5628
Saving ...

Epoch 6:
Training loss: 1.0347, Training accuracy: 0.5683
Validation loss: 2.0740, Validation accuracy: 0.5646
Saving ...

Epoch 7:
Training loss: 1.0324, Training accuracy: 0.5817
Validation loss: 2.1203, Validation accuracy: 0.5775
Saving ...

Epoch 

0.7865

In [26]:
train(student,teacher,optimizer,criterion,2,[0,1,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 5.0655, Training accuracy: 0.3171
Validation loss: 4.0881, Validation accuracy: 0.3150
Saving ...

Epoch 1:
Training loss: 4.7682, Training accuracy: 0.3813
Validation loss: 4.2499, Validation accuracy: 0.3694
Saving ...

Epoch 2:
Training loss: 4.6991, Training accuracy: 0.4147
Validation loss: 4.2745, Validation accuracy: 0.4280
Saving ...

Epoch 3:
Training loss: 4.6682, Training accuracy: 0.4336
Validation loss: 4.3875, Validation accuracy: 0.4478
Saving ...

Epoch 4:
Training loss: 4.6359, Training accuracy: 0.4529
Validation loss: 4.1668, Validation accuracy: 0.4612
Saving ...

Epoch 5:
Training loss: 4.6100, Training accuracy: 0.4739
Validation loss: 3.6246, Validation accuracy: 0.4742
Saving ...

Epoch 6:
Training loss: 4.5862, Training accuracy: 0.4881
Validation loss: 3.5046, Validation accuracy: 0.4818
Saving ...

Epoch 7:
Training loss: 4.5630, Training accuracy: 0.4977
Validation loss: 3.2782, Validation accuracy: 0.4993
Saving 

0.6771

In [33]:
train(student,teacher,optimizer,criterion,4,[0,1,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 25.6500, Training accuracy: 0.2139
Validation loss: 5.0232, Validation accuracy: 0.1875
Saving ...

Epoch 1:
Training loss: 23.9535, Training accuracy: 0.2259
Validation loss: 6.6620, Validation accuracy: 0.2895
Saving ...

Epoch 2:
Training loss: 23.5759, Training accuracy: 0.2528
Validation loss: 11.0997, Validation accuracy: 0.2514

Epoch 3:
Training loss: 23.3857, Training accuracy: 0.2690
Validation loss: 11.6303, Validation accuracy: 0.2634

Epoch 4:
Training loss: 23.2681, Training accuracy: 0.2833
Validation loss: 14.3683, Validation accuracy: 0.2860

Epoch 5:
Training loss: 23.1979, Training accuracy: 0.2923
Validation loss: 18.2550, Validation accuracy: 0.2955
Saving ...

Epoch 6:
Training loss: 23.1366, Training accuracy: 0.2923
Validation loss: 16.1116, Validation accuracy: 0.3010
Saving ...

Epoch 7:
Training loss: 23.0790, Training accuracy: 0.3015
Validation loss: 16.5622, Validation accuracy: 0.2586

Epoch 8:
Training loss: 2

0.4668

In [42]:
train(student,teacher,optimizer,criterion,8,[0,1,0],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 116.2444, Training accuracy: 0.1834
Validation loss: 10.3513, Validation accuracy: 0.1781
Saving ...

Epoch 1:
Training loss: 108.7878, Training accuracy: 0.1872
Validation loss: 25.3271, Validation accuracy: 0.1691

Epoch 2:
Training loss: 106.7553, Training accuracy: 0.1820
Validation loss: 17.2843, Validation accuracy: 0.2095
Saving ...

Epoch 3:
Training loss: 105.5789, Training accuracy: 0.1953
Validation loss: 43.4810, Validation accuracy: 0.1905

Epoch 4:
Training loss: 105.0658, Training accuracy: 0.1859
Validation loss: 40.0285, Validation accuracy: 0.1738

Epoch 5:
Training loss: 104.6674, Training accuracy: 0.1865
Validation loss: 41.3829, Validation accuracy: 0.1920

Epoch 6:
Training loss: 104.3655, Training accuracy: 0.1786
Validation loss: 50.6629, Validation accuracy: 0.1899

Epoch 7:
Training loss: 104.2609, Training accuracy: 0.1714
Validation loss: 57.1645, Validation accuracy: 0.1835

Epoch 8:
Training loss: 103.9941, Tra

0.2544

In [47]:
train(student,teacher,optimizer,criterion,1,[0,0,1],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 5.8541, Training accuracy: 0.3465
Validation loss: 2.2467, Validation accuracy: 0.4461
Saving ...

Epoch 1:
Training loss: 5.8493, Training accuracy: 0.4747
Validation loss: 2.2316, Validation accuracy: 0.5216
Saving ...

Epoch 2:
Training loss: 5.8478, Training accuracy: 0.5400
Validation loss: 2.2165, Validation accuracy: 0.5579
Saving ...

Epoch 3:
Training loss: 5.8465, Training accuracy: 0.5909
Validation loss: 2.1990, Validation accuracy: 0.6020
Saving ...

Epoch 4:
Training loss: 5.8455, Training accuracy: 0.6250
Validation loss: 2.1881, Validation accuracy: 0.6407
Saving ...

Epoch 5:
Training loss: 5.8447, Training accuracy: 0.6536
Validation loss: 2.1804, Validation accuracy: 0.6625
Saving ...

Epoch 6:
Training loss: 5.8440, Training accuracy: 0.6743
Validation loss: 2.1719, Validation accuracy: 0.6777
Saving ...

Epoch 7:
Training loss: 5.8434, Training accuracy: 0.6907
Validation loss: 2.1675, Validation accuracy: 0.6869
Saving 

0.8794

In [52]:
train(student,teacher,optimizer,criterion,2,[0,0,1],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 23.2290, Training accuracy: 0.3860
Validation loss: 2.2135, Validation accuracy: 0.4815
Saving ...

Epoch 1:
Training loss: 23.2228, Training accuracy: 0.5219
Validation loss: 2.1835, Validation accuracy: 0.5646
Saving ...

Epoch 2:
Training loss: 23.2197, Training accuracy: 0.5834
Validation loss: 2.1652, Validation accuracy: 0.5893
Saving ...

Epoch 3:
Training loss: 23.2175, Training accuracy: 0.6286
Validation loss: 2.1423, Validation accuracy: 0.6485
Saving ...

Epoch 4:
Training loss: 23.2156, Training accuracy: 0.6680
Validation loss: 2.1333, Validation accuracy: 0.6877
Saving ...

Epoch 5:
Training loss: 23.2140, Training accuracy: 0.6961
Validation loss: 2.1150, Validation accuracy: 0.7041
Saving ...

Epoch 6:
Training loss: 23.2129, Training accuracy: 0.7185
Validation loss: 2.1140, Validation accuracy: 0.7228
Saving ...

Epoch 7:
Training loss: 23.2120, Training accuracy: 0.7352
Validation loss: 2.1091, Validation accuracy: 0.7360

0.8856

In [57]:
train(student,teacher,optimizer,criterion,4,[0,0,1],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 92.4496, Training accuracy: 0.3695
Validation loss: 2.2326, Validation accuracy: 0.4687
Saving ...

Epoch 1:
Training loss: 92.4442, Training accuracy: 0.4989
Validation loss: 2.2116, Validation accuracy: 0.5524
Saving ...

Epoch 2:
Training loss: 92.4424, Training accuracy: 0.5601
Validation loss: 2.1980, Validation accuracy: 0.5824
Saving ...

Epoch 3:
Training loss: 92.4410, Training accuracy: 0.6017
Validation loss: 2.1843, Validation accuracy: 0.6229
Saving ...

Epoch 4:
Training loss: 92.4397, Training accuracy: 0.6368
Validation loss: 2.1699, Validation accuracy: 0.6362
Saving ...

Epoch 5:
Training loss: 92.4390, Training accuracy: 0.6609
Validation loss: 2.1701, Validation accuracy: 0.6699
Saving ...

Epoch 6:
Training loss: 92.4382, Training accuracy: 0.6847
Validation loss: 2.1591, Validation accuracy: 0.6921
Saving ...

Epoch 7:
Training loss: 92.4376, Training accuracy: 0.7026
Validation loss: 2.1509, Validation accuracy: 0.7022

0.884

In [61]:
train(student,teacher,optimizer,criterion,8,[0,0,1],1,train_loader,val_loader)

==> Training starts!
Epoch 0:
Training loss: 369.4914, Training accuracy: 0.3525
Validation loss: 2.2486, Validation accuracy: 0.4358
Saving ...

Epoch 1:
Training loss: 369.4879, Training accuracy: 0.4772
Validation loss: 2.2339, Validation accuracy: 0.5185
Saving ...

Epoch 2:
Training loss: 369.4867, Training accuracy: 0.5374
Validation loss: 2.2169, Validation accuracy: 0.5455
Saving ...

Epoch 3:
Training loss: 369.4859, Training accuracy: 0.5851
Validation loss: 2.2086, Validation accuracy: 0.5941
Saving ...

Epoch 4:
Training loss: 369.4853, Training accuracy: 0.6183
Validation loss: 2.1956, Validation accuracy: 0.6325
Saving ...

Epoch 5:
Training loss: 369.4847, Training accuracy: 0.6433
Validation loss: 2.1948, Validation accuracy: 0.6543
Saving ...

Epoch 6:
Training loss: 369.4843, Training accuracy: 0.6634
Validation loss: 2.1902, Validation accuracy: 0.6625
Saving ...

Epoch 7:
Training loss: 369.4838, Training accuracy: 0.6799
Validation loss: 2.1838, Validation accuracy

0.8786