In [1]:
import os
import math
import logging

import datetime
import time

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import torchvision.datasets

from utils import data
import torchvision.transforms as transforms

from tensorboardX import SummaryWriter
import shutil
import _pickle as cPickle
from sklearn import preprocessing
import subprocess

from IPython.core.debugger import Tracer

In [2]:
# set up parameters
class Options:
    def __init__(self):
        self.seed_val = 0 # random seed val
        self.num_train_sup = 4000 # number of labeled train samples
        self.batch_size = 100 # batch size
        self.labeled_batch_size = 50 # number of labeled samples in a batch
        self.device = 0 # gpu id
        
        self.lr = [0.001, 0.15] # learning rate for adam and then and then initial learning rate for sgd respectively
        self.num_epochs = 500 # number of training epochs
        self.weight_decay = 5e-4 # weight decay
        
        self.mount_point = '/tan' # change this to your mount_point
        self.datadir = 'data-local/images/cifar/cifar10/by-image' # dataset directory
        self.labels = 'data-local/labels/cifar10/4000_balanced_labels/00.txt' # label directory
        self.log_dir = os.path.join(self.mount_point,'logs') # log directory
        self.model_dir = os.path.join(self.mount_point,'models') # log 
        self.exp_name = 'cifar10_nl_%i_allconv13_seed_%i'%(self.num_train_sup, self.seed_val) # name of experiments
        
        self.train_subdir = 'train+val'
        self.eval_subdir = 'test'
        self.num_classes = 10
        self.workers = 4
        
        self.alpha_reconst = 0.5 # weight for reconstruction loss
        self.alpha_pn = 1.0 # weight for path normalization loss
        self.alpha_kl = 0.5 # weight for kl loss
        self.alpha_bnmm = 0.5 # weight for moment matching loss when doing batchnorm
        
        self.use_bias = True # add bias after batchnorm
        self.use_bn = True # use batch norm
        self.do_topdown = True # do topdown
        self.do_pn = True # do path normalization
        self.do_bnmm = True # do moment matching for batchnorm

opt = Options()

In [3]:
# set device
th.cuda.set_device(opt.device)
device = th.device("cuda:%i"%opt.device if (opt.device >= 0) else "cpu")

def gpu_device(device=0):
    try:
        _ = th.tensor([1, 2, 3], device=th.device('cuda', device))
    except ValueError:
        return None
    return th.device('cuda', device)

assert gpu_device(opt.device), 'No GPU device found!'

In [4]:
# make required folders
if not os.path.exists(opt.log_dir):
    os.makedirs(opt.log_dir)
log_dir = os.path.join(opt.log_dir, opt.exp_name)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
if not os.path.exists(opt.model_dir):
    os.makedirs(opt.model_dir)
if not os.path.exists(os.path.join(opt.mount_point,'datasets')):
    os.makedirs(os.path.join(opt.mount_point,'datasets'))

In [None]:
# set logging option
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(message)s')
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)

hdlr = logging.FileHandler(os.path.join(opt.log_dir, '{}.log'.format(opt.exp_name)))
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logging.info(opt)

writer = SummaryWriter(os.path.join(opt.log_dir, opt.exp_name))

2019-03-04 09:20:52,573 - <__main__.Options object at 0x7f1e957f3940>


In [None]:
# prepare data loaders
channel_stats = dict(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470,  0.2435,  0.2616])
train_transformation = transforms.Compose([
    data.RandomTranslateWithReflect(4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(**channel_stats)
])
eval_transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**channel_stats)
])

traindir = os.path.join(opt.datadir, opt.train_subdir)
evaldir = os.path.join(opt.datadir, opt.eval_subdir)

dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

with open(opt.labels) as f:
    labels = dict(line.split(' ') for line in f.read().splitlines())
labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

batch_sampler = data.TwoStreamBatchSampler(unlabeled_idxs, labeled_idxs, opt.batch_size, opt.labeled_batch_size)

train_loader = th.utils.data.DataLoader(dataset,
                                        batch_sampler=batch_sampler,
                                        num_workers=opt.workers,
                                        pin_memory=True)

eval_loader = th.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(evaldir, eval_transformation),
    batch_size=opt.batch_size//2,
    shuffle=False,
    num_workers=2 * opt.workers,  # Needs images twice as fast
    pin_memory=True,
    drop_last=False)

In [None]:
# set losses
NO_LABEL = -1
criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=NO_LABEL).cuda()
L2_loss = nn.MSELoss(size_average=False, reduce=False, reduction='mean').cuda()



In [None]:
# import the NRM
from nrm import NRM

In [None]:
# some util functions
def get_acc(output, label):
    pred = th.argmax(output, dim=1, keepdim=False)
    correct = th.mean((pred == label).type(th.FloatTensor))
    return correct

class AverageMeterSet:
    def __init__(self):
        self.meters = {}

    def __getitem__(self, key):
        return self.meters[key]

    def update(self, name, value, n=1):
        if not name in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, postfix=''):
        return {name + postfix: meter.val for name, meter in self.meters.items()}

    def averages(self, postfix='/avg'):
        return {name + postfix: meter.avg for name, meter in self.meters.items()}

    def sums(self, postfix='/sum'):
        return {name + postfix: meter.sum for name, meter in self.meters.items()}

    def counts(self, postfix='/count'):
        return {name + postfix: meter.count for name, meter in self.meters.items()}


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, format):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)

In [None]:
# main training loop
def train(net, train_loader, eval_loader, num_epochs, wd):
    trainer = th.optim.Adam(net.parameters(), opt.lr[0], weight_decay=wd)
    
    prev_time = datetime.datetime.now()
    best_valid_acc = 0
    iter_indx = 0
    
    for epoch in range(num_epochs):
        train_loss = 0; train_loss_xentropy = 0; train_loss_reconst = 0; train_loss_pn = 0; train_loss_kl = 0; train_loss_bnmm = 0
        correct = 0
        num_batch_train = 0
        
        # start with adam optimizer but switch sgd optimizer with exponential decay learning rate since epoch 20 
        if epoch == 20:
            sgd_lr = opt.lr[1]
            decay_val = np.exp(np.log(sgd_lr / 0.0001) / (num_epochs - 2))
            sgd_lr = sgd_lr * decay_val
            trainer = th.optim.SGD(net.parameters(), sgd_lr, weight_decay=wd)
            
        if epoch >= 20:
            for param_group in trainer.param_groups:
                param_group['lr'] = param_group['lr']/decay_val
                
        for param_group in trainer.param_groups:
            learning_rate = param_group['lr']
        
        meters = AverageMeterSet()
        
        # switch to train mode
        net.train()
        
        end = time.time()
        for i, (batch, target) in enumerate(train_loader):
            meters.update('data_time', time.time() - end)
            
            # set up unlabeled input and labeled input with the corresponding labels
            input_unsup_var = th.autograd.Variable(batch[0:(opt.batch_size - opt.labeled_batch_size)]).to(device)
            input_sup_var = th.autograd.Variable(batch[(opt.batch_size - opt.labeled_batch_size):]).to(device)
            target_sup_var = th.autograd.Variable(target[(opt.batch_size - opt.labeled_batch_size):].cuda(async=True)).to(device)
            
            minibatch_unsup_size = opt.batch_size - opt.labeled_batch_size
            minibatch_sup_size = opt.labeled_batch_size
            
            # compute loss for unlabeled input
            [output_unsup, xhat_unsup, loss_pn_unsup, loss_bnmm_unsup] = net(input_unsup_var)
            loss_reconst_unsup = L2_loss(xhat_unsup, input_unsup_var).mean()
            softmax_unsup = F.softmax(output_unsup)
            loss_kl_unsup = -th.sum(th.log(10.0*softmax_unsup + 1e-8) * softmax_unsup) / minibatch_unsup_size
            loss_unsup = opt.alpha_reconst * loss_reconst_unsup + opt.alpha_kl * loss_kl_unsup + opt.alpha_bnmm * loss_bnmm_unsup + opt.alpha_pn * loss_pn_unsup
            
            # compute loss for labeled input
            [output_sup, xhat_sup, loss_pn_sup, loss_bnmm_sup] = net(input_sup_var, target_sup_var)
            loss_xentropy_sup = criterion(output_sup, target_sup_var) / minibatch_sup_size
            loss_reconst_sup = L2_loss(xhat_sup, input_sup_var).mean()
            softmax_sup = F.softmax(output_sup)
            loss_kl_sup = -th.sum(th.log(10.0*softmax_sup + 1e-8) * softmax_sup)/ minibatch_sup_size
            loss_sup = loss_xentropy_sup + opt.alpha_reconst * loss_reconst_sup + opt.alpha_kl * loss_kl_sup + opt.alpha_bnmm * loss_bnmm_sup + opt.alpha_pn * loss_pn_sup

            loss = th.mean(loss_unsup + loss_sup)
            
            # compute the grads and update the parameters
            trainer.zero_grad()
            loss.backward()
            trainer.step()
            
            # accumulate all the losses for visualization
            loss_reconst = loss_reconst_unsup + loss_reconst_sup
            loss_pn = loss_pn_unsup + loss_pn_sup
            loss_xentropy = loss_xentropy_sup
            loss_kl = loss_kl_unsup + loss_kl_sup
            loss_bnmm = loss_bnmm_unsup + loss_bnmm_sup
            
            train_loss_xentropy += th.mean(loss_xentropy).cpu().detach().numpy()
            train_loss_reconst += th.mean(loss_reconst).cpu().detach().numpy()
            train_loss_pn += th.mean(loss_pn).cpu().detach().numpy()
            train_loss_kl += th.mean(loss_kl).cpu().detach().numpy()
            train_loss_bnmm += th.mean(loss_bnmm).cpu().detach().numpy()
            train_loss += th.mean(loss).cpu().detach().numpy()
            correct += get_acc(output_sup, target_sup_var).cpu().detach().numpy()
            
            num_batch_train += 1
            iter_indx += 1
        
        writer.add_scalars('loss', {'train': train_loss / num_batch_train}, epoch)
        writer.add_scalars('loss_xentropy', {'train': train_loss_xentropy / num_batch_train}, epoch)
        writer.add_scalars('loss_reconst', {'train': train_loss_reconst / num_batch_train}, epoch)
        writer.add_scalars('loss_pn', {'train': train_loss_pn / num_batch_train}, epoch)
        writer.add_scalars('loss_kl', {'train': train_loss_kl / num_batch_train}, epoch)
        writer.add_scalars('loss_bnmm', {'train': train_loss_bnmm / num_batch_train}, epoch)
        writer.add_scalars('acc', {'train': correct / num_batch_train}, epoch)
        
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        
        # Validation
        valid_loss = 0; valid_loss_xentropy = 0; valid_loss_reconst = 0; valid_loss_pn = 0; valid_loss_kl = 0; valid_loss_bnmm = 0
        valid_correct = 0
        num_batch_valid = 0

        net.eval()

        for i, (batch, target) in enumerate(eval_loader):
            with th.no_grad():
                input_var = th.autograd.Variable(batch).to(device)
                target_var = th.autograd.Variable(target.cuda(async=True)).to(device)

                minibatch_size = len(target_var)

                [output, xhat, loss_pn, loss_bnmm] = net(input_var, target_var)

                loss_xentropy = criterion(output, target_var)/minibatch_size
                loss_reconst = L2_loss(xhat, input_var).mean()
                softmax_val = F.softmax(output)
                loss_kl = -th.sum(th.log(10.0*softmax_val + 1e-8) * softmax_val)/minibatch_size
                loss = loss_xentropy + opt.alpha_reconst * loss_reconst + opt.alpha_kl * loss_kl + opt.alpha_bnmm * loss_bnmm + opt.alpha_pn * loss_pn

                valid_loss_xentropy += th.mean(loss_xentropy).cpu().detach().numpy()
                valid_loss_reconst += th.mean(loss_reconst).cpu().detach().numpy()
                valid_loss_pn += th.mean(loss_pn).cpu().detach().numpy()
                valid_loss_kl += th.mean(loss_kl).cpu().detach().numpy()
                valid_loss_bnmm += th.mean(loss_bnmm).cpu().detach().numpy()
                valid_loss += th.mean(loss).cpu().detach().numpy()
                valid_correct += get_acc(output, target_var).cpu().detach().numpy()

                num_batch_valid += 1
        
        valid_acc = valid_correct / num_batch_valid
        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            th.save(net.state_dict(), '%s/%s_best.pth'%(opt.model_dir, opt.exp_name))
        writer.add_scalars('loss', {'valid': valid_loss / num_batch_valid}, epoch)
        writer.add_scalars('loss_xentropy', {'valid': valid_loss_xentropy / num_batch_valid}, epoch)
        writer.add_scalars('loss_reconst', {'valid': valid_loss_reconst / num_batch_valid}, epoch)
        writer.add_scalars('loss_pn', {'valid': valid_loss_pn / num_batch_valid}, epoch)
        writer.add_scalars('loss_kl', {'valid': valid_loss_kl / num_batch_valid}, epoch)
        writer.add_scalars('loss_bnmm', {'valid': valid_loss_bnmm / num_batch_valid}, epoch)
        writer.add_scalars('acc', {'valid': valid_acc}, epoch)
        epoch_str = ("Epoch %d. Train Loss: %f, Train Xent: %f, Train Reconst: %f, Train Pn: %f, Train acc %f, Valid Loss: %f, Valid acc %f, Best valid acc %f, "
                     % (epoch, train_loss / num_batch_train, train_loss_xentropy / num_batch_train, train_loss_reconst / num_batch_train, train_loss_pn / num_batch_train,
                        correct / num_batch_train, valid_loss / num_batch_valid, valid_acc, best_valid_acc))
        if not epoch % 20:
            th.save(net.state_dict(), '%s/%s_epoch_%i.pth'%(opt.model_dir, opt.exp_name, epoch))

        prev_time = cur_time
        logging.info(epoch_str + time_str + ', lr ' + str(learning_rate))
        
    return best_valid_acc

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        th.nn.init.xavier_uniform(m.weight)
        # m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Bias') != -1:
        m.bias.data.fill_(0)
        
def run_train(num_exp):
    valid_acc = 0
    for i in range(num_exp):
        model = NRM('AllConv13', batch_size=opt.batch_size // 2, num_class=10, use_bias=opt.use_bias, use_bn=opt.use_bn, do_topdown=opt.do_topdown, do_pn=opt.do_pn, do_bnmm=opt.do_bnmm).to(device)
        model.apply(weights_init)
                
        acc = train(model, train_loader, eval_loader, opt.num_epochs, opt.weight_decay)
        logging.info('Validation Accuracy - Run %i = %f'%(i, acc))
        valid_acc += acc

    logging.info('Validation Accuracy = %f'%(valid_acc/num_exp))

In [None]:
run_train(1)

  after removing the cwd from sys.path.
