In [1]:
import os
import math
import logging

import datetime
import time

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import torchvision.datasets

from utils import data
import torchvision.transforms as transforms

from tensorboardX import SummaryWriter
import shutil
import _pickle as cPickle
from sklearn import preprocessing
import subprocess

from IPython.core.debugger import Tracer

In [2]:
# set up parameters
class Options:
    def __init__(self):
        self.seed_val = 0 # random seed val
        self.num_train_sup = 228 # number of labeled train samples
        self.batch_size = 6 # batch size
        self.labeled_batch_size = 3 # number of labeled samples in a batch
        self.device = 0 # gpu id
        
        self.lr = [0.001, 0.15] # learning rate for adam and then and then initial learning rate for sgd respectively
        self.num_epochs = 500 # number of training epochs
        self.weight_decay = 5e-4 # weight decay
        
        self.mount_point = '/home/onu/Desktop/' # change this to your mount_point
        self.datadir = 'data-local/images/cifar/cifar10/by-image' # dataset directory
        self.labels = 'data-local/labels/cifar10/4000_balanced_labels/00.txt' # label directory
        self.log_dir = os.path.join(self.mount_point,'logs') # log directory
        self.model_dir = os.path.join(self.mount_point,'models') # log 
        self.exp_name = 'cifar10_nl_%i_allconv13_seed_%i'%(self.num_train_sup, self.seed_val) # name of experiments
        
        self.train_subdir = 'train+val'
        self.eval_subdir = 'test'
        self.num_classes = 17
        self.workers = 4
        
        self.alpha_reconst = 0.5 # weight for reconstruction loss
        self.alpha_pn = 1.0 # weight for path normalization loss
        self.alpha_kl = 0.5 # weight for kl loss
        self.alpha_bnmm = 0.5 # weight for moment matching loss when doing batchnorm
        
        self.use_bias = True # add bias after batchnorm
        self.use_bn = True # use batch norm
        self.do_topdown = True # do topdown
        self.do_pn = True # do path normalization
        self.do_bnmm = True # do moment matching for batchnorm

opt = Options()

In [3]:
# set device
th.cuda.set_device(opt.device)
device = th.device("cuda:%i"%opt.device if (opt.device >= 0) else "cpu")

def gpu_device(device=0):
    try:
        _ = th.tensor([1, 2, 3], device=th.device('cuda', device))
    except ValueError:
        return None
    return th.device('cuda', device)

assert gpu_device(opt.device), 'No GPU device found!'

In [4]:
# make required folders
if not os.path.exists(opt.log_dir):
    os.makedirs(opt.log_dir)
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
if not os.path.exists(opt.model_dir):
    os.makedirs(opt.model_dir)
if not os.path.exists(os.path.join(opt.mount_point,'datasets')):
    os.makedirs(os.path.join(opt.mount_point,'datasets'))

In [5]:
# set logging option
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(message)s')
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)

hdlr = logging.FileHandler(os.path.join(opt.log_dir, '{}.log'.format(opt.exp_name)))
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logging.info(opt)

writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

2019-04-09 17:34:52,992 - <__main__.Options object at 0x7fc910b27e80>


In [6]:
# set losses
NO_LABEL = -1
criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda()
L2_loss = nn.MSELoss(size_average=False, reduce=False, reduction='mean').cuda()



In [7]:
# some util functions
def get_acc(output, label):
    pred = th.argmax(output, dim=1, keepdim=False)
    correct = th.mean((pred == label).type(th.FloatTensor))
    return correct

class AverageMeterSet:
    def __init__(self):
        self.meters = {}

    def __getitem__(self, key):
        return self.meters[key]

    def update(self, name, value, n=1):
        if not name in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, postfix=''):
        return {name + postfix: meter.val for name, meter in self.meters.items()}

    def averages(self, postfix='/avg'):
        return {name + postfix: meter.avg for name, meter in self.meters.items()}

    def sums(self, postfix='/sum'):
        return {name + postfix: meter.sum for name, meter in self.meters.items()}

    def counts(self, postfix='/count'):
        return {name + postfix: meter.count for name, meter in self.meters.items()}


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, format):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)

In [8]:
from data.dataset import HoromaDataset

from comet_ml import OfflineExperiment
import json
import argparse
from models import *
from models.clustering import *
from utils.ali_utils import *
from utils.utils import *
from utils.utils import load_datasets
from utils.constants import Constants
from data.dataset import HoromaDataset
import torch
path_to_model = None
config_key = 'HALI'
config = 'HALI'

with open(Constants.CONFIG_PATH, 'r') as f:
    configuration = json.load(f)[config_key]

# Parse configuration file
clustering_model = configuration['cluster_model']
encoding_model = configuration['enc_model']
batch_size = configuration['batch_size']
seed = configuration['seed']
n_epochs = configuration['n_epochs']
train_subset = configuration['train_subset']
train_split = configuration['train_split']
valid_split = configuration['valid_split']
train_labeled_split = configuration['train_labeled_split']
encode = configuration['encode']
cluster = configuration['cluster']
flattened = False  # Default
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set all seeds for full reproducibility
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

datapath = Constants.DATAPATH
parser = argparse.ArgumentParser()

train = HoromaDataset(datapath, split=train_split, subset=train_subset,
                      flattened=flattened)
labeled = HoromaDataset(datapath, split=train_labeled_split, subset=train_subset,
                        flattened=flattened)
valid_data = HoromaDataset(
    datapath, split=valid_split, subset=train_subset, flattened=flattened)

train_label_indices = labeled.targets
valid_indices = valid_data.targets

print("Shape of training set: ", train.data.shape)
print("Shape of validation set: ", labeled.data.shape)
print("Shape of validation set: ", valid_data.data.shape)

Shape of training set:  (152228, 3, 32, 32)
Shape of validation set:  (228, 3, 32, 32)
Shape of validation set:  (252, 3, 32, 32)


In [9]:
train_loader = DataLoader(train, batch_size=opt.batch_size, shuffle=True)
labeled_loader = DataLoader(labeled, batch_size=opt.labeled_batch_size, shuffle=True)
eval_loader = DataLoader(valid_data, batch_size=opt.labeled_batch_size, shuffle=True)


In [12]:
n_iterations = np.floor(labeled.data.shape[0]/opt.labeled_batch_size)
print(n_iterations)

76.0


In [None]:
opt.

In [15]:
net = NRM('AllConv13', batch_size=opt.batch_size//2, num_class=17, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_pn=opt.do_pn, do_bnmm=opt.do_bnmm).to(device)
net.apply(weights_init)
                


  after removing the cwd from sys.path.


NRM(
  (features): Sequential(
    (conv0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batchnorm0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bias0): BiasAdder()
    (relu0): LeakyReLU(negative_slope=0.1)
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batchnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bias1): BiasAdder()
    (relu1): LeakyReLU(negative_slope=0.1)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batchnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bias2): BiasAdder()
    (relu2): LeakyReLU(negative_slope=0.1)
    (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (dropout3): Dropout(p=0.5)
    (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding

In [16]:
trainer = th.optim.Adam(net.parameters(), opt.lr[0], weight_decay=opt.weight_decay)

prev_time = datetime.datetime.now()
best_valid_acc = 0
iter_indx = 0

epoch = 1

In [None]:
unsup_batch = next(iter(train_loader))
sup_batch,target = next(iter(labeled_loader))




In [None]:
unsup_batch.size()

In [None]:
# set up unlabeled input and labeled input with the corresponding labels
input_unsup_var = th.autograd.Variable(unsup_batch[0:(opt.batch_size - opt.labeled_batch_size)]).to(device)
input_sup_var = th.autograd.Variable(sup_batch).to(device)
target_sup_var = th.autograd.Variable(target.data.long()).to(device)



In [None]:
minibatch_unsup_size = opt.batch_size
minibatch_sup_size = opt.labeled_batch_size

In [None]:
# compute loss for unlabeled input
[output_unsup, xhat_unsup, loss_pn_unsup, loss_bnmm_unsup] = net(input_unsup_var)


In [None]:
loss_reconst_unsup = L2_loss(xhat_unsup, input_unsup_var).mean()
softmax_unsup = F.softmax(output_unsup)
loss_kl_unsup = -th.sum(th.log(10.0*softmax_unsup + 1e-8) * softmax_unsup) / minibatch_unsup_size
loss_unsup = opt.alpha_reconst * loss_reconst_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_bnmm * loss_bnmm_unsup + opt.alpha_pn * loss_pn_unsup


In [None]:
loss_unsup

In [None]:
# compute loss for labeled input
[output_sup, xhat_sup, loss_pn_sup, loss_bnmm_sup] = net(input_sup_var, target_sup_var)


In [None]:
loss_xentropy_sup = criterion(output_sup, target_sup_var.squeeze_()) 

In [None]:
loss_xentropy_sup = criterion(output_sup, target_sup_var.squeeze_()) / minibatch_sup_size
loss_reconst_sup = L2_loss(xhat_sup, input_sup_var).mean()
softmax_sup = F.softmax(output_sup)
loss_kl_sup = -th.sum(th.log(10.0*softmax_sup + 1e-8) * softmax_sup)/ minibatch_sup_size
loss_sup = loss_xentropy_sup + opt.alpha_reconst * loss_reconst_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_bnmm * loss_bnmm_sup + opt.alpha_pn * loss_pn_sup


In [None]:
# main training loop
# def train(net, train_loader,labeled_loader, eval_loader, num_epochs, wd):

wd = 5e-4
num_epochs = 1000
best_f1 = 0
for epoch in range(30,num_epochs):
    train_loss = 0; train_loss_xentropy = 0; train_loss_reconst = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_bnmm = 0
    correct = 0
    num_batch_train = 0

    # start with adam optimizer but switch sgd optimizer with exponential decay learning rate since epoch 20 
    if epoch == 250:
        sgd_lr = opt.lr[1]
        decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
        sgd_lr = sgd_lr * decay_val
        trainer = th.optim.SGD(net.parameters(), sgd_lr, weight_decay=wd)

    if epoch >= 250:
        for param_group in trainer.param_groups:
            param_group['lr'] = param_group['lr']/decay_val

    for param_group in trainer.param_groups:
        learning_rate = param_group['lr']

    meters = AverageMeterSet()

    # switch to train mode
    net.train()

#     end = time.time()
    for i in range(int(n_iterations)):
#         meters.update('data_time', time.time() - end)

        unsup_batch = next(iter(train_loader))
        sup_batch,target = next(iter(labeled_loader))

        # set up unlabeled input and labeled input with the corresponding labels
        input_unsup_var = th.autograd.Variable(unsup_batch[0:(opt.batch_size - opt.labeled_batch_size)]).to(device)
        input_sup_var = th.autograd.Variable(sup_batch).to(device)
        target_sup_var = th.autograd.Variable(target.data.long()).to(device)

        minibatch_unsup_size = opt.batch_size - opt.labeled_batch_size
        minibatch_sup_size = opt.labeled_batch_size

        # compute loss for unlabeled input
        [output_unsup, xhat_unsup, loss_pn_unsup, loss_bnmm_unsup] = net(input_unsup_var)
        loss_reconst_unsup = L2_loss(xhat_unsup, input_unsup_var).mean()
        softmax_unsup = F.softmax(output_unsup)
        loss_kl_unsup = -th.sum(th.log(10.0*softmax_unsup + 1e-8) * softmax_unsup) / minibatch_unsup_size
        loss_unsup = opt.alpha_reconst * loss_reconst_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_bnmm * loss_bnmm_unsup + opt.alpha_pn * loss_pn_unsup



        # compute loss for labeled input
        [output_sup, xhat_sup, loss_pn_sup, loss_bnmm_sup] = net(input_sup_var, target_sup_var.squeeze_())
        loss_xentropy_sup = criterion(output_sup, target_sup_var) / minibatch_sup_size
        loss_reconst_sup = L2_loss(xhat_sup, input_sup_var).mean()
        softmax_sup = F.softmax(output_sup)
        loss_kl_sup = -th.sum(th.log(10.0*softmax_sup + 1e-8) * softmax_sup)/ minibatch_sup_size
        loss_sup = loss_xentropy_sup + opt.alpha_reconst * loss_reconst_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_bnmm * loss_bnmm_sup + opt.alpha_pn * loss_pn_sup

        loss = th.mean(loss_unsup + loss_sup)

        # compute the grads and update the parameters
        trainer.zero_grad()
        loss.backward()
        trainer.step()

        # accumulate all the losses for visualization
        loss_reconst = loss_reconst_unsup + loss_reconst_sup
        loss_pn = loss_pn_unsup + loss_pn_sup
        loss_xentropy = loss_xentropy_sup
        loss_kl = loss_kl_unsup + loss_kl_sup
        loss_bnmm = loss_bnmm_unsup + loss_bnmm_sup

        train_loss_xentropy += th.mean(loss_xentropy).cpu().detach().numpy()
        train_loss_reconst += th.mean(loss_reconst).cpu().detach().numpy()
        train_loss_pn += th.mean(loss_pn).cpu().detach().numpy()
        train_loss_kl += th.mean(loss_kl).cpu().detach().numpy()
        train_loss_bnmm += th.mean(loss_bnmm).cpu().detach().numpy()
        train_loss += th.mean(loss).cpu().detach().numpy()
        correct += get_acc(output_sup, target_sup_var).cpu().detach().numpy()

        num_batch_train += 1
        iter_indx += 1
    writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
    writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
    writer.add_scalars('loss_reconst', {'train': train_loss_reconst / num_batch_train}, epoch)
    writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
    writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
    writer.add_scalars('loss_bnmm', {'train': train_loss_bnmm / num_batch_train}, epoch)
    writer.add_scalars('acc', {'train': correct / num_batch_train}, epoch)

    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)

    # Validation
    valid_loss = 0; valid_loss_xentropy = 0; valid_loss_reconst = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_bnmm = 0
    valid_correct = 0
    num_batch_valid = 0
    valid_accuracy = 0
    valid_f1 = 0

    net.eval()

    for i, (batch, target) in enumerate(eval_loader):
        with th.no_grad():
            input_var = th.autograd.Variable(batch).to(device)
            target_var = th.autograd.Variable(target.data.long()).to(device)

            minibatch_size = len(target_var)

            [output, xhat, loss_pn, loss_bnmm] = net(input_var, target_var)

            loss_xentropy = criterion(output, target_var.squeeze_())/minibatch_size
            loss_reconst = L2_loss(xhat, input_var).mean()
            softmax_val = F.softmax(output)
            loss_kl = -th.sum(th.log(10.0*softmax_val + 1e-8) * softmax_val)/minibatch_size
            loss = loss_xentropy + opt.alpha_reconst * loss_reconst + opt.alpha_kl * loss_kl + opt.alpha_bnmm * loss_bnmm + opt.alpha_pn * loss_pn

            valid_loss_xentropy += th.mean(loss_xentropy).cpu().detach().numpy()
            valid_loss_reconst += th.mean(loss_reconst).cpu().detach().numpy()
            valid_loss_pn += th.mean(loss_pn).cpu().detach().numpy()
            valid_loss_kl += th.mean(loss_kl).cpu().detach().numpy()
            valid_loss_bnmm += th.mean(loss_bnmm).cpu().detach().numpy()
            valid_loss += th.mean(loss).cpu().detach().numpy()
            valid_correct += get_acc(output, target_var).cpu().detach().numpy()
            
            accuracy, f1 = compute_metrics(target_var.cpu(), th.argmax(output, dim=1, keepdim=False).cpu())
            valid_accuracy+=accuracy
            valid_f1+=f1
            num_batch_valid += 1

    valid_acc = valid_correct / num_batch_valid
    f1_s = valid_f1/num_batch_valid
    if f1_s > best_f1:
        best_f1 = f1_s
        
#         th.save(net.state_dict(), '%s/%s_best.pth'%(opt.model_dir, opt.exp_name))
#     writer.add_scalars('loss', {'valid': valid_loss / num_batch_valid}, epoch)
#     writer.add_scalars('loss_xentropy', {'valid': valid_loss_xentropy / num_batch_valid}, epoch)
#     writer.add_scalars('loss_reconst', {'valid': valid_loss_reconst / num_batch_valid}, epoch)
#     writer.add_scalars('loss_pn', {'valid': valid_loss_pn / num_batch_valid}, epoch)
#     writer.add_scalars('loss_kl', {'valid': valid_loss_kl / num_batch_valid}, epoch)
#     writer.add_scalars('loss_bnmm', {'valid': valid_loss_bnmm / num_batch_valid}, epoch)
#     writer.add_scalars('acc', {'valid': valid_acc}, epoch)
    epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Best f1 acc %f,f1 %f, acc %f "
                 % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_reconst / num_batch_train, train_loss_pn / num_batch_train,
                    correct / num_batch_train, valid_loss / num_batch_valid, valid_acc, best_f1,valid_f1/num_batch_valid,valid_accuracy/num_batch_valid))
    if not epoch % 20:
        th.save(net.state_dict(), '%s/%s_epoch_%i.pth'%(opt.model_dir, opt.exp_name, epoch))

    prev_time = cur_time
#     logging.info(epoch_str + time_str + ', lr ' + str(learning_rate))
    print(epoch_str)   
#     return best_valid_acc



Epoch 30. Train Loss: 2.956289, Train Xent: 2.608154, Train Reconst: 0.016058, Train Pn: 0.071542, Train acc 0.298246, Valid Loss: 5.133420, Valid acc 0.329365, Best f1 acc 0.279762,f1 0.279762, acc 0.329365 
Epoch 31. Train Loss: 5.218093, Train Xent: 2.674425, Train Reconst: 0.018138, Train Pn: 0.069681, Train acc 0.302632, Valid Loss: 3.653008, Valid acc 0.301587, Best f1 acc 0.279762,f1 0.237302, acc 0.301587 
Epoch 32. Train Loss: 3.397070, Train Xent: 2.719332, Train Reconst: 0.016569, Train Pn: 0.066221, Train acc 0.263158, Valid Loss: 4.131181, Valid acc 0.305556, Best f1 acc 0.280423,f1 0.280423, acc 0.305556 
Epoch 33. Train Loss: 3.622095, Train Xent: 2.650633, Train Reconst: 0.016285, Train Pn: 0.067332, Train acc 0.285088, Valid Loss: 6.664385, Valid acc 0.281746, Best f1 acc 0.280423,f1 0.212434, acc 0.281746 
Epoch 34. Train Loss: 3.856143, Train Xent: 2.632510, Train Reconst: 0.014388, Train Pn: 0.064078, Train acc 0.293860, Valid Loss: 4.133035, Valid acc 0.301587, Bes

Epoch 70. Train Loss: 2.191177, Train Xent: 2.396460, Train Reconst: 0.015215, Train Pn: 0.039288, Train acc 0.328947, Valid Loss: 13.241214, Valid acc 0.317460, Best f1 acc 0.316270,f1 0.302910, acc 0.317460 
Epoch 71. Train Loss: 2.272240, Train Xent: 2.521175, Train Reconst: 0.016584, Train Pn: 0.037985, Train acc 0.333333, Valid Loss: 14.464527, Valid acc 0.273810, Best f1 acc 0.316270,f1 0.225000, acc 0.273810 
Epoch 72. Train Loss: 2.109194, Train Xent: 2.364879, Train Reconst: 0.014938, Train Pn: 0.038272, Train acc 0.324561, Valid Loss: 6.535020, Valid acc 0.297619, Best f1 acc 0.316270,f1 0.273016, acc 0.297619 
Epoch 73. Train Loss: 2.409133, Train Xent: 2.592926, Train Reconst: 0.014622, Train Pn: 0.036958, Train acc 0.289474, Valid Loss: 4.445194, Valid acc 0.293651, Best f1 acc 0.316270,f1 0.246032, acc 0.293651 
Epoch 74. Train Loss: 2.382578, Train Xent: 2.456031, Train Reconst: 0.015951, Train Pn: 0.037167, Train acc 0.328947, Valid Loss: 3.372244, Valid acc 0.361111, B

In [None]:
from sklearn.metrics import accuracy_score, f1_score
from utils.constants import Constants
def compute_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="weighted")
    return accuracy, f1

In [None]:
pred = 

In [None]:
target_var

In [None]:
accuracy,f1

In [14]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        th.nn.init.xavier_uniform(m.weight)
        # m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Bias') != -1:
        m.bias.data.fill_(0)
        
    


#     acc = train(model, train_loader, eval_loader, opt.num_epochs, opt.weight_decay)
#     logging.info('Validation Accuracy - Run %i = %f'%(i, acc))
#     valid_acc += acc

#     logging.info('Validation Accuracy = %f'%(valid_acc/num_exp))

In [11]:
import numpy as np

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from collections import OrderedDict

cfg = {
    'AllConv13': [128, 128, 128, 'M', 256, 256, 256, 'M', 512, 256, 128, 'A'],
}

#################### Some utils class ####################
class Reshape(nn.Module):
    """
    Flatten the output of the convolutional layer
    Parameters
    ----------
    Input shape: (N, C * W * H)
    Output shape: (N, C, W, H)
    """
    def __init__(self, shape, **kwargs):
        super(Reshape, self).__init__(**kwargs)
        self._shape = shape

    def forward(self, x):
        return x.reshape(x.size()[0], self._shape[0], self._shape[1], self._shape[2])
    
class BiasAdder(nn.Module):
    """
    Add a bias into the input
    """
    def __init__ (self, channels, **kwargs):
        super(BiasAdder, self).__init__(**kwargs)
        self.bias = nn.Parameter(th.Tensor(1,channels,1,1))
        self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, x):
        return x + self.bias
    
class Flatten(nn.Module):
    """
    Flatten 4D tensor into 2D tensor
    """
    def forward(self, x):
        return x.view(x.size(0), -1)
    
class Upaverage(nn.Module):
    """
    Upsample to reverse the avg pooling layer
    """
    def __init__(self, scale_factor, **kwargs):
        super(Upaverage, self).__init__(**kwargs)
        self.scale_factor = scale_factor
        self.upsample_layer = nn.Upsample(scale_factor=self.scale_factor, mode='nearest')
        
    def forward(self, x):
        return self.upsample_layer(x) * (1./self.scale_factor)**2
    
def make_one_hot(labels, C=2):
    """
    Converts an integer label torch.autograd.Variable to a one-hot Variable.
    """
    target = th.eye(C)[labels.data]
    target = target.to(labels.get_device())      
    return target

#################### Main NRM class ####################
class NRM(nn.Module):
    def __init__(self, net_name, batch_size, num_class, use_bias=False, use_bn=False, do_topdown=False, do_pn=False, do_bnmm=False):
        super(NRM, self).__init__()
        self.num_class = num_class
        self.do_topdown = do_topdown
        self.do_pn = do_pn
        self.do_bnmm = do_bnmm
        self.use_bn = use_bn
        self.use_bias = use_bias
        self.batch_size = batch_size
        
        # create:
        # feature extractor in the forward cnn step: self.features
        # corresponding layer inm the top-down reconstruction nrm step: layers_nrm
        # instance norm used in the top-down reconstruction nrm step: insnorms_nrm
        # instance norm used in the forward cnn step: insnorms_cnn
        self.features, layers_nrm, insnorms_nrm, insnorms_cnn = self._make_layers(cfg[net_name], use_bias, use_bn, self.do_topdown)
        
        # create the classifer in the forward cnn step
        conv_layer = nn.Conv2d(in_channels=cfg[net_name][-2], out_channels=self.num_class, kernel_size=(1,1), bias=True)
        flatten_layer = Flatten()
        self.classifier = nn.Sequential(OrderedDict([('conv',conv_layer), ('flatten', flatten_layer)]))
        
        # create the nrm
        if self.do_topdown:
            # add layers corresponding to the classifer in the forward step
            convtd_layer = nn.ConvTranspose2d(out_channels=cfg[net_name][-2], in_channels=self.num_class, kernel_size=(1,1), stride=(1, 1), bias=False)
            convtd_layer.weight.data = conv_layer.weight.data
            layers_nrm += [('convtd',convtd_layer), ('reshape', Reshape(shape=(self.num_class, 1, 1)))]
            
            self.nrm = nn.Sequential(OrderedDict(layers_nrm[::-1]))
            
            # if use path normalization, then also use instance normalization
            if self.do_pn:
                self.insnorms_nrm = nn.Sequential(OrderedDict(insnorms_nrm[::-1]))
                self.insnorms_cnn = nn.Sequential(OrderedDict(insnorms_cnn))


    def forward(self, x, y=None):
        ahat = []; that = []; bcnn = []; apn = []; meancnn = []; varcnn = []
        xbias = th.zeros([1, x.shape[1], x.shape[2], x.shape[3]], device=x.get_device()) if self.do_pn else []
        insnormcnn_indx = 0
        
        # if do top-down reconstruction, we need to keep track of relu state, maxpool state,
        # mean and var of the activations, and the bias terms in the forward cnn step
        if self.do_topdown: 
            for name, layer in self.features.named_children():
                if name.find('pool') != -1 and not name.find('average') != -1: # keep track of the maxpool state
                    F.interpolate(layer(x), scale_factor=2, mode='nearest')
                    that.append(th.gt(x-F.interpolate(layer(x), scale_factor=2, mode='nearest'),0))
                    x = layer(x)
                    if self.do_pn:
                        xbias = layer(xbias)
                else:
                    x = layer(x)
                    
                    if self.do_pn: # get the forward results to compute the path normalization later
                        if name.find('batchnorm') != -1:
                            xbias = self.insnorms_cnn[insnormcnn_indx](xbias)
                            insnormcnn_indx += 1
                        else:
                            xbias = layer(xbias)
                    if name.find('relu') != -1: # keep track of the relu state
                        ahat.append(th.gt(x,0) + th.le(x,0)*0.1)
                        if self.do_pn:
                            apn.append(th.gt(xbias,0) + th.le(xbias,0)*0.1)
                    
                    if self.use_bn:
                        if name.find('conv') != -1: # keep track of the mean and var of the activations
                            meancnn.append(th.mean(x, dim=(0,2,3), keepdim=True))
                            varcnn.append(th.mean((x - th.mean(x, dim=(0,2,3), keepdim=True))**2, dim=(0,2,3), keepdim=True))
                        if self.use_bias: # keep track of the bias terms when adding bias
                            if name.find('bias') != -1: 
                                bcnn.append(layer.bias)
                        else: # otherwise, keep track of the bias terms inside the batch norm
                            if name.find('batchnorm') != -1:
                                bcnn.append(layer.bias)
                    else:
                        if self.use_bias:
                            if name.find('conv') != -1:
                                bcnn.append(layer.bias)
            
            # reverse the order of the parameters/variables that we keep track to use in the top-down reconstruction nrm step since nrm is the reverse of cnn
            ahat = ahat[::-1]
            that = that[::-1]
            bcnn = bcnn[::-1]
            apn = apn[::-1]
            meancnn = meancnn[::-1]
            varcnn = varcnn[::-1]
        else:
            x =  self.features(x)
        
        # send the features into the classifier
        z = self.classifier(x)
        
        # do reconstruction via nrm
        # xhat: the reconstruction image
        # loss_pn: path normalization loss
        # loss_bnmm: batch norm moment matching loss
        if self.do_topdown:
            xhat, _, loss_pn, loss_bnmm = self.topdown(self.nrm, make_one_hot(y, self.num_class), ahat, that, bcnn, th.ones([1, z.size()[1]], device=z.get_device()), apn, meancnn, varcnn) if y is not None else self.topdown(self.nrm, make_one_hot(th.argmax(z.detach(), dim=1), self.num_class), ahat, that, bcnn, th.ones([1, z.size()[1]], device=z.get_device()), apn, meancnn, varcnn)
        else:
            xhat = None
            loss_pn = None
            loss_bnmm = None


        return [z, xhat, loss_pn, loss_bnmm]

    def _make_layers(self, cfg, use_bias, use_bn, do_topdown):
        layers = []
        layers_nrm = []
        insnorms_nrm = []
        insnorms_cnn = []
        in_channels = 3

        for i, x in enumerate(cfg):
            if x == 'M': # if max pooling layer, then add max pooling and dropout into the cnn. Add upsample layers, dropout, batchnorm, and instance norm - for path normaliztion - into the nrm.
                layers += [('pool%i'%i, nn.MaxPool2d(2, stride=2)), ('dropout%i'%i, nn.Dropout(0.5))]
                if do_topdown:
                    if use_bn:
                        layers_nrm += [('upsample%i'%i, nn.Upsample(scale_factor=2, mode='nearest')), ('dropout%i'%i, nn.Dropout(0.5)), ('batchnorm%i'%i, nn.BatchNorm2d(cfg[i-1]))]
                        insnorms_nrm += [('instancenormtd%i'%i, nn.InstanceNorm2d(cfg[i-1], affine=True))]
                    else:
                        layers_nrm += [('upsample%i'%i, nn.Upsample(scale_factor=2, mode='nearest')), ('dropout%i'%i, nn.Dropout(0.5))]
                        
            elif x == 'A': # if avg pooling layer, then add average pooling layer into the cnn. Add up average layers, batchnorm and instance norm - for path normaliztion - into the nrm.
                layers += [('average%i'%i, nn.AvgPool2d(6, stride=1))]
                if do_topdown:
                    if use_bn:
                        layers_nrm += [('upaverage%i'%i, Upaverage(scale_factor=6)), ('batchnorm%i'%i, nn.BatchNorm2d(cfg[i-1]))]
                        insnorms_nrm += [('instancenormtd%i'%i, nn.InstanceNorm2d(cfg[i-1], affine=True))]
                    else:
                        layers_nrm += [('upaverage%i'%i, Upaverage(scale_factor=6))]
                        
            else: # add other layers into the cnn and the nrm
                padding_cnn = (0,0) if x == 512 else (1,1)
                padding_nrm = (0,0) if x == 512 else (1,1)
                if use_bn:
                    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=x, kernel_size=(3,3), padding=padding_cnn, bias=False)
                    if use_bias:
                        layers += [('conv%i'%i, conv_layer),
                                   ('batchnorm%i'%i, nn.BatchNorm2d(x)),
                                   ('bias%i'%i,BiasAdder(channels=x)),
                                   ('relu%i'%i,nn.LeakyReLU(0.1))]
                    else:
                        layers += [('conv%i'%i, conv_layer),
                                   ('batchnorm%i'%i, nn.BatchNorm2d(x)),
                                   ('relu%i'%i, nn.LeakyReLU(0.1))]
                    
                    insnorms_cnn += [('instancenormcnn%i'%i, nn.InstanceNorm2d(x, affine=True))]
                    if do_topdown:
                        if (cfg[i-1] == 'M' or cfg[i-1] == 'A') and not i == 0:
                            layers_nrm += [('convtd%i'%i, nn.ConvTranspose2d(out_channels=in_channels, in_channels=x, kernel_size=3, stride=(1, 1),
                                                              padding=padding_nrm, bias=False))]
                            layers_nrm[-1][-1].weight.data = conv_layer.weight.data
                        else:
                            layers_nrm += [('batchnormtd%i'%i, nn.BatchNorm2d(in_channels)), ('convtd%i'%i, nn.ConvTranspose2d(out_channels=in_channels, in_channels=x, kernel_size=3, stride=(1, 1), padding=padding_nrm, bias=False))]
                            layers_nrm[-1][-1].weight.data = conv_layer.weight.data
                            insnorms_nrm += [('instancenormtd%i'%i, nn.InstanceNorm2d(in_channels, affine=True))]
                    
                elif use_bias:
                    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=x, kernel_size=(3,3), padding=padding_cnn, use_bias=True)
                    layers += [('conv%i'%i, conv_layer), ('relu%i'%i, nn.LeakyReLU(0.1))]
                    if do_topdown:
                        layers_nrm += [('convtd%i'%i, nn.ConvTranspose2d(out_channels=in_channels, in_channels=x, kernel_size=3, stride=(1, 1),
                                                          padding=padding_nrm, bias=False))]
                        layers_nrm[-1][-1].weight.data = conv_layer.weight.data
                    
                else:
                    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=x, kernel_size=(3,3), padding=padding_cnn, bias=False)
                    layers += [('conv%i'%i, conv_layer), ('relu%i'%i, nn.LeakyReLU(0.1))]
                    if do_topdown:
                        layers_nrm += [('convtd%i'%i, nn.ConvTranspose2d(out_channels=in_channels, in_channels=x, kernel_size=3, stride=(1,1),
                                                          padding=padding_nrm, bias=False))]
                        layers_nrm[-1][-1].weight.data = conv_layer.weight.data
                        
                in_channels = x

        model = nn.Sequential(OrderedDict(layers))

        return model, layers_nrm, insnorms_nrm, insnorms_cnn

    def topdown(self, net, xhat, ahat, that, bcnn, xpn, apn, meancnn, varcnn):
        mu = xhat
        mupn = xpn
        loss_pn = th.zeros([self.batch_size,], device=mu.get_device())
        loss_bnmm = th.zeros([self.batch_size,], device=mu.get_device())

        ahat_indx = 0; that_indx = 0; meanvar_indx = 0; insnormtd_indx = 0
        prev_name = ''
        
        for i, (name, layer) in enumerate(net.named_children()):
            if name.find('conv') != -1 and i > 1: 
                mu = mu * ahat[ahat_indx].type(th.FloatTensor).to(mu.get_device()) # mask the intermediate rendered images by the relu states in the forward step
                
                if self.do_pn: # compute the path normalization loss
                    mupn = mupn * apn[ahat_indx].type(th.FloatTensor).to(mu.get_device())
                    mu_b = bcnn[ahat_indx].data.reshape((1, -1, 1, 1)) * mu
                    mupn_b = bcnn[ahat_indx].data.reshape((1, -1, 1, 1)) * mupn
                    
                    loss_pn_layer = th.mean(th.abs(mu_b - mupn_b), dim=(1,2,3))
                    loss_pn = loss_pn + loss_pn_layer

                ahat_indx += 1

            if prev_name.find('upsamplelayer') != -1 and not prev_name.find('avg') != -1:
                mu = mu * that[that_indx].type(th.FloatTensor).to(mu.get_device()) # mask the intermediate rendered images by the maxpool states in the forward step
                if self.do_pn:
                    mupn = mupn * that[that_indx].type(th.FloatTensor).to(mu.get_device())
                that_indx += 1
          
            # compute the next intermediate rendered images
            mu = layer(mu)
            
            # compute the next intermediate rendered results for computing the path normalization loss in the next layer
            if (name.find('batchnorm') != -1) and (i < len(net) - 1):
                if self.do_pn:
                    mupn = self.insnorms_nrm[insnormtd_indx](mupn)
                    insnormtd_indx += 1
            else:
                if self.do_pn:
                    mupn = layer(mupn)
            
            if (name.find('conv') != -1) and (i != (len(net)-2)):
                if self.do_bnmm and self.use_bn:
                    # compute the KL distance between two Gaussians - the intermediate rendered images and the mean/var from the forward step
                    loss_bnmm = loss_bnmm + 0.5*th.mean(((th.mean(mu, dim=(0,2,3)) - meancnn[meanvar_indx])**2)/varcnn[meanvar_indx]) + 0.5*th.mean(th.mean((mu - th.mean(mu, dim=(0,2,3), keepdim=True))**2, dim=(0,2,3))/varcnn[meanvar_indx]) - 0.5*th.mean(th.log(th.mean((mu - th.mean(mu, dim=(0,2,3), keepdim=True))**2, dim=(0,2,3)) + 1e-8) - th.log(varcnn[meanvar_indx])) - 0.5
                    meanvar_indx += 1
                    
            prev_name = name
            
        return mu, mupn, loss_pn, loss_bnmm