In [1]:
import torch
from torchvision.datasets import FashionMNIST, MNIST
from torchvision.datasets import SVHN, CIFAR10, CIFAR100
from torchvision import transforms
from utils.datasets import SiameseDataset, TripletDataset, CrossEntropyDataset
import models.resnet, models.siamesenet, models.triplenet, models.resnet_sis
from utils.losses import ContrastiveLoss, TripletLoss
from utils.trainer_sch import fit
import torch.optim as optim
from datetime import datetime
import warnings
import torch.nn as nn
import os

# Configuration
torch.cuda.set_device(1)
epoch=10
lr=0.001
batch_size=256
save_type="test"

input_data = ['cifar10','cifar100','svhn']
model_list = ['siam', 'triplet']

cuda = torch.cuda.is_available()
kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}
warnings.filterwarnings(action='ignore') # warning message off
which_resnet = models.resnet.ResNet34

class DataConfig():
    def __init__(self, data_name):
        if data_name == 'fashionmnist':
            self.mean, self.std = (0.28604059698879553,), (0.35302424451492237,)
            self.path = './data/FashionMNIST'
            self.dataset_f = FashionMNIST
            
        elif data_name == 'mnist':
            self.mean, self.std = (0.1307,), (0.3081,)
            self.path = './data/MNIST'
            self.dataset_f = MNIST
            
        elif data_name == 'svhn':
            self.mean, self.std = [0.4380, 0.4440, 0.4730], [0.1751, 0.1771, 0.1744]
            self.path = './data/SVHN'
            self.dataset_f = SVHN
            
        elif data_name == 'cifar10':
            self.mean, self.std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
            self.path = './data/CIFAR10'
            self.dataset_f = CIFAR10

        elif data_name == 'cifar100':
            self.mean, self.std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
            self.path = './data/CIFAR100'
            self.dataset_f = CIFAR100
            
        if data_name == 'svhn':
            self.train_dataset = self.data_init('train')
            self.test_dataset = self.data_init('test')
        else:
            self.train_dataset = self.data_init(True)
            self.test_dataset = self.data_init(False)
            
    def data_init(self, train_opt):
        try:
            return self.dataset_f(self.path, 
                                            train=train_opt, 
                                            download=True, 
                                            transform=transforms.Compose([transforms.ToTensor(),
                                                                                                 transforms.Normalize(self.mean, self.std)]
                                            ))
        except:
            return self.dataset_f(self.path, 
                                            split=train_opt, 
                                            download=True, 
                                            transform=transforms.Compose([transforms.ToTensor(),
                                                                                                 transforms.Normalize(self.mean, self.std)]
                                            ))

class ModelConfig():
    def __init__(self, model_name):
        if model_name == 'siam':
            self.dataset_for_model = SiameseDataset
            self.model = models.siamesenet.SiameseNet
            self.batch_size = batch_size
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = ContrastiveLoss(self.margin)
            self.model_flag = True
            
        elif model_name == 'triplet':
            self.dataset_for_model = TripletDataset
            self.model = models.triplenet.TripletNet
            self.batch_size = 128
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = TripletLoss(self.margin)
            self.model_flag = True
            
        elif model_name == 'baseline' or model_name == 'odin':
            self.dataset_for_model = CrossEntropyDataset
            self.model = models.resnet_sis.ResNet
            self.batch_size = batch_size
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = nn.CrossEntropyLoss()
            self.model_flag = False
            
# Excution (each data, each model)
for data_name in input_data:
    print(epoch)
    print(lr)
    # creating dataset
    dc = DataConfig(data_name)
    num_of_channel = len(dc.mean)
    try:
        num_of_classes = len(set(dc.train_dataset.labels))
    except AttributeError:
        try :
            num_of_classes = len(set(dc.train_dataset.classes))
        except AttributeError:
            num_of_classes = len(set(dc.train_dataset.train_labels))
    for model_name in model_list:
        mc = ModelConfig(model_name)

        # creating data loader
        train_dataset = mc.dataset_for_model(dc.train_dataset) 
        test_dataset = mc.dataset_for_model(dc.test_dataset)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=mc.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=mc.batch_size, shuffle=False, **kwargs)

        # training model
        pre_resnet = models.resnet.ResNet34(num_c=100)
        embedding_resnet=pre_resnet#.embedding_net
        model = mc.model(embedding_resnet)
        if cuda:
            model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=mc.lr, weight_decay=0.0001)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch*len(train_loader), eta_min=0)
        print(f"[{datetime.now()}] {model_name}_resnet_{data_name} Training Start")
        model_nname = model_name + '_resnet_' + data_name
        fit(train_loader, test_loader, model, mc.loss_f, optimizer, scheduler, mc.n_epochs, cuda, mc.log_interval, model_nname=model_nname, save_type=save_type, model_flag=mc.model_flag)
        print(f"[{datetime.now()}] {model_name}_resnet_{data_name} Training End", end='\n\n')
        

10
0.001
Files already downloaded and verified
Files already downloaded and verified
[2021-02-21 16:25:52.519959] siam_resnet_cifar10 Training Start
Epoch: 1/10. Train loss: 0.1409,  Validation loss: 0.1415
Epoch: 2/10. Train loss: 0.0007,  Validation loss: 0.1269
Epoch: 3/10. Train loss: 0.0006,  Validation loss: 0.1253
Epoch: 4/10. Train loss: 0.0006,  Validation loss: 0.1243
Epoch: 5/10. Train loss: 0.0006,  Validation loss: 0.1237
Epoch: 6/10. Train loss: 0.0006,  Validation loss: 0.1229
Epoch: 7/10. Train loss: 0.0006,  Validation loss: 0.1228
Epoch: 8/10. Train loss: 0.0006,  Validation loss: 0.1226
Epoch: 9/10. Train loss: 0.0006,  Validation loss: 0.1225
Epoch: 10/10. Train loss: 0.0006,  Validation loss: 0.1225
[2021-02-21 16:40:27.850204] siam_resnet_cifar10 Training End

[2021-02-21 16:40:28.424667] triplet_resnet_cifar10 Training Start
Epoch: 1/10. Train loss: 0.0260,  Validation loss: 0.9391
Epoch: 2/10. Train loss: 0.0024,  Validation loss: 0.9056
Epoch: 3/10. Train loss:

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


Using downloaded and verified file: ./data/SVHN\test_32x32.mat
[2021-02-21 17:41:35.484012] siam_resnet_svhn Training Start
Epoch: 1/10. Train loss: 0.0338,  Validation loss: 24.9264
Epoch: 2/10. Train loss: 0.0016,  Validation loss: 0.2995
Epoch: 3/10. Train loss: 0.0005,  Validation loss: 4.4504
Epoch: 4/10. Train loss: 0.0005,  Validation loss: 3.2789
Epoch: 5/10. Train loss: 0.0005,  Validation loss: 4.7312
Epoch: 6/10. Train loss: 0.0005,  Validation loss: 2.5667
Epoch: 7/10. Train loss: 0.0005,  Validation loss: 1.2172
Epoch: 8/10. Train loss: 0.0005,  Validation loss: 0.4355
Epoch: 9/10. Train loss: 0.0005,  Validation loss: 0.3472
Epoch: 10/10. Train loss: 0.0005,  Validation loss: 0.4000
[2021-02-21 18:00:10.153345] siam_resnet_svhn Training End

[2021-02-21 18:00:10.917301] triplet_resnet_svhn Training Start
Epoch: 1/10. Train loss: 0.0117,  Validation loss: 0.9903
Epoch: 2/10. Train loss: 0.0017,  Validation loss: 0.9368
Epoch: 3/10. Train loss: 0.0016,  Validation loss: 0.