In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# numpy and matplotlib
import numpy as np
import matplotlib.pyplot as plt

import os
from PIL import Image
import sys

In [2]:
# initialize for  each parameters
DATASET = 'CIFAR10'
BATCH_SIZE = 100
NUM_WORKERS = 2

WEIGHT_DECAY = 0.007
LEARNING_RATE = 0.01
MOMENTUM = 0.9

SCHEDULER_STEPS = 100
SCHEDULER_GAMMA = 0.1

SEED = 1

EPOCH = 150

KD_LAMBDA = 2.0

TRIPLET_MARGINE = 5.0

In [3]:
!pwd

/ssd_scratch/cvit/sashank.sridhar


In [4]:
# fixing the seed
torch.cuda.manual_seed_all(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [5]:
# check if gpu is available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("gpu mode")
else:
    device = torch.device("cpu")
    print("cpu mode")

gpu mode


In [6]:
# the name of results files
codename = 'kd_example'

fnnname = codename + "_fnn_model"

total_loss_name = codename + "_total_loss"
soft_loss_name = codename + "_soft_loss"
tri_loss_name = codename + "_tri_loss"
acc_name = codename + "_accuracy"

result_name = codename + "_result"

In [7]:
class Datasets(object):
    def __init__(self, dataset_name, batch_size = 100, num_workers = 2, transform = None, shuffle = True):
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        self.shuffle = shuffle
        
    def create(self, path = None):
        print("Dataset :",self.dataset_name)
        if self.transform is None:
                self.transform = transforms.Compose([transforms.ToTensor()])
        
        
        if path is None:
            path = "./"+self.dataset_name+"Dataset/data"
        
        
        if self.dataset_name == "MNIST":
            trainset = torchvision.datasets.MNIST(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.MNIST(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "FashionMNIST":
            trainset = torchvision.datasets.FashionMNIST(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.FashionMNIST(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "CIFAR10":
            trainset = torchvision.datasets.CIFAR10(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.CIFAR10(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(10))
            base_labels = trainset.classes
            
        elif self.dataset_name == "CIFAR100":
            trainset = torchvision.datasets.CIFAR100(root = path,
                                       train = True, download = True, transform = self.transform)
            testset = torchvision.datasets.CIFAR100(root = path,
                                                 train = False, download = True, transform = self.transform)
            classes = list(range(100))
            base_labels = trainset.classes
        
        else:
            raise KeyError("Unknown dataset: {}".format(self.dataset_name))
            
        
        trainloader = torch.utils.data.DataLoader(trainset, batch_size = self.batch_size,
                        shuffle = self.shuffle, num_workers = self.num_workers)
        
        if testset is not None:
            testloader = torch.utils.data.DataLoader(testset, batch_size = self.batch_size,
                        shuffle = False, num_workers = self.num_workers)
        else:
            testloader = None
            
            
        return [trainloader, testloader, classes, base_labels, trainset, testset]
    
    def worker_init_fn(self, worker_id):                                                          
        np.random.seed(worker_id)

In [8]:
# load the data set
instance_datasets = Datasets(DATASET, BATCH_SIZE, NUM_WORKERS, shuffle = False)
data_sets = instance_datasets.create()

#trainloader = data_sets[0]
#testloader = data_sets[1]
classes = data_sets[2]
based_labels = data_sets[3]
trainset = data_sets[4]
testset = data_sets[5]

Dataset : CIFAR10
Files already downloaded and verified
Files already downloaded and verified


In [9]:
class KDTripletDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
        data = dataset.data
        labels = dataset.targets
        if type(labels) is not torch.Tensor:
                labels = torch.tensor(labels)
        
        # make label set 0-9
        labels_set = set(labels.numpy())
        
        # make the indices excepted each classes
        label_to_indices = {label : np.where(labels.numpy() != label)[0] for label in labels_set}
        
        if self.dataset.train:
            self.negative_indices = label_to_indices
        else:
            self.negative_indices = [[np.random.choice(label_to_indices[labels[i].item()])] for i in range(len(data))]
        

            
    def __getitem__(self, index):
        if self.dataset.train:
            img1_2, label1_2 = self.dataset[index]
            if type(label1_2) is not torch.Tensor:
                label1_2 = torch.tensor(label1_2)
            img3, label3 = self.dataset[np.random.choice(self.negative_indices[label1_2.item()])]
        else:
            img1_2, label1_2 = self.dataset[index]
            img3, label3 = self.dataset[self.negative_indices[index][0]]
        
            
        return (img1_2, img3), (label1_2, label3)
    
    def __len__(self):
        return len(self.dataset)

In [10]:
# use the KD Triplet Dataset by using above dataset
tri_trainset = KDTripletDataset(trainset)
tri_testset = KDTripletDataset(testset)
tri_trainloader = torch.utils.data.DataLoader(tri_trainset, batch_size = BATCH_SIZE, shuffle = True, num_workers = NUM_WORKERS)
tri_testloader = torch.utils.data.DataLoader(tri_testset, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS)

In [11]:
class Net_teacher(nn.Module):
    def __init__(self):
        super(Net_teacher, self).__init__()
        self.conv1 = nn.Conv2d(3,32,3, padding=1)
        self.conv2 = nn.Conv2d(32,32,3, padding=1)
        self.conv3 = nn.Conv2d(32,64, 3, padding=1)
        self.conv4 = nn.Conv2d(64,64, 3, padding=1)
        self.conv5 = nn.Conv2d(64,128, 3, padding=1)
        self.fc1 = nn.Linear(128*4*4, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)

        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.batchnorm3 = nn.BatchNorm2d(64)
        self.batchnorm4 = nn.BatchNorm2d(64)
        self.batchnorm5 = nn.BatchNorm2d(128)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()


    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.batchnorm1(x)
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.batchnorm2(x)
        x = self.pool(x)

        x = self.relu(self.conv3(x))
        x = self.batchnorm3(x)
        x = self.relu(self.conv4(x))
        x = self.batchnorm4(x)
        x = self.relu(self.conv5(x))
        x = self.batchnorm5(x)
        x = self.pool(x)

        x = x.view(-1, 128 * 4 * 4)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.fc3(x)
        return x


class Net_student(nn.Module):
    def __init__(self):
        super(Net_student, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 10)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)

        x = self.relu(self.conv2(x))
        x = self.pool(x)

        x = self.relu(self.conv3(x))
        x = self.pool(x)
        
        x = x.view(-1, 64 * 4 * 4)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
# network and criterions
model_t = Net_teacher().to(device)
model_s = Net_student().to(device)

model_t.load_state_dict(torch.load("cnn_alex.pkl"))

optimizer = optim.SGD(model_s.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEPS, gamma=SCHEDULER_GAMMA)

soft_criterion = nn.CrossEntropyLoss()
triplet_loss = nn.TripletMarginLoss(margin=TRIPLET_MARGINE)

In [13]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [14]:
class NetworkFit(object):
    def __init__(self, model_t, model_s, optimizer, soft_criterion, triplet_loss):
        self.model_t = model_t
        self.model_s = model_s
        self.optimizer = optimizer
        
        self.soft_criterion = soft_criterion
        self.triplet_loss = triplet_loss
        
        self.model_t.eval()
        

    def train(self, inputs, labels, kd_lambda = 2.0):
        self.optimizer.zero_grad()
        self.model_s.train()

        img1_t = inputs[0]
        img2_s = inputs[1]
        img3_s = inputs[2]
        
        label1_t = labels[0]
        label2_s = labels[1]
        label3_s = labels[2]
        
        out1_t = self.model_t(img1_t)
        out2_s = self.model_s(img2_s)
        out3_s = self.model_s(img3_s)
        
        soft_loss = self.soft_criterion(out2_s, label2_s)
        trip_loss = self.triplet_loss(out1_t, out2_s, out3_s)

        loss = soft_loss + kd_lambda*trip_loss

        loss.backward()
        self.optimizer.step()
            
            
    def test(self, inputs, labels, kd_lambda = 2.0):
        self.model_s.eval()
        
        img1_t = inputs[0]
        img2_s = inputs[1]
        img3_s = inputs[2]
        
        label1_t = labels[0]
        label2_s = labels[1]
        label3_s = labels[2]
        
        out1_t = self.model_t(img1_t)
        out2_s = self.model_s(img2_s)
        out3_s = self.model_s(img3_s)
        
        soft_loss = self.soft_criterion(out2_s, label2_s)
        trip_loss = self.triplet_loss(out1_t, out2_s, out3_s)

        loss = soft_loss + kd_lambda*trip_loss
        
        _, predicted = out2_s.max(1)
        correct = (predicted == label2_s).sum().item()
        
        return [loss.item(), soft_loss.item(), trip_loss.item()], [correct]
        
        
    

In [15]:
# fit for training and test
fit = NetworkFit(model_t, model_s, optimizer, soft_criterion, triplet_loss)

In [16]:
class Score(object):
    def __init__(self, score = 0):
        self.score = score
        
    def sum_score(self, score):
        self.score += score
    
    def set_score(self, score):
        self.score = score
    
    def init_score(self):
        self.score = 0
    
    def get_score(self):
        return self.score

In [17]:
class ScoreCalc(object):
    def __init__(self, losses, corrects, batch_size):
        self.losses = losses
        self.corrects = corrects
        
        self.batch_size = batch_size
        
        self.len_l = len(losses)
        self.len_c = len(corrects)
        
        self.train_losses = [[] for l in range(self.len_l)]
        self.train_corrects = [[] for c in range(self.len_c)]
        
        self.test_losses = [[] for l in range(self.len_l)]
        self.test_corrects = [[] for c in range(self.len_c)]

        patience = 50

        self.early_stopping = EarlyStopping(patience=patience, verbose=True)
       
    
    def calc_sum(self, losses, corrects):
        if len(losses) != len(self.losses):
            print("warning : len(losses) != len(self.losses)")
            sys.exit()
        if len(corrects) != len(self.corrects):
            print("warning : len(corrects) != len(self.corrects)")
            sys.exit()
        
        for l in range(self.len_l):
            self.losses[l].sum_score(losses[l])
        
        for c in range(self.len_c):
            self.corrects[c].sum_score(corrects[c])
        
        return self.losses, self.corrects
    
    
    def score_del(self):
        for loss in self.losses:
            loss.init_score()
        for correct in self.corrects:
            correct.init_score()

        
    def score_print(self, data_num, train = True):
        if train:
            print("train mean loss={}, accuracy={}".format(self.losses[0].get_score()*self.batch_size/data_num, float(self.corrects[0].get_score()/data_num)))
        else:
            
            print("test mean loss={}, accuracy={}".format(self.losses[0].get_score()*self.batch_size/data_num, float(self.corrects[0].get_score()/data_num)))
            self.early_stopping(self.losses[0].get_score()*self.batch_size/data_num, model_s)
        
            if self.early_stopping.early_stop:
                print("Early stopping")
                return True
            else:
              return False

            
    def score_append(self, data_num, train = True):
        if train:
            for l in range(self.len_l):
                self.train_losses[l].append(self.losses[l].get_score()*self.batch_size/data_num)
            for c in range(self.len_c):
                self.train_corrects[c].append(float(self.corrects[c].get_score()/data_num))
        else:
            for l in range(self.len_l):
                self.test_losses[l].append(self.losses[l].get_score()*self.batch_size/data_num)
            for c in range(self.len_c):
                self.test_corrects[c].append(float(self.corrects[c].get_score()/data_num))
    
    
    def get_value(self, train = True):
        if train:
            return self.train_losses, self.train_corrects
        else:
            return self.test_losses, self.test_corrects

In [18]:
# to manage all scores
loss = Score()
loss_s = Score()
loss_t = Score()
correct = Score()
score_loss = [loss, loss_s, loss_t]
score_correct = [correct]
sc = ScoreCalc(score_loss, score_correct, BATCH_SIZE)

In [19]:
# training and test
for epoch in range(EPOCH):
    print('epoch', epoch+1)
    
    c = 0
    for (inputs, labels) in tri_trainloader:
        c+=1
        print("\rIteration: {}/{}".format(c, len(tri_trainloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        fit.train(images, label, KD_LAMBDA)
    c = 0
    print("Train Loss Calc")
    for (inputs, labels) in tri_trainloader:
        c+=1
        print("\rIteration: {}/{}".format(c, len(tri_trainloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    sc.score_print(len(trainset))
    sc.score_append(len(trainset))
    sc.score_del()
    c = 0
    print("Test Loss Calc")
    for (inputs, labels) in tri_testloader:
        c+=1
        
        print("\rIteration: {}/{}".format(c, len(tri_testloader)), end="")
        img1_t = inputs[0].to(device)
        img2_s = inputs[0].to(device)
        img3_s = inputs[1].to(device)
        
        images = (img1_t, img2_s, img3_s)
        
        label1_t = labels[0].to(device)
        label2_s = labels[0].to(device)
        label3_s = labels[1].to(device)
        
        label = (label1_t, label2_s, label3_s)
        
        losses, corrects = fit.test(images, label, KD_LAMBDA)
        
        sc.calc_sum(losses, corrects)
    
    if sc.score_print(len(testset), train = False):
      sc.score_append(len(testset), train = False)
      sc.score_del()
      break

    sc.score_append(len(testset), train = False)
    sc.score_del()
    
    scheduler.step()

epoch 1
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=6.488371030807495, accuracy=0.45498
Test Loss Calc
Iteration: 100/100test mean loss=6.4390854644775395, accuracy=0.4573
Validation loss decreased (inf --> 6.439085).  Saving model ...
epoch 2
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=5.101740472316742, accuracy=0.57018
Test Loss Calc
Iteration: 100/100test mean loss=5.228070287704468, accuracy=0.5644
Validation loss decreased (6.439085 --> 5.228070).  Saving model ...
epoch 3
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=4.518249933719635, accuracy=0.62788
Test Loss Calc
Iteration: 100/100test mean loss=4.66085821390152, accuracy=0.6186
Validation loss decreased (5.228070 --> 4.660858).  Saving model ...
epoch 4
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=4.079832461357117, accuracy=0.67084
Test Loss Calc
Iteration: 100/100test mean loss=4.296921043395996, accuracy=0.6567
Validation loss de

Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.734136913537979, accuracy=0.78008
Test Loss Calc
Iteration: 100/100test mean loss=3.314646668434143, accuracy=0.7323
EarlyStopping counter: 4 out of 50
epoch 36
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.7277489817142486, accuracy=0.7769
Test Loss Calc
Iteration: 100/100test mean loss=3.260294930934906, accuracy=0.7259
EarlyStopping counter: 5 out of 50
epoch 37
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.499296494960785, accuracy=0.8004
Test Loss Calc
Iteration: 100/100test mean loss=3.034727191925049, accuracy=0.7473
Validation loss decreased (3.036364 --> 3.034727).  Saving model ...
epoch 38
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.639820502281189, accuracy=0.78386
Test Loss Calc
Iteration: 100/100test mean loss=3.1533086824417116, accuracy=0.7398
EarlyStopping counter: 1 out of 50
epoch 39
Iteration: 500/500Train Loss Calc
Iteratio

Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.4985981373786927, accuracy=0.80016
Test Loss Calc
Iteration: 100/100test mean loss=3.0220175838470458, accuracy=0.7529
EarlyStopping counter: 9 out of 50
epoch 71
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.7113001449108123, accuracy=0.77848
Test Loss Calc
Iteration: 100/100test mean loss=3.2481325268745422, accuracy=0.73
EarlyStopping counter: 10 out of 50
epoch 72
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.5643739671707153, accuracy=0.79272
Test Loss Calc
Iteration: 100/100test mean loss=3.143607475757599, accuracy=0.7427
EarlyStopping counter: 11 out of 50
epoch 73
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.59882865023613, accuracy=0.79344
Test Loss Calc
Iteration: 100/100test mean loss=3.153661539554596, accuracy=0.743
EarlyStopping counter: 12 out of 50
epoch 74
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=2.61

Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.593693876504898, accuracy=0.87418
Test Loss Calc
Iteration: 100/100test mean loss=2.3360882008075716, accuracy=0.8028
Validation loss decreased (2.352602 --> 2.336088).  Saving model ...
epoch 106
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.5980384271144867, accuracy=0.87582
Test Loss Calc
Iteration: 100/100test mean loss=2.3665697371959684, accuracy=0.8001
EarlyStopping counter: 1 out of 50
epoch 107
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.55987118434906, accuracy=0.87798
Test Loss Calc
Iteration: 100/100test mean loss=2.3430869722366334, accuracy=0.7992
EarlyStopping counter: 2 out of 50
epoch 108
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.593987108707428, accuracy=0.87762
Test Loss Calc
Iteration: 100/100test mean loss=2.372348563671112, accuracy=0.8031
EarlyStopping counter: 3 out of 50
epoch 109
Iteration: 500/500Train Loss Calc
I

Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.3932449791431427, accuracy=0.90046
Test Loss Calc
Iteration: 100/100test mean loss=2.4478215837478636, accuracy=0.7937
EarlyStopping counter: 15 out of 50
epoch 141
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.3229564212560654, accuracy=0.90896
Test Loss Calc
Iteration: 100/100test mean loss=2.3775306165218355, accuracy=0.8036
EarlyStopping counter: 16 out of 50
epoch 142
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.323344449520111, accuracy=0.90878
Test Loss Calc
Iteration: 100/100test mean loss=2.3775234150886537, accuracy=0.8006
EarlyStopping counter: 17 out of 50
epoch 143
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean loss=1.3039958382844925, accuracy=0.91174
Test Loss Calc
Iteration: 100/100test mean loss=2.382228889465332, accuracy=0.8033
EarlyStopping counter: 18 out of 50
epoch 144
Iteration: 500/500Train Loss Calc
Iteration: 500/500train mean

In [20]:
# get the scores
train_losses, train_corrects = sc.get_value()
test_losses, test_corrects = sc.get_value(train = False)

In [21]:
def plot_score(epoch, train_data, test_data, x_lim = None, y_lim = None, x_label = 'EPOCH', y_label = 'score', title = 'score', legend = ['train', 'test'], filename = 'test'):
    plt.figure(figsize=(6,6))
    
    if x_lim is None:
        x_lim = epoch
    if y_lim is None:
        y_lim = 1
        
    plt.plot(range(epoch), train_data)
    plt.plot(range(epoch), test_data, c='#00ff00')
    plt.xlim(0, x_lim)
    plt.ylim(0, y_lim)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend(legend)
    plt.title(title)
    plt.savefig(filename+'.png')
    plt.close()

    
def save_data(train_loss, test_loss, train_acc, test_acc, filename):
    with open(filename + '.txt', mode='w') as f:
        f.write("train mean loss={}\n".format(train_loss[-1]))
        f.write("test  mean loss={}\n".format(test_loss[-1]))
        f.write("train accuracy={}\n".format(train_acc[-1]))
        f.write("test  accuracy={}\n".format(test_acc[-1]))

In [22]:
torch.save(model_s.state_dict(), fnnname + '.pth')

In [23]:
save_data(train_losses[0], test_losses[0], train_corrects[0], test_corrects[0], result_name)

In [24]:
# output the glaphs of the scores

plot_score(150, train_losses[0], test_losses[0], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'total loss', filename = total_loss_name)

plot_score(150, train_losses[1], test_losses[1], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'softmax loss', filename = soft_loss_name)

plot_score(150, train_losses[2], test_losses[2], y_lim = 5.0, y_label = 'LOSS', legend = ['train loss', 'test loss'], title = 'triplet loss', filename = tri_loss_name)

plot_score(150, train_corrects[0], test_corrects[0], y_lim = 1, y_label = 'ACCURACY', legend = ['train acc', 'test acc'], title = 'accuracy', filename = acc_name)

