In [None]:
import torch
from torchvision.datasets import FashionMNIST, MNIST
from torchvision.datasets import SVHN, CIFAR10, CIFAR100
from torchvision import transforms
from utils.datasets import SiameseDataset, TripletDataset, CrossEntropyDataset
import models.resnet, models.siamesenet, models.triplenet, models.resnet_sis
from utils.losses import ContrastiveLoss, TripletLoss
from utils.trainer import fit
import torch.optim as optim
from datetime import datetime
import warnings
import torch.nn as nn
import os

torch.cuda.set_device(0)
epoch=100
lr=0.001
batch_size=32

# Configuration
input_data = ['fashionmnist', 'mnist']
model_list = [ 'siam', 'baseline','triplet']


cuda = torch.cuda.is_available()
kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}
warnings.filterwarnings(action='ignore') # warning message off
which_resnet = models.resnet.ResNet34


class DataConfig():
    def __init__(self, data_name):
        if data_name == 'fashionmnist':
            self.mean, self.std = (0.28604059698879553,), (0.35302424451492237,)
            self.path = '../data/FashionMNIST'
            self.dataset_f = FashionMNIST
            
        elif data_name == 'mnist':
            self.mean, self.std = (0.1307,), (0.3081,)
            self.path = '../data/MNIST'
            self.dataset_f = MNIST
            
        elif data_name == 'svhn':
            self.mean, self.std = [0.4380, 0.4440, 0.4730], [0.1751, 0.1771, 0.1744]
            self.path = '../data/SVHN'
            self.dataset_f = SVHN
            
        elif data_name == 'cifar10':
            self.mean, self.std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
            self.path = '../data/CIFAR10'
            self.dataset_f = CIFAR10

        elif data_name == 'cifar100':
            self.mean, self.std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
            self.path = '../data/CIFAR100'
            self.dataset_f = CIFAR100
            
        self.train_dataset = self.data_init(True)
        self.test_dataset = self.data_init(False)
            
    def data_init(self, train_opt):
        return self.dataset_f(self.path, 
                                        train=train_opt, 
                                        download=True, 
                                        transform=transforms.Compose([transforms.ToTensor(),
                                                                                             transforms.Normalize(self.mean, self.std)]
                                        ))

            
class ModelConfig():
    def __init__(self, model_name):
        if model_name == 'siam':
            self.dataset_for_model = SiameseDataset
            self.model = models.siamesenet.SiameseNet
            self.batch_size = batch_size
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = ContrastiveLoss(self.margin)
            self.model_flag = True
            
        elif model_name == 'triplet':
            self.dataset_for_model = TripletDataset
            self.model = models.triplenet.TripletNet
            self.batch_size = batch_size
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = TripletLoss(self.margin)
            self.model_flag = True
            
        elif model_name == 'baseline' or model_name == 'odin':
            self.dataset_for_model = CrossEntropyDataset
            self.model = models.resnet_sis.ResNet
            self.batch_size = batch_size
            self.margin = 1.
            self.lr = lr
            self.n_epochs = epoch
            self.log_interval = 50
            self.loss_f = nn.CrossEntropyLoss()
            self.model_flag = False

In [None]:
# Excution (each data, each model)
for data_name in input_data:
    print(epoch)
    print(lr)
    # creating dataset
    dc = DataConfig(data_name)
    num_of_channel = len(dc.mean)
    try:
        num_of_classes = len(set(dc.train_dataset.labels))
    except AttributeError:
        try :
            num_of_classes = len(set(dc.train_dataset.train_labels))
        except AttributeError:
            num_of_classes = len(set(dc.train_dataset.classes))
    for model_name in model_list:
        mc = ModelConfig(model_name)

        # creating data loader
        train_dataset = mc.dataset_for_model(dc.train_dataset) 
        test_dataset = mc.dataset_for_model(dc.test_dataset)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=mc.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=mc.batch_size, shuffle=False, **kwargs)

        # training model
        embedding_resnet = which_resnet(input_channels=num_of_channel, num_c=num_of_classes, model_flag=mc.model_flag)
        model = mc.model(embedding_resnet)
        if cuda:
            model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=mc.lr)

        print(f"[{datetime.now()}] {model_name}_resnet_{data_name} Training Start")
        model_nname = model_name + '_resnet_' + data_name
        fit(train_loader, test_loader, model, mc.loss_f, optimizer, mc.n_epochs, cuda, mc.log_interval, model_nname=model_nname, model_flag=mc.model_flag)
        torch.save(model, 'trained_models/'+ model_nname + '.pth')
        print(f"[{datetime.now()}] {model_name}_resnet_{data_name} Training End", end='\n\n')