In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# numpy and matplotlib
import numpy as np
import matplotlib.pyplot as plt

import os
from PIL import Image
import sys
import torch.nn.functional as F

In [2]:
# initialize for  each parameters
DATASET = 'MNIST'
BATCH_SIZE = 100
NUM_WORKERS = 2

WEIGHT_DECAY = 0.007
LEARNING_RATE = 0.01
MOMENTUM = 0.9

SCHEDULER_STEPS = 100
SCHEDULER_GAMMA = 0.1

SEED = 1

EPOCH = 150

KD_LAMBDA = 2.0

TRIPLET_MARGINE = 5.0

In [3]:
!pwd

/ssd_scratch/cvit/sashank.sridhar


In [4]:
# fixing the seed
torch.cuda.manual_seed_all(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [5]:
# check if gpu is available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("gpu mode")
else:
    device = torch.device("cpu")
    print("cpu mode")

gpu mode


In [6]:
# the name of results files
codename = 'mnist_kd'

fnnname = codename + "_fnn_model"

total_loss_name = codename + "_total_loss"
soft_loss_name = codename + "_soft_loss"
tri_loss_name = codename + "_tri_loss"
acc_name = codename + "_accuracy"

result_name = codename + "_result"

In [7]:
class Datasets(object):
    def __init__(self, dataset_name, batch_size = 100, num_workers = 2, transform = None, shuffle = True):
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        self.shuffle = shuffle
        
    def create(self, path = None):
        print("Dataset :",self.dataset_name)
        if self.transform is None:
                self.transform = transforms.Compose([transforms.ToTensor()])
        
        
        if path is None:
            path = "./"+self.dataset_name+"Dataset/data"
        
        
        if self.dataset_name == "MNIST":
            trainset = torchvision.datasets.MNIST(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.MNIST(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "FashionMNIST":
            trainset = torchvision.datasets.FashionMNIST(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.FashionMNIST(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "CIFAR10":
            trainset = torchvision.datasets.CIFAR10(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.CIFAR10(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "CIFAR100":
            trainset = torchvision.datasets.CIFAR100(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.CIFAR100(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(100))
            base_labels = trainset.classes
        
        else:
            raise KeyError("Unknown dataset: {}".format(self.dataset_name))
            
        
        trainloader = torch.utils.data.DataLoader(trainset, batch_size = self.batch_size,
                        shuffle = self.shuffle, num_workers = self.num_workers)
        
        if testset is not None:
            testloader = torch.utils.data.DataLoader(testset, batch_size = self.batch_size,
                        shuffle = False, num_workers = self.num_workers)
        else:
            testloader = None
            
            
        return [trainloader, testloader, classes, base_labels, trainset, testset]
    
    def worker_init_fn(self, worker_id):                                                          
        np.random.seed(worker_id)

In [8]:
# load the data set
instance_datasets = Datasets(DATASET, BATCH_SIZE, NUM_WORKERS, shuffle = False)
data_sets = instance_datasets.create()

#trainloader = data_sets[0]
#testloader = data_sets[1]
classes = data_sets[2]
based_labels = data_sets[3]
trainset = data_sets[4]
testset = data_sets[5]

Dataset : MNIST


In [9]:
class KDTripletDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
        data = dataset.data
        labels = dataset.targets
        if type(labels) is not torch.Tensor:
                labels = torch.tensor(labels)
        
        # make label set 0-9
        labels_set = set(labels.numpy())
        
        # make the indices excepted each classes
        label_to_indices = {label : np.where(labels.numpy() != label)[0] for label in labels_set}
        
        if self.dataset.train:
            self.negative_indices = label_to_indices
        else:
            self.negative_indices = [[np.random.choice(label_to_indices[labels[i].item()])] for i in range(len(data))]
        

            
    def __getitem__(self, index):
        if self.dataset.train:
            img1_2, label1_2 = self.dataset[index]
            if type(label1_2) is not torch.Tensor:
                label1_2 = torch.tensor(label1_2)
            img3, label3 = self.dataset[np.random.choice(self.negative_indices[label1_2.item()])]
        else:
            img1_2, label1_2 = self.dataset[index]
            img3, label3 = self.dataset[self.negative_indices[index][0]]
        
            
        return (img1_2, img3), (label1_2, label3)
    
    def __len__(self):
        return len(self.dataset)

In [10]:
# use the KD Triplet Dataset by using above dataset
tri_trainset = KDTripletDataset(trainset)
tri_testset = KDTripletDataset(testset)
tri_trainloader = torch.utils.data.DataLoader(tri_trainset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS)
tri_testloader = torch.utils.data.DataLoader(tri_testset, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS)

In [11]:
class TeacherNetwork(nn.Module):
    def __init__(self):
        super(TeacherNetwork, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.dropout_input = 0.0
        self.dropout_hidden = 0.0
        self.is_training = True
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = F.dropout(F.relu(self.fc2(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc3(x)
        return x

class StudentNetwork(nn.Module):
    def __init__(self):
        super(StudentNetwork, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 400)
        self.fc2 = nn.Linear(400, 10)
        self.dropout_input = 0.0
        self.dropout_hidden = 0.0
        self.is_training = True
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.dropout(x, p=self.dropout_input, training=self.is_training)
        x = F.dropout(F.relu(self.fc1(x)), p=self.dropout_hidden, training=self.is_training)
        x = self.fc2(x)
        return x


In [46]:
net = TeacherNetwork()
bat_size = 100
max_epochs = 20  # 100 gives better results
ep_log_interval = 5
lrn_rate = 0.005

loss_func = nn.CrossEntropyLoss()  # does log-softmax()
optimizer = optim.SGD(net.parameters(), lr=lrn_rate)

print("\nbat_size = %3d " % bat_size)
print("loss = " + str(loss_func))
print("optimizer = SGD")
print("max_epochs = %3d " % max_epochs)
print("lrn_rate = %0.3f " % lrn_rate)

print("\nStarting training")
net.train()  # set mode

for epoch in range(0, max_epochs):
    
    accuracy = 0.0
    ep_loss = 0  # for one full epoch
    c = 0
    for (batch_idx, batch) in enumerate(data_sets[0]):
        c+=1
        print("\rIteration: {}/{}".format(c, len(data_sets[0])), end="")
        (X, y) = batch  # X = pixels, y = target labels
        
        optimizer.zero_grad()
        X.to(device)
#         print(X.shape)
        y.to(device)
        oupt = net(X)
        loss_val = loss_func(oupt, y)  # a tensor
        ep_loss += loss_val.item()  # accumulate
        loss_val.backward()  # compute grads
        optimizer.step()     # update weights
        accuracy += float(torch.sum(torch.argmax(oupt, dim=1) == y).item()) / y.shape[0]
    print("train mean loss={}, accuracy={}".format(ep_loss*bat_size/len(data_sets[0]), accuracy/len(data_sets[0])))
print("Done ") 



bat_size = 100 
loss = CrossEntropyLoss()
optimizer = SGD
max_epochs =  20 
lrn_rate = 0.005 

Starting training
Iteration: 600/600train mean loss=214.56449995438257, accuracy=0.5914333333333328
Iteration: 600/600train mean loss=138.40460643172264, accuracy=0.768183333333333
Iteration: 600/600train mean loss=74.1871007680893, accuracy=0.8315833333333333
Iteration: 600/600train mean loss=54.11694197356701, accuracy=0.8622833333333335
Iteration: 600/600train mean loss=45.53262098878622, accuracy=0.8792833333333342
Iteration: 600/600train mean loss=40.77618765706817, accuracy=0.8889666666666668
Iteration: 600/600train mean loss=37.73662828281522, accuracy=0.8952500000000005
Iteration: 600/600train mean loss=35.58648294645051, accuracy=0.9002000000000002
Iteration: 600/600train mean loss=33.94134741773208, accuracy=0.9046833333333334
Iteration: 600/600train mean loss=32.60682564539214, accuracy=0.9082500000000007
Iteration: 600/600train mean loss=31.476194728165865, accuracy=0.91090000000

In [47]:
torch.save(net.state_dict(), "Mnist_teacher.pt")

In [12]:
# network and criterions
model_t = TeacherNetwork().to(device)
model_s = StudentNetwork().to(device)

model_t.load_state_dict(torch.load("Mnist_teacher.pt"))

optimizer = optim.SGD(model_s.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEPS, gamma=SCHEDULER_GAMMA)

soft_criterion = nn.CrossEntropyLoss()
triplet_loss = nn.TripletMarginLoss(margin=TRIPLET_MARGINE)

In [13]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [14]:
class NetworkFit(object):
    def __init__(self, model_t, model_s, optimizer, soft_criterion, triplet_loss):
        self.model_t = model_t
        self.model_s = model_s
        self.optimizer = optimizer
        
        self.soft_criterion = soft_criterion
        self.triplet_loss = triplet_loss
        
        self.model_t.eval()
        

    def train(self, inputs, labels, kd_lambda = 2.0):
        self.optimizer.zero_grad()
        self.model_s.train()

        img1_t = inputs[0]
        img2_s = inputs[1]
        img3_s = inputs[2]
        
        label1_t = labels[0]
        label2_s = labels[1]
        label3_s = labels[2]
        
        out1_t = self.model_t(img1_t)
        out2_s = self.model_s(img2_s)
        out3_s = self.model_s(img3_s)
        
        soft_loss = self.soft_criterion(out2_s, label2_s)
        trip_loss = self.triplet_loss(out1_t, out2_s, out3_s)

        loss = soft_loss + kd_lambda*trip_loss

        loss.backward()
        self.optimizer.step()
            
            
    def test(self, inputs, labels, kd_lambda = 2.0):
        self.model_s.eval()
        
        img1_t = inputs[0]
        img2_s = inputs[1]
        img3_s = inputs[2]
        
        label1_t = labels[0]
        label2_s = labels[1]
        label3_s = labels[2]
        
        out1_t = self.model_t(img1_t)
        out2_s = self.model_s(img2_s)
        out3_s = self.model_s(img3_s)
        
        soft_loss = self.soft_criterion(out2_s, label2_s)
        trip_loss = self.triplet_loss(out1_t, out2_s, out3_s)

        loss = soft_loss + kd_lambda*trip_loss
        
        _, predicted = out2_s.max(1)
        correct = (predicted == label2_s).sum().item()
        
        return [loss.item(), soft_loss.item(), trip_loss.item()], [correct]
        
        
    

In [15]:
# fit for training and test
fit = NetworkFit(model_t, model_s, optimizer, soft_criterion, triplet_loss)

In [16]:
class Score(object):
    def __init__(self, score = 0):
        self.score = score
        
    def sum_score(self, score):
        self.score += score
    
    def set_score(self, score):
        self.score = score
    
    def init_score(self):
        self.score = 0
    
    def get_score(self):
        return self.score

In [17]:
class ScoreCalc(object):
    def __init__(self, losses, corrects, batch_size):
        self.losses = losses
        self.corrects = corrects
        
        self.batch_size = batch_size
        
        self.len_l = len(losses)
        self.len_c = len(corrects)
        
        self.train_losses = [[] for l in range(self.len_l)]
        self.train_corrects = [[] for c in range(self.len_c)]
        
        self.test_losses = [[] for l in range(self.len_l)]
        self.test_corrects = [[] for c in range(self.len_c)]

        patience = 10

        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
       
    
    def calc_sum(self, losses, corrects):
        if len(losses) != len(self.losses):
            print("warning : len(losses) != len(self.losses)")
            sys.exit()
        if len(corrects) != len(self.corrects):
            print("warning : len(corrects) != len(self.corrects)")
            sys.exit()
        
        for l in range(self.len_l):
            self.losses[l].sum_score(losses[l])
        
        for c in range(self.len_c):
            self.corrects[c].sum_score(corrects[c])
        
        return self.losses, self.corrects
    
    
    def score_del(self):
        for loss in self.losses:
            loss.init_score()
        for correct in self.corrects:
            correct.init_score()

        
    def score_print(self, data_num, train = True):
        if train:
            print("train mean loss={}, accuracy={}".format(self.losses[0].get_score()*self.batch_size/data_num, float(self.corrects[0].get_score()/data_num)))
        else:
            
            print("test mean loss={}, accuracy={}".format(self.losses[0].get_score()*self.batch_size/data_num, float(self.corrects[0].get_score()/data_num)))
            self.early_stopping(self.losses[0].get_score()*self.batch_size/data_num, model_s)
        
            if self.early_stopping.early_stop:
                print("Early stopping")
                return True
            else:
              return False

            
    def score_append(self, data_num, train = True):
        if train:
            for l in range(self.len_l):
                self.train_losses[l].append(self.losses[l].get_score()*self.batch_size/data_num)
            for c in range(self.len_c):
                self.train_corrects[c].append(float(self.corrects[c].get_score()/data_num))
        else:
            for l in range(self.len_l):
                self.test_losses[l].append(self.losses[l].get_score()*self.batch_size/data_num)
            for c in range(self.len_c):
                self.test_corrects[c].append(float(self.corrects[c].get_score()/data_num))
    
    
    def get_value(self, train = True):
        if train:
            return self.train_losses, self.train_corrects
        else:
            return self.test_losses, self.test_corrects

In [18]:
# to manage all scores
loss = Score()
loss_s = Score()
loss_t = Score()
correct = Score()
score_loss = [loss, loss_s, loss_t]
score_correct = [correct]
sc = ScoreCalc(score_loss, score_correct, BATCH_SIZE)

In [None]:
# training and test
for epoch in range(EPOCH):
    print('epoch', epoch+1)
    
    c = 0
    for (inputs, labels) in tri_trainloader:
        c+=1
        print("\rIteration: {}/{}".format(c, len(tri_trainloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        fit.train(images, label, KD_LAMBDA)
    c = 0
    print("Train Loss Calc")
    for (inputs, labels) in tri_trainloader:
        c+=1
        print("\rIteration: {}/{}".format(c, len(tri_trainloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    sc.score_print(len(trainset))
    sc.score_append(len(trainset))
    sc.score_del()
    c = 0
    print("Test Loss Calc")
    for (inputs, labels) in tri_testloader:
        c+=1
        
        print("\rIteration: {}/{}".format(c, len(tri_testloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    if sc.score_print(len(testset), train = False):
      sc.score_append(len(testset), train = False)
      sc.score_del()
      break

    sc.score_append(len(testset), train = False)
    sc.score_del()
    
    scheduler.step()

epoch 1
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.3392240350196759, accuracy=0.90895
Test Loss Calc
Iteration: 100/100test mean loss=0.32718781583011153, accuracy=0.9114
Validation loss decreased (inf --> 0.327188).  Saving model ...
epoch 2
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.2870872473344207, accuracy=0.9234
Test Loss Calc
Iteration: 100/100test mean loss=0.2759371783584356, accuracy=0.9243
Validation loss decreased (0.327188 --> 0.275937).  Saving model ...
epoch 3
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.26858167931437493, accuracy=0.9305666666666667
Test Loss Calc
Iteration: 100/100test mean loss=0.2620486550591886, accuracy=0.9323
Validation loss decreased (0.275937 --> 0.262049).  Saving model ...
epoch 4
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.25554347747315964, accuracy=0.9344666666666667
Test Loss Calc
Iteration: 100/100test mean loss=0.24732208984903992, a

Iteration: 600/600train mean loss=0.21573784157633782, accuracy=0.9474166666666667
Test Loss Calc
Iteration: 100/100test mean loss=0.20785614093765617, accuracy=0.9491
Validation loss decreased (0.209101 --> 0.207856).  Saving model ...
epoch 34
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.21592179962744315, accuracy=0.94645
Test Loss Calc
Iteration: 100/100test mean loss=0.21080665020272135, accuracy=0.9464
EarlyStopping counter: 1 out of 10
epoch 35
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.2148519828543067, accuracy=0.9475666666666667
Test Loss Calc
Iteration: 100/100test mean loss=0.20755405128002166, accuracy=0.9486
Validation loss decreased (0.207856 --> 0.207554).  Saving model ...
epoch 36
Iteration: 600/600Train Loss Calc
Iteration: 600/600train mean loss=0.2174828307206432, accuracy=0.9484666666666667
Test Loss Calc
Iteration: 100/100test mean loss=0.21157521226443352, accuracy=0.9494
EarlyStopping counter: 1 out of 10
epoc

In [20]:
# get the scores
train_losses, train_corrects = sc.get_value()
test_losses, test_corrects = sc.get_value(train = False)

In [21]:
def plot_score(epoch, train_data, test_data, x_lim = None, y_lim = None, x_label = 'EPOCH', y_label = 'score', title = 'score', legend = ['train', 'test'], filename = 'test'):
    plt.figure(figsize=(6,6))
    
    if x_lim is None:
        x_lim = epoch
    if y_lim is None:
        y_lim = 1
        
    plt.plot(range(epoch), train_data)
    plt.plot(range(epoch), test_data, c='#00ff00')
    plt.xlim(0, x_lim)
    plt.ylim(0, y_lim)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend(legend)
    plt.title(title)
    plt.savefig(filename+'.png')
    plt.close()

    
def save_data(train_loss, test_loss, train_acc, test_acc, filename):
    with open(filename + '.txt', mode='w') as f:
        f.write("train mean loss={}\n".format(train_loss[-1]))
        f.write("test  mean loss={}\n".format(test_loss[-1]))
        f.write("train accuracy={}\n".format(train_acc[-1]))
        f.write("test  accuracy={}\n".format(test_acc[-1]))

In [22]:
torch.save(model_s.state_dict(), fnnname + '.pth')

In [23]:
save_data(train_losses[0], test_losses[0], train_corrects[0], test_corrects[0], result_name)

In [24]:
# output the glaphs of the scores

plot_score(150, train_losses[0], test_losses[0], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'total loss', filename = total_loss_name)

plot_score(150, train_losses[1], test_losses[1], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'softmax loss', filename = soft_loss_name)

plot_score(150, train_losses[2], test_losses[2], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'triplet loss', filename = tri_loss_name)

plot_score(150, train_corrects[0], test_corrects[0], y_lim = 1, y_label = 'ACCURACY', legend = ['train acc', 'test acc'], title = 'accuracy', filename = acc_name)

