In [1]:
use_tensorboard = True

if use_tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter() 

In [2]:
train_langevine = False
mixed_sampling = True

In [3]:
import argparse
import json
import logging
import random
from datetime import datetime
from importlib import import_module
from itertools import chain
from os.path import join, exists

import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.autograd import Variable

from new.aae.pcutil import plot_3d_point_cloud
#from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging

cudnn.benchmark = True


In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif classname.find('Linear') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)


In [5]:
def get_grad_norm(parameters, norm_type=2):
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm

In [6]:
# random.seed(1234)
# torch.manual_seed(1234)
# torch.cuda.manual_seed_all(1234)

In [7]:
results_dir = "results/"

In [8]:
device = torch.device("cuda")

In [9]:
dataset_name = "shapenet"

In [10]:
from new.shapenet import ShapeNetDataset

In [11]:
dataset = ShapeNetDataset(root_dir="shapenet",
                          classes=["chair"])
points_dataloader = DataLoader(dataset, batch_size=16,
                               shuffle=True,
                               num_workers=8,
                               drop_last=True, pin_memory=True)

In [12]:
if train_langevine:
    from new.models import NetWrapper
    
    net = NetWrapper().to(device)
    net.apply(weights_init)

else:
    from new.aae.aae import Generator, Encoder
    from new.models import LangevinEncoderDecoder
    
    G =  Generator().to(device)
    E = Encoder().to(device)

    G.apply(weights_init)
    E.apply(weights_init)
    
    if mixed_sampling:
        net = LangevinEncoderDecoder(E, G)

In [13]:
from new.params import *
from new.utils import *

In [14]:
#
# Float Tensors
#
fixed_noise = torch.FloatTensor(16, z_dim)
fixed_noise.normal_(mean=0, std=0.2)
std_assumed = torch.tensor(0.2)

fixed_noise = fixed_noise.to(device)
std_assumed = std_assumed.to(device)

In [15]:
#
# Optimizers
#
optim_params = {
                "lr": 0.0005,
                "weight_decay": 0,
                "betas": [0.9, 0.999],
                "amsgrad": False
            }

if train_langevine:
    optE = torch.optim.Adam(net.netE.parameters(), lr = 1e-6)
    optG = torch.optim.Adam(net.netG.parameters(), **optim_params)
else:
    EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()),
                        **optim_params)

In [16]:
from new.champfer_loss import ChamferLoss

reconstruction_loss = ChamferLoss().to(device)

In [17]:
verbose = True

In [18]:
from tqdm.auto import tqdm
total_step = 0

In [None]:
for epoch in range(400):
    start_epoch_time = datetime.now()

    if train_langevine:
        net.train()
    else:     
        G.train()
        E.train()

    total_loss = 0.0
    for i, point_data in tqdm(enumerate(points_dataloader, 1)):
#         if i > 1:
#             break
        total_step += 1

        X, _ = point_data
        X = X.to(device)

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        if train_langevine:
            batch_num = X.shape[0]

            # Initialize chains
            z_g_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = X.device)
            z_e_0 = sample_p_0(n = batch_num,sig=g_init_sig, device = X.device)
            
            # print("z_g_0 norm", torch.mean(torch.linalg.vector_norm(z_g_0, dim = 1)))
            # print("z_e_0 norm", torch.mean(torch.linalg.vector_norm(z_e_0, dim = 1)))

            # Langevin posterior and prior
            z_g_k = net(Variable(z_g_0), X, prior=False, verbose = verbose)
            z_e_k = net(Variable(z_e_0), prior=True, verbose = verbose)

            # reconstruction
            X_hat = net.netG(z_g_k.detach())
            loss_g = net.loss_fun(X_hat.transpose(1,2).contiguous(), X.transpose(1,2).contiguous())

            # energy prior
            en_neg = net.netE(z_e_k.detach()).mean() # TODO(nijkamp): why mean() here and in Langevin sum() over energy? constant is absorbed into Adam adaptive lr
            en_pos = net.netE(z_g_k.detach()).mean()
            loss_e = (en_pos - en_neg) / batch_num

            # Learn generator
            optG.zero_grad()
            loss_g.backward()
            
            #if args.g_is_grad_clamp:
            # grad_norm_g = get_grad_norm(net.netG.parameters())
            # torch.nn.utils.clip_grad_norm(net.netG.parameters(), 10)
            
            optG.step()
            
        
            optE.zero_grad()

            #if args.g_is_grad_clamp:
            # grad_norm_e = get_grad_norm(net.netE.parameters())
            # torch.nn.utils.clip_grad_norm(net.netE.parameters(), 10)

            loss_e.backward()
            # grad_norm_e = get_grad_norm(net.netE.parameters())
            # if args.e_is_grad_clamp:
            #    torch.nn.utils.clip_grad_norm_(net.netE.parameters(), args.e_max_norm)
            optE.step()


            if i % 30 == 0:
                print(f'[{epoch}: ({i})] '
                          f'loss_g: {loss_g.item():.4f} '
                          f'(loss_e: {loss_e.item(): .4f}'
                          f' Time: {datetime.now() - start_epoch_time}')
                
        else:
            codes, mu, logvar = E(X)
            
            if mixed_sampling:
                codes = net.reparameterize(X, mu, logvar, use_lagivine = True, verbose = True)
                X_rec = G(codes)
                loss_g = torch.mean(
                     0.05 *reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5))
                
                loss_e = 0.05 * torch.mean((codes - mu)**2)
                
                loss_eg = loss_g + loss_e
                
                EG_optim.zero_grad()
                #E.zero_grad()
                #G.zero_grad()

                loss_eg.backward()
                total_loss += loss_eg.item()
                EG_optim.step()
                
                if i % 5 == 0:
                    if use_tensorboard:
                        writer.add_scalar('loss/loss_eg',loss_eg.item(), total_step)
                        writer.add_scalar('loss/loss_g',loss_g.item(), total_step)
                        writer.add_scalar('loss/loss_e',loss_e.item(), total_step)
                    else:
                        print(f'[{epoch}: ({i})] '
                                  f'Loss_EG: {loss_eg.item():.4f} '
                                  f'(E loss: {loss_e.item(): .4f}'
                                  f' G loss: {loss_g.item(): .4f})'
                                  f' Time: {datetime.now() - start_epoch_time}')
                
            else:            
                X_rec = G(codes)

                loss_e = torch.mean(
                     0.05 *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

                loss_kld = -0.5 * torch.mean(
                    1 - 2.0 * torch.log(std_assumed) + logvar -
                    (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

                loss_eg = loss_e + loss_kld
                EG_optim.zero_grad()
                E.zero_grad()
                G.zero_grad()

                loss_eg.backward()
                total_loss += loss_eg.item()
                EG_optim.step()

                if i % 30 == 0:
                    print(f'[{epoch}: ({i})] '
                              f'Loss_EG: {loss_eg.item():.4f} '
                              f'(REC: {loss_e.item(): .4f}'
                              f' KLD: {loss_kld.item(): .4f})'
                              f' Time: {datetime.now() - start_epoch_time}')

    print(
        f'[{epoch}/{400}] '
        f'Loss_G: {total_loss / i:.4f} '
        f'Time: {datetime.now() - start_epoch_time}'
    )
        
    ################################### eval ######################################
    #
    # Save intermediate results
    #
    if train_langevine:
        net.eval()
        with torch.no_grad():
            fake = net.netG(fixed_noise).data.cpu().numpy()
            X_rec = X_hat.data.cpu().numpy()
            X = X.data.cpu().numpy()
    else:
        G.eval()
        E.eval()
        with torch.no_grad():
            fake = G(fixed_noise).data.cpu().numpy()
            codes, _, _ = E(X)
            X_rec = G(codes).data.cpu().numpy()
            X = X.data.cpu().numpy()

    for k in range(5):
        fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(
            join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                  in_u_sphere=True, show=False,
                                  title=str(epoch))
        fig.savefig(
            join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(X_rec[k][0],
                                  X_rec[k][1],
                                  X_rec[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
        plt.close(fig)

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Langevin posterior   1/ 10: LOSS G=3915.051 z_norm:   2.195
Langevin posterior  10/ 10: LOSS G=1617.180 z_norm:   2.086
Langevin posterior   1/ 10: LOSS G=1254.213 z_norm:   2.328
Langevin posterior  10/ 10: LOSS G=1046.287 z_norm:   2.303
Langevin posterior   1/ 10: LOSS G=1552.479 z_norm:   2.451
Langevin posterior  10/ 10: LOSS G= 773.151 z_norm:   2.400
Langevin posterior   1/ 10: LOSS G=1105.519 z_norm:   2.297
Langevin posterior  10/ 10: LOSS G= 560.049 z_norm:   2.277
Langevin posterior   1/ 10: LOSS G= 971.609 z_norm:   2.219
Langevin posterior  10/ 10: LOSS G= 789.440 z_norm:   2.222
Langevin posterior   1/ 10: LOSS G= 661.131 z_norm:   1.999
Langevin posterior  10/ 10: LOSS G= 595.530 z_norm:   1.998
Langevin posterior   1/ 10: LOSS G= 903.476 z_norm:   1.896
Langevin posterior  10/ 10: LOSS G= 601.678 z_norm:   1.857
Langevin posterior   1/ 10: LOSS G= 990.773 z_norm:   1.860
Langevin posterior  10/ 10: LOSS G= 691.632 z_norm:   1.814
Langevin posterior   1/ 10: LOSS G= 984.

Langevin posterior  10/ 10: LOSS G= 406.674 z_norm:   3.221
Langevin posterior   1/ 10: LOSS G= 607.571 z_norm:   3.282
Langevin posterior  10/ 10: LOSS G= 344.114 z_norm:   3.270
Langevin posterior   1/ 10: LOSS G= 903.167 z_norm:   3.382
Langevin posterior  10/ 10: LOSS G= 478.639 z_norm:   3.357
Langevin posterior   1/ 10: LOSS G= 637.406 z_norm:   3.459
Langevin posterior  10/ 10: LOSS G= 375.885 z_norm:   3.439
Langevin posterior   1/ 10: LOSS G= 592.404 z_norm:   3.401
Langevin posterior  10/ 10: LOSS G= 382.409 z_norm:   3.382
Langevin posterior   1/ 10: LOSS G= 632.379 z_norm:   3.369
Langevin posterior  10/ 10: LOSS G= 429.128 z_norm:   3.356
Langevin posterior   1/ 10: LOSS G= 595.665 z_norm:   3.346
Langevin posterior  10/ 10: LOSS G= 397.566 z_norm:   3.347
Langevin posterior   1/ 10: LOSS G= 554.163 z_norm:   3.251
Langevin posterior  10/ 10: LOSS G= 306.767 z_norm:   3.247
Langevin posterior   1/ 10: LOSS G= 688.782 z_norm:   3.279
Langevin posterior  10/ 10: LOSS G= 380.

Langevin posterior  10/ 10: LOSS G= 275.479 z_norm:   6.652
Langevin posterior   1/ 10: LOSS G= 671.961 z_norm:   6.598
Langevin posterior  10/ 10: LOSS G= 299.830 z_norm:   6.596
Langevin posterior   1/ 10: LOSS G= 741.339 z_norm:   6.445
Langevin posterior  10/ 10: LOSS G= 367.914 z_norm:   6.427
Langevin posterior   1/ 10: LOSS G= 570.161 z_norm:   6.632
Langevin posterior  10/ 10: LOSS G= 362.585 z_norm:   6.630
Langevin posterior   1/ 10: LOSS G= 454.903 z_norm:   6.402
Langevin posterior  10/ 10: LOSS G= 317.141 z_norm:   6.396
Langevin posterior   1/ 10: LOSS G= 581.728 z_norm:   6.215
Langevin posterior  10/ 10: LOSS G= 307.614 z_norm:   6.210
Langevin posterior   1/ 10: LOSS G= 738.037 z_norm:   6.328
Langevin posterior  10/ 10: LOSS G= 375.412 z_norm:   6.319
Langevin posterior   1/ 10: LOSS G= 606.287 z_norm:   6.421
Langevin posterior  10/ 10: LOSS G= 330.723 z_norm:   6.421
Langevin posterior   1/ 10: LOSS G= 417.087 z_norm:   5.741
Langevin posterior  10/ 10: LOSS G= 244.

Langevin posterior  10/ 10: LOSS G= 298.338 z_norm:   3.293
Langevin posterior   1/ 10: LOSS G= 421.695 z_norm:   3.216
Langevin posterior  10/ 10: LOSS G= 322.621 z_norm:   3.219
Langevin posterior   1/ 10: LOSS G= 428.503 z_norm:   3.035
Langevin posterior  10/ 10: LOSS G= 320.464 z_norm:   3.029
Langevin posterior   1/ 10: LOSS G= 472.879 z_norm:   2.923
Langevin posterior  10/ 10: LOSS G= 309.576 z_norm:   2.921
Langevin posterior   1/ 10: LOSS G= 486.543 z_norm:   3.113
Langevin posterior  10/ 10: LOSS G= 374.611 z_norm:   3.112
Langevin posterior   1/ 10: LOSS G= 510.689 z_norm:   2.937
Langevin posterior  10/ 10: LOSS G= 368.020 z_norm:   2.941
Langevin posterior   1/ 10: LOSS G= 347.719 z_norm:   2.818
Langevin posterior  10/ 10: LOSS G= 315.572 z_norm:   2.826
Langevin posterior   1/ 10: LOSS G= 426.718 z_norm:   2.913
Langevin posterior  10/ 10: LOSS G= 373.588 z_norm:   2.956
Langevin posterior   1/ 10: LOSS G= 435.497 z_norm:   2.752
Langevin posterior  10/ 10: LOSS G= 324.

Langevin posterior  10/ 10: LOSS G= 318.740 z_norm:   2.847
Langevin posterior   1/ 10: LOSS G= 384.187 z_norm:   2.863
Langevin posterior  10/ 10: LOSS G= 251.068 z_norm:   2.857
Langevin posterior   1/ 10: LOSS G= 602.852 z_norm:   2.827
Langevin posterior  10/ 10: LOSS G= 305.499 z_norm:   2.829
Langevin posterior   1/ 10: LOSS G= 463.696 z_norm:   2.932
Langevin posterior  10/ 10: LOSS G= 339.080 z_norm:   2.949
Langevin posterior   1/ 10: LOSS G= 428.384 z_norm:   2.910
Langevin posterior  10/ 10: LOSS G= 362.511 z_norm:   2.915
Langevin posterior   1/ 10: LOSS G= 384.594 z_norm:   2.918
Langevin posterior  10/ 10: LOSS G= 270.912 z_norm:   2.924
Langevin posterior   1/ 10: LOSS G= 506.147 z_norm:   2.886
Langevin posterior  10/ 10: LOSS G= 343.906 z_norm:   2.896
Langevin posterior   1/ 10: LOSS G= 394.758 z_norm:   2.891
Langevin posterior  10/ 10: LOSS G= 259.671 z_norm:   2.887
Langevin posterior   1/ 10: LOSS G= 438.829 z_norm:   2.905
Langevin posterior  10/ 10: LOSS G= 326.

In [None]:
a = torch.randn(16,128)

In [None]:
torch.mean(torch.linalg.vector_norm(a, dim = 1))

In [None]:
batch_size