In [1]:
train_langevine = True

In [2]:
import argparse
import json
import logging
import random
from datetime import datetime
from importlib import import_module
from itertools import chain
from os.path import join, exists

import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.autograd import Variable

from new.aae.pcutil import plot_3d_point_cloud
#from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging

cudnn.benchmark = True


In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif classname.find('Linear') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)


In [4]:
random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

In [5]:
results_dir = "results/"

In [6]:
device = torch.device("cuda")

In [7]:
dataset_name = "shapenet"

In [8]:
from new.shapenet import ShapeNetDataset

In [9]:
dataset = ShapeNetDataset(root_dir="shapenet",
                          classes=["chair"])
points_dataloader = DataLoader(dataset, batch_size=16,
                               shuffle=True,
                               num_workers=8,
                               drop_last=True, pin_memory=True)

In [10]:
if train_langevine:
    from new.models import NetWrapper
    
    net = NetWrapper().to(device)
    net.apply(weights_init)

else:
    from new.aae.aae import Generator, Encoder

    G =  Generator().to(device)
    E = Encoder().to(device)

    G.apply(weights_init)
    E.apply(weights_init)

In [11]:
from new.params import *
from new.utils import *

In [12]:
#
# Float Tensors
#
fixed_noise = torch.FloatTensor(16, z_dim)
fixed_noise.normal_(mean=0, std=0.2)
std_assumed = torch.tensor(0.2)

fixed_noise = fixed_noise.to(device)
std_assumed = std_assumed.to(device)

In [13]:
#
# Optimizers
#
optim_params = {
                "lr": 0.0005,
                "weight_decay": 0,
                "betas": [0.9, 0.999],
                "amsgrad": False
            }

if train_langevine:
    optE = torch.optim.Adam(net.netE.parameters(), **optim_params)
    optG = torch.optim.Adam(net.netG.parameters(), **optim_params)
else:
    EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()),
                        **optim_params)

In [14]:
from new.champfer_loss import ChamferLoss

reconstruction_loss = ChamferLoss().to(device)

In [None]:
for epoch in range(400):
    start_epoch_time = datetime.now()

    if train_langevine:
        net.train()
    else:     
        G.train()
        E.train()

    total_loss = 0.0
    for i, point_data in enumerate(points_dataloader, 1):
#         if i > 1:
#             break

        X, _ = point_data
        X = X.to(device)

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        if train_langevine:
            batch_num = X.shape[0]

            # Initialize chains
            z_g_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = X.device)
            z_e_0 = sample_p_0(n = batch_num,sig=g_init_sig, device = X.device)

            # Langevin posterior and prior
            z_g_k = net(Variable(z_g_0), X, prior=False)
            z_e_k = net(Variable(z_e_0), prior=True)

            # reconstruction
            X_hat = net.netG(z_g_k.detach())
            loss_g = net.loss_fun(X_hat.transpose(1,2).contiguous() + 0.5, X.transpose(1,2).contiguous() + 0.5)

            # energy prior
            en_neg = net.netE(z_e_k.detach()).mean() # TODO(nijkamp): why mean() here and in Langevin sum() over energy? constant is absorbed into Adam adaptive lr
            en_pos = net.netE(z_g_k.detach()).mean()
            loss_e = en_pos - en_neg

            # Learn generator
            optG.zero_grad()
            loss_g.backward()
            optG.step()

            optE.zero_grad()
            loss_e.backward()
            # grad_norm_e = get_grad_norm(net.netE.parameters())
            # if args.e_is_grad_clamp:
            #    torch.nn.utils.clip_grad_norm_(net.netE.parameters(), args.e_max_norm)
            optE.step()


            if i % 30 == 0:
                print(f'[{epoch}: ({i})] '
                          f'loss_g: {loss_g.item():.4f} '
                          f'(loss_e: {loss_e.item(): .4f}'
                          f' Time: {datetime.now() - start_epoch_time}')


        else:
            codes, mu, logvar = E(X)
            X_rec = G(codes)

            loss_e = torch.mean(
                 0.05 *
                reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = -0.5 * torch.mean(
                1 - 2.0 * torch.log(std_assumed) + logvar -
                (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

            loss_eg = loss_e + loss_kld
            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss_eg.backward()
            total_loss += loss_eg.item()
            EG_optim.step()

            if i % 30 == 0:
                print(f'[{epoch}: ({i})] '
                          f'Loss_EG: {loss_eg.item():.4f} '
                          f'(REC: {loss_e.item(): .4f}'
                          f' KLD: {loss_kld.item(): .4f})'
                          f' Time: {datetime.now() - start_epoch_time}')

    print(
        f'[{epoch}/{400}] '
        f'Loss_G: {total_loss / i:.4f} '
        f'Time: {datetime.now() - start_epoch_time}'
    )
        
    ################################### eval ######################################
    #
    # Save intermediate results
    #
    if train_langevine:
        net.eval()
        with torch.no_grad():
            fake = net.netG(fixed_noise).data.cpu().numpy()
            X_rec = X_hat.data.cpu().numpy()
            X = X.data.cpu().numpy()
    else:
        G.eval()
        E.eval()
        with torch.no_grad():
            fake = G(fixed_noise).data.cpu().numpy()
            codes, _, _ = E(X)
            X_rec = G(codes).data.cpu().numpy()
            X = X.data.cpu().numpy()

    for k in range(5):
        fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(
            join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                  in_u_sphere=True, show=False,
                                  title=str(epoch))
        fig.savefig(
            join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(X_rec[k][0],
                                  X_rec[k][1],
                                  X_rec[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
        plt.close(fig)