In [1]:
# prerequisites
import os
import torch
import time
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import *
from models import SITE

# Device configuration
torch.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
batch_size = 64
model_path = './models/CIFAR'
dataset_path = '../data'

image_size = 128
n_epoch = 100

Transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.Resize((image_size,image_size)),
     transforms.ToTensor(),
     transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                          std = [0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True,
                                        download=True, transform=Transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)

valset = torchvision.datasets.CIFAR10(root=dataset_path, train=False,
                                       download=True, transform=Transform)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                         shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [3]:
site = SITE().to(device)
print(f'Models are properly built! There are totally {get_n_params(site)} parameters.')

optimizer = optim.Adam(site.parameters(), lr = 1e-2)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = [50, 80], gamma = 0.5)

celoss = nn.CrossEntropyLoss()
bceloss = nn.BCELoss()

Models are properly built! There are totally 7819592 parameters.


In [None]:
for epoch in range(n_epoch):
    Loss_cls = []
    Loss_rec = []
    equality = 0
    start = time.time()
    
    # generate new prototypes
    prototype = get_prototype(train_loader)
    site.train()
    
    for batch_idx, (image, label) in tqdm(enumerate(train_loader)):
        
        optimizer.zero_grad()
        
        image, label = image.to(device), label.to(device)
        
        theta, gamma = get_theta(image.shape[0], get_reverse = True)
        t_image = transform(image, theta, dataset = 'CIFAR')
        
        t_feature, t_W, t_pred  = site.for_training(t_image)
        t_W_t = transform_W(t_W, gamma, dataset = 'CIFAR')
        
        # Classification loss
        loss_cls = celoss(t_pred, label)
        Loss_cls.append(loss_cls)

        target = sample_prototype(prototype, label, dataset = 'CIFAR')
        with torch.no_grad():
            target_feature, _ = site.backbone(target.view(-1, 3, 128, 128))
            target_feature = target_feature.view(-1, 10, 10, 16, 16)
        loss_rec = 5*bceloss((t_W + 1)/2, (target_feature + 1)/2)
        Loss_rec.append(loss_rec)
        
        loss = loss_cls + loss_rec
        loss.backward()
        optimizer.step()
        
        equality += (t_pred.max(1)[1] == label).float().mean()
    
    train_accuracy = equality / (batch_idx + 1)
    scheduler.step()
    
    equality = 0
    site.eval()
    for batch_idx, (image, label) in enumerate(val_loader):
        
        image, label= image.to(device), label.to(device)
        
        theta, gamma = get_theta(image.shape[0], get_reverse = True)
        t_image = transform(image, theta, dataset = 'CIFAR')
        
        with torch.no_grad():
            t_pred = site(image)
            
        equality += (t_pred.max(1)[1] == label).float().mean()
        
    val_accuracy = equality / (batch_idx + 1)
    
    print('epoch: {}, loss: {:.3f}/{:.3f}, train_acc: {:.4f}, val_acc: {:.4f}, time: {:.2f}'.format(
        epoch + 1,
        torch.FloatTensor(Loss_cls).mean(),
        torch.FloatTensor(Loss_rec).mean(),
        train_accuracy,
        val_accuracy,
        time.time() - start))

13it [00:12,  1.06it/s]