# VAE6
## Changes to VAE3
* Add robustness to blurriness
* Improved beta-VAE

In [None]:
import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import pytorch3d
from skimage.filters import gaussian

import easydict
from tqdm import tqdm

from vae_3d import VAE3D
from gaussian_smoothing import GaussianSmoothing

# Check whether GPU is available.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.manual_seed(1337)
np.random.seed(1337)

In [None]:
# Set hyperparameters.
args = easydict.EasyDict({
    'train': True,
    'batch_size': 64,       # input batch size
    'n_epochs': 500,         # number of epochs
    'n_workers': 4,         # number of data loading workers
    'learning_rate': 0.0005, # learning rate
    'beta1': 0.9,           # beta 1
    'beta2': 0.999,         # beta 2
    'milestones': [1],        # step size
    'gamma': 1,           # gamma
    'beta': 1000,
    'maxC': 5,
    'bce_weight': 0.9,
    'l2_reg': 0.0,
    'sigma': 2,
    'model': '',                            # model path
    'out_dir': 'outputs_vae6'    # output directory
})

In [None]:
from dataset.shapenet_voxel import ShapeNetVoxel, ShapeNetVoxel32

SYNSET_CHAIR = '03001627'
SYNSET_JAR = '03593526'

SHAPENET_PATH = '/home/ubuntu/voxel-autoencoder/shapenet/ShapeNetCore.v2'
R2N2_PATH = '/home/ubuntu/voxel-autoencoder/shapenet/ShapeNetVox32'

shapenet_dataset = ShapeNetVoxel32(R2N2_PATH, synsets=[SYNSET_CHAIR], version=2, load_textures=True)

len(shapenet_dataset)

In [None]:
# Load the data and create dataloaders.
def create_datasets_and_dataloaders():
    full_size = len(shapenet_dataset)
    train_size = int(0.8 * full_size)
    test_size = full_size - train_size
    
    train_data, test_data = torch.utils.data.random_split(shapenet_dataset, [train_size, test_size])

    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=args.train,
        num_workers=int(args.n_workers))

    test_dataloader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=args.train,
        num_workers=int(args.n_workers))

    return train_data, train_dataloader, test_data, test_dataloader

In [None]:
# Define the cross entropy loss function.
def weighted_binary_cross_entropy(output, target):
    gamma = args.bce_weight
    return -torch.mean(gamma*(target)*torch.log(output) + (1-gamma)*(1-target)*torch.log(1-output))

def compute_loss(in_voxels, out_voxels, mean, logvar, epoch):
    # in_voxels: (batch_size, 1, V, V, V)
    # out_voxels: (batch_size, 1, V, V, V)
    # mean: (batch_size, Z)
    # logvar: (batch_size, Z)
    recon_loss = weighted_binary_cross_entropy(0.1 + 0.8999 * out_voxels, -1 + 6*in_voxels)
    kl_loss = 0.5 * torch.mean(torch.exp(logvar) + mean**2 - 1 - logvar)
    if epoch is None:
        epoch = args.n_epochs
    C = args.maxC * np.clip((epoch - 100) / 300, 0, 1)
    beta = -1 + ((args.beta+1) ** np.clip((epoch - 100) / 300, 0, 1))
#     beta = args.beta * np.clip((epoch - 100) / 300, 0, 1)
    beta = np.clip(beta, 0, args.beta)
    
    loss = recon_loss + beta*torch.abs(kl_loss - C)
    return loss

In [None]:
# Define the accuracy function.
def compute_accuracy(in_voxels, out_voxels):
    # in_voxels: (batch_size, 1, voxel_size, voxel_size, voxel_size)
    # out_voxels: (batch_size, 1, voxel_size, voxel_size, voxel_size)
    batch_size = in_voxels.shape[0]
    in_bin = in_voxels > 0.5
    out_bin = out_voxels > 0.5
    acc = (in_bin & out_bin).reshape(batch_size, -1).sum(axis=1) / (in_bin | out_bin).reshape(batch_size, -1).sum(axis=1)
    return acc.mean() * 100

In [None]:
# Define one-step training function.
blur = torch.nn.Sequential(
    torch.nn.ConstantPad3d(args.sigma, 0),
    GaussianSmoothing(1, args.sigma * 2 + 1, args.sigma, 3)
).cuda()

def run_train(data, net, optimizer, writer=None, iter_epoch=None):
    # Parse data.
    original_voxels = data.cuda()
    in_voxels = blur(original_voxels)
    # in_voxels: (batch_size, 1, V, V, V)

    # Reset gradients.
    # https://pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html#zero-the-gradients-while-training-the-network
    optimizer.zero_grad()

    # Predict.
    out_voxels, mean, logvar = net.train()(in_voxels)

    # Compute the loss.
    loss = compute_loss(original_voxels, out_voxels, mean, logvar, iter_epoch)

    with torch.no_grad():
        # Compute the accuracy.
        acc = compute_accuracy(original_voxels, out_voxels)

    # Backprop.
    loss.backward()
    optimizer.step()

    return loss, acc

In [None]:
# Define one-step evaluation function.
def run_eval(data, net, optimizer, writer=None, iter_epoch=None):
    # Parse data.
    original_voxels = data.cuda()
    in_voxels = blur(original_voxels)
    # in_voxels: (batch_size, 1, V, V, V)

    with torch.no_grad():
        # Predict.
        out_voxels, mean, logvar = net.eval()(in_voxels)

        # Compute the loss.
        loss = compute_loss(original_voxels, out_voxels, mean, logvar, iter_epoch)

        # Compute the accuracy.
        acc = compute_accuracy(original_voxels, out_voxels)

    return loss, acc

In [None]:
# Define one-epoch training/evaluation function.
def run_epoch(dataset, dataloader, train, epoch=None, writer=None):
    total_loss = 0.0
    total_acc = 0.0
    n_data = len(dataset)

    # Create a progress bar.
    pbar = tqdm(total=n_data, leave=False)

    mode = 'Train' if train else 'Test'
    epoch_str = '' if epoch is None else '[Epoch {}/{}]'.format(
            str(epoch).zfill(len(str(args.n_epochs))), args.n_epochs)

    for i, data in enumerate(dataloader):
        # Run one step.
        iter_epoch = args.n_epochs if epoch is None else epoch + i/len(dataloader)
        loss, acc = run_train(data, net, optimizer, writer, iter_epoch) if train else \
                run_eval(data, net, optimizer, writer, iter_epoch)

        if train and writer is not None:
            # Write results if training.
            assert(epoch is not None)
            step = epoch * len(dataloader) + i
            writer.add_scalar('Loss/Train', loss, step)
            writer.add_scalar('Accuracy/Train', acc, step)

        batch_size = data.shape[0]
        total_loss += (loss * batch_size)
        total_acc += (acc * batch_size)

        pbar.set_description('{} {} Loss: {:f}, Acc : {:.2f}%'.format(
            epoch_str, mode, loss, acc))
        pbar.update(batch_size)

    pbar.close()
    mean_loss = total_loss / float(n_data)
    mean_acc = total_acc / float(n_data)
    return mean_loss, mean_acc

In [None]:
# Define one-epoch function for both training and evaluation.
def run_epoch_train_and_test(
    train_dataset, train_dataloader, test_dataset, test_dataloader, epoch=None,
        writer=None):
    train_loss, train_acc = run_epoch(
        train_dataset, train_dataloader, train=args.train, epoch=epoch,
        writer=writer)
    test_loss, test_acc = run_epoch(
        test_dataset, test_dataloader, train=False, epoch=epoch, writer=None)

    if writer is not None:
        # Write test results.
        assert(epoch is not None)
        step = (epoch + 1) * len(train_dataloader)
        writer.add_scalar('Loss/Test', test_loss, step)
        writer.add_scalar('Accuracy/Test', test_acc, step)

    epoch_str = '' if epoch is None else '[Epoch {}/{}]'.format(
            str(epoch).zfill(len(str(args.n_epochs))), args.n_epochs)

    log = epoch_str + ' '
    log += 'Train Loss: {:f}, '.format(train_loss)
    log += 'Train Acc: {:.2f}%, '.format(train_acc)
    log += 'Test Loss: {:f}, '.format(test_loss)
    log += 'Test Acc: {:.2f}%.'.format(test_acc)
    print(log)

In [None]:
# Main function.
if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
    net = VAE3D(32)
    if torch.cuda.is_available():
        net.cuda()

    # Load a model if given.
    if args.model != '':
        net.load_state_dict(torch.load(args.model))

    # Set an optimizer and a scheduler.
#     optimizer = torch.optim.SGD(
#         net.parameters(), lr=args.learning_rate,
#         momentum=0.9, weight_decay=args.l2_reg, nesterov=True)
    optimizer = torch.optim.Adam(
        net.parameters(), lr=args.learning_rate,
        betas=(args.beta1, args.beta2), weight_decay=args.l2_reg)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=args.gamma)

    # Create the output directory.
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    # Train.
    if args.train:
        writer = SummaryWriter(args.out_dir)

        for epoch in range(args.n_epochs):
            run_epoch_train_and_test(
                train_dataset, train_dataloader, test_dataset, test_dataloader,
                epoch, writer)

            if (epoch + 1) % 50 == 0:
                # Save the model.
                model_file = os.path.join(
                    args.out_dir, 'model_{:d}.pth'.format(epoch + 1))
                torch.save(net.state_dict(), model_file)
                print("Saved '{}'.".format(model_file))

            scheduler.step()

        writer.close()
    else:
        run_epoch_train_and_test(
            train_dataset, train_dataloader, test_dataset, test_dataloader)

In [None]:
# Main function.
if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
#     net = VAE3D(100)
    if torch.cuda.is_available():
        net.cuda()

    # Load a model if given.
    if args.model != '':
        net.load_state_dict(torch.load(args.model))

    with torch.no_grad():
        import random
        train_data = random.choice(test_dataset).unsqueeze(0)
        train_out, _, _ = net.eval()(blur(train_data.cuda()))
#         train_out, _, _ = net.eval()(train_data.cuda())
    
    sample_origin = train_data[0,0,:,:,:].cpu().numpy()
    sample_recon = train_out[0,0,:,:,:].cpu().numpy()
    
    print(sample_recon.max(), sample_recon.min())
    
    import matplotlib.pyplot as plt
    print('drawing original')
    ax1 = plt.figure().add_subplot(projection='3d')
    ax1.voxels(sample_origin > 0.5)
    print('drawing reconstruction')
    ax2 = plt.figure().add_subplot(projection='3d')
    ax2.voxels(sample_recon > 0.1)

    plt.show()

In [None]:
# Main function.
if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
#     net = VAE3D(128)
#     if torch.cuda.is_available():
#         net.cuda()

#     # Load a model if given.
#     args.model = 'outputs_vae5/model_200.pth'
#     if args.model != '':
#         net.load_state_dict(torch.load(args.model))

    num_points = len(train_dataset) + len(test_dataset)
    z = np.zeros((num_points, 32))
    c = np.zeros((num_points))
    with torch.no_grad():
        idx = 0
        for train_data in tqdm(train_dataset):
            train_data = train_data.unsqueeze(0)
            mean, _ = net.encoder.eval()(blur(train_data.cuda()))
            z[idx,:] = mean[0,:].cpu().numpy()
            c[idx] = 0
            idx += 1
        
        for test_data in tqdm(test_dataset):
            test_data = test_data.unsqueeze(0)
            mean, _ = net.encoder.eval()(blur(test_data.cuda()))
            z[idx,:] = mean[0,:].cpu().numpy()
            c[idx] = 1
            idx += 1

In [None]:
from sklearn.manifold import TSNE
z_embedded = TSNE().fit_transform(z)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 20))
plt.scatter(z_embedded[:,0], z_embedded[:,1], c=c, cmap='tab10')
plt.colorbar()
plt.show()

In [None]:
def compute_kl_loss(in_voxels, out_voxels, mean, logvar, epoch):
    # in_voxels: (batch_size, 1, V, V, V)
    # out_voxels: (batch_size, 1, V, V, V)
    # mean: (batch_size, Z)
    # logvar: (batch_size, Z)
    kl_loss = 0.5 * torch.mean(torch.exp(logvar) + mean**2 - 1 - logvar)
    return kl_loss

# Main function.
if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
#     net = VAE3D(128)
#     if torch.cuda.is_available():
#         net.cuda()

#     # Load a model if given.
#     args.model = 'outputs_vae5/model_200.pth'
#     if args.model != '':
#         net.load_state_dict(torch.load(args.model))

    num_points = len(train_dataset) + len(test_dataset)
#     num_points = len(test_dataset)
    loss = 0.0
    with torch.no_grad():
        idx = 0
        for train_data in tqdm(train_dataset):
            train_data = train_data.unsqueeze(0)
            mean, logvar = net.encoder.eval()(blur(train_data.cuda()))
            loss += 0.5 * torch.mean(torch.exp(logvar) + mean**2 - 1 - logvar)
        
        for test_data in tqdm(test_dataset):
            test_data = test_data.unsqueeze(0)
            mean, logvar = net.encoder.eval()(blur(test_data.cuda()))
            loss += 0.5 * torch.mean(torch.exp(logvar) + mean**2 - 1 - logvar)
    loss /= num_points
    print(f'KL Divergence = {loss:.5f} nats')

In [None]:
with torch.no_grad():
    _, _, pc_basis = torch.pca_lowrank(torch.tensor(z).unsqueeze(0).cuda(), 32)
    model_file = os.path.join(
        args.out_dir, 'model_basis.pth')
    torch.save(pc_basis, model_file)

In [None]:
# Main function.
def visualize(data, threshold=0.5, desc=''):
    import matplotlib.pyplot as plt
    ax1 = plt.figure().add_subplot(projection='3d')
    ax1.set_title(desc)
    ax1.voxels(data > threshold)

if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
#     net = VAE3D(128)
#     if torch.cuda.is_available():
#         net.cuda()

#     # Load a model if given.
#     args.model = 'outputs_vae5/model_200.pth'
#     if args.model != '':
#         net.load_state_dict(torch.load(args.model))

    with torch.no_grad():
        import random
        data1, data2 = random.choice(test_dataset).unsqueeze(0), random.choice(test_dataset).unsqueeze(0)
        data = torch.cat((data1, data2), 0)
        visualize(data[0,0,...], desc='Data 1')
        visualize(data[1,0,...], desc='Data 2')
        
        z, _ = net.encoder(blur(data.cuda()))
        t = torch.linspace(0, 1, 11).unsqueeze(1).cuda()
        zt = (1-t)*z[[0],:] + t*z[[1],:]
        recon = net.decoder(zt)
        for i in tqdm(range(11)):
            visualize(recon[i,0,...], threshold=0.1, desc=f'Mixture (t = {t[i].item()})')

In [None]:
if __name__ == "__main__":
    print(args)

    # Load datasets.
    train_dataset, train_dataloader, test_dataset, test_dataloader = create_datasets_and_dataloaders()

    # Create the network.
#     net = VAE3D(128)
#     if torch.cuda.is_available():
#         net.cuda()

#     # Load a model if given.
#     args.model = 'outputs_vae5/model_200.pth'
#     if args.model != '':
#         net.load_state_dict(torch.load(args.model))

    with torch.no_grad():
        import random
#         data = random.choice(test_dataset).unsqueeze(0)
#         visualize(data[0,0,...], desc='Data - original')
        
        BASIS = 2
        print(f'Using {BASIS}-th basis')
        basis = pc_basis[0,:,BASIS]
        
#         z, _ = net.encoder(data.cuda())
        t = (torch.linspace(-1, 1, 7).unsqueeze(1).cuda() ** 1) * 10
        
        print(z, '\n', basis)
        deltaz = t*basis[None,:]
        zt = deltaz
#         zt = deltaz + z
        recon = net.decoder(zt.float())
        for i in tqdm(range(7)):
#             visualize(recon[i,0,...], threshold=0.2, desc=f'Mixture (t = {t[i].item()})')
            visualize(recon[i,0,...], threshold=0.1, desc='')