In [1]:
import os
import torch
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as tforms
from torchvision.utils import save_image
import numpy as np
import time

import modelling.odenvp as odenvp
from modelling.layers.cnf import CNF
from modelling.layers.odefunc import ODEfunc
from application_examples.helpers.training import RunningAverageMeter
from modelling.utils import count_total_time

In [2]:
args = {
    'resume': None,     # Change to directory where the model is stored to resume training
    'begin_epoch': 1,   # Change to last epoch to resume training
    'dims': "16, 16, 16",
    'strides': "1, 1, 1, 1",
    'num_blocks': 1,
    'num_epochs': 500,
    'val_freq': 5,
    'batch_size': 200,
    'test_batch_size': 200,
    'nonlinearity': 'softplus',
    'alpha': 1e-6,
    'time_length': 1.0,
    'warmup_iters': 1000,
    'lr': 1e-3,
    'solver': 'dopri5',
    'atol': 1e-5,
    'rtol': 1e-5,
    'step_size': None,
    'weight_decay': 0.0,
}

# Load data

In [3]:
def add_noise(x):
    noise = x.new().resize_as_(x).uniform_()
    x = x * 255 + noise
    x = x / 256
    return x

In [4]:
def get_dataset():
    trans = lambda im_size: tforms.Compose([
        tforms.Resize(im_size),
        tforms.ToTensor(),
        add_noise
    ])
    
    im_dim = 1
    im_size = 28
    train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
    test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True)
    
    data_shape = (im_dim, im_size, im_size)
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args['test_batch_size'],
        shuffle=False
    )
    
    return train_set, test_loader, data_shape
train_set, test_loader, data_shape = get_dataset()

In [5]:
def get_train_loader(train_set):
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args['batch_size'],
        shuffle=True,
        drop_last=True,
        pin_memory=True
    )
    print("Using batch size {}. Total {} iterations/epoch.".format(args['batch_size'], len(train_loader)))
    return train_loader

# Define model

In [6]:
hidden_dims = tuple(map(int, args['dims'].split(',')))
strides = tuple(map(int, args['strides'].split(',')))

model = odenvp.ODENVP(
    input_size=(args['batch_size'], *data_shape),
    n_blocks=args['num_blocks'],
    intermediate_dims=hidden_dims,
    nonlinearity=args['nonlinearity'],
    alpha=args['alpha'],
    cnf_kwargs={'T': args['time_length'], 'train_T': True } #, 'regularization_fns': ()
)

Using 3 scales


In [7]:
def set_cnf_options(args, model):

    def _set(module):
        if isinstance(module, CNF):
            # Set training settings
            module.solver = args['solver']
            module.atol = args['atol']
            module.rtol = args['rtol']
            if args['step_size'] is not None:
                module.solver_options['step_size'] = args['step_size']

            # Set the test settings
            module.test_solver = args.get('test_solver', args['solver'])
            module.test_atol = args.get('test_atol', args['atol'])
            module.test_rtol = args.get('test_rtol', args['rtol'])

        if isinstance(module, ODEfunc):
            module.rademacher = args.get('rademacher', True)
            module.residual = args.get('residual', False)

    model.apply(_set)

# Define loss

In [8]:
def standard_normal_logprob(z):
    log_z = -0.5 * np.log(2 * np.pi)
    return log_z - z.pow(2) / 2

In [9]:
def bits_per_dim(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, delta_logp = model(x, zero)
    
    logpz = standard_normal_logprob(z).view(x.shape[0], -1).sum(1, keepdim=True)
    logpx = logpz - delta_logp
    
    logpx_per_dim = torch.sum(logpx) / (x.shape[0] * x.nelement())
    
    return -(logpx_per_dim - np.log(256)) / np.log(2)
    

# Train

In [10]:
def update_lr(optimizer, itr):
    iter_frac = min(float(itr + 1) / max(args['warmup_iters'], 1), 1.)
    lr = args['lr'] * iter_frac
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# Get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device {}.".format(device))

cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True)

set_cnf_options(args, model)

print("Number of trainable parameters: {}.".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

optimizer = optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

if args['resume'] is not None:
    checkpoint = torch.load(args['resume'], map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    if 'optim_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        # Manually move optimizer state to device
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
                    
# If possible, parallelize the computation
if torch.cuda.is_available():
    model = torch.nn.DataParallel(model)
    print("Parallelizing computation.")
    
# Get 100 random samples from normal distribution to visualize the model
fixed_z = cvt(torch.randn(100, *data_shape))

# Set up the average meters
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)
#steps_meter = RunningAverageMeter(0.97)
grad_meter = RunningAverageMeter(0.97)
tt_meter = RunningAverageMeter(0.97)

# Keep track of current best performance
best_loss = float('inf')

# Start training
itr = 0
for epoch in range(args['begin_epoch'], args['num_epochs'] + 1):
    # Set model to training mode
    model.train()
    
    # Get train loader
    train_loader = get_train_loader(train_set)
    
    for x, y in train_loader:
        # Set starting time of epoch iteration
        start = time.time()
        
        # Update learning rate
        update_lr(optimizer, itr)
        optimizer.zero_grad()
        
        # Move data to device
        x = cvt(x)
        #Compute loss
        loss = bits_per_dim(x, model)
        
        total_time = count_total_time(model)
        
        # Backpropagate and update parameters
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 100) # 100 is the max norm
        
        optimizer.step()
        
        # Update average meters
        time_meter.update(time.time() - start)
        loss_meter.update(loss.item())
        grad_meter.update(grad_norm)
        #steps_meter.update(count_nfe(model))
        tt_meter.update(total_time)
        
        # Log progress
        if itr % args.get('log_freq', 10) == 0:
            print("Epoch {} Iter {} Loss {:.4f}({:.4f}) Time {:.4f}({:.4f}) Total time {:.2f}({:.2f}) Grad Norm {:.2f}".format(
                epoch, itr, loss_meter.val, loss_meter.avg, time_meter.val, time_meter.avg, tt_meter.val, tt_meter.avg, grad_meter.val
            ))
            
        itr += 1
        
    # Compute validation loss
    model.eval()
    
    if epoch % args['val_freq'] == 0:
        with torch.no_grad():
            start = time.time()
            print("Validating...")
            losses = []
            for x, y in test_loader:
                x = cvt(x)
                losses.append(bits_per_dim(x, model))
                
            loss = torch.mean(torch.stack(losses))
            print("Epoch {} Validation Loss {:.4f} Time {:.4f}".format(epoch, loss, time.time() - start))
            
            # Save model if it is the best so far
            if loss < best_loss:
                best_loss = loss
                if not os.path.exists('results'):
                    os.makedirs('results')
                torch.save({
                    'args': args,
                    'state_dict': model.state_dict(),
                    'optim_state_dict': optimizer.state_dict()
                }, 'models/mnist_odenvp_best.pth')

    # Visualize the model
    with torch.no_grad():
        fig_filename = os.path.join('mnist_results', "{}.jpg".format(epoch))
        if not os.path.exists('mnist_results'):
            os.makedirs('mnist_results')
        generated_samples = model(fixed_z, reverse=True).view(-1, *data_shape)
        save_image(generated_samples, fig_filename, nrow=10)

Using device cpu.
Number of trainable parameters: 31107.
Using batch size 200. Total 300 iterations/epoch.
Epoch 1 Iter 0 Loss 8.1062(8.1062) Time 23.8059(23.8059) Total time 5.00(5.00) Grad Norm 0.40
Epoch 1 Iter 10 Loss 8.1046(8.1060) Time 25.5334(24.0784) Total time 5.00(5.00) Grad Norm 0.40
Epoch 1 Iter 20 Loss 8.1004(8.1052) Time 24.8206(24.2436) Total time 5.00(5.00) Grad Norm 0.38
Epoch 1 Iter 30 Loss 8.0964(8.1034) Time 25.4941(24.4282) Total time 5.00(5.00) Grad Norm 0.38
Epoch 1 Iter 40 Loss 8.0885(8.1002) Time 24.5258(24.5307) Total time 5.00(5.00) Grad Norm 0.36
Epoch 1 Iter 50 Loss 8.0794(8.0959) Time 24.5279(24.5523) Total time 5.00(5.00) Grad Norm 0.34
Epoch 1 Iter 60 Loss 8.0698(8.0902) Time 24.6228(24.6278) Total time 5.00(5.00) Grad Norm 0.31
Epoch 1 Iter 70 Loss 8.0587(8.0831) Time 24.9202(24.6708) Total time 5.00(5.00) Grad Norm 0.29
Epoch 1 Iter 80 Loss 8.0474(8.0750) Time 24.5255(24.6434) Total time 5.00(5.00) Grad Norm 0.26
Epoch 1 Iter 90 Loss 8.0345(8.0658) Tim