In [None]:
import argparse
import json
import logging
import random
from datetime import datetime
from importlib import import_module
from itertools import chain
from os.path import join, exists

import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader

from new.aae.pcutil import plot_3d_point_cloud
#from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging

cudnn.benchmark = True


In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif classname.find('Linear') != -1:
        gain = torch.nn.init.calculate_gain('relu')
        torch.nn.init.xavier_uniform_(m.weight, gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)


In [None]:
random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

In [None]:
results_dir = "results/"

In [None]:
device = torch.device("cuda")

In [None]:
dataset_name = "shapenet"

In [None]:
from new.shapenet import ShapeNetDataset

In [None]:
dataset = ShapeNetDataset(root_dir="shapenet",
                          classes=["chair"])
points_dataloader = DataLoader(dataset, batch_size=16,
                               shuffle=True,
                               num_workers=8,
                               drop_last=True, pin_memory=True)

In [None]:
from new.aae.aae import Generator, Encoder

G =  Generator().to(device)
E = Encoder().to(device)

G.apply(weights_init)
E.apply(weights_init)

In [None]:
from new.params import z_dim

In [None]:
#
# Float Tensors
#
fixed_noise = torch.FloatTensor(16, z_dim, 1)
fixed_noise.normal_(mean=0, std=0.2)
std_assumed = torch.tensor(0.2)

fixed_noise = fixed_noise.to(device)
std_assumed = std_assumed.to(device)

#
# Optimizers
#
optim_params = {
                "lr": 0.0005,
                "weight_decay": 0,
                "betas": [0.9, 0.999],
                "amsgrad": False
            }

EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()),
                    **optim_params)

In [None]:
from new.champfer_loss import ChamferLoss

reconstruction_loss = ChamferLoss().to(device)

In [None]:
for epoch in range(400):
    start_epoch_time = datetime.now()

    G.train()
    E.train()

    total_loss = 0.0
    for i, point_data in enumerate(points_dataloader, 1):

        X, _ = point_data
        X = X.to(device)

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        codes, mu, logvar = E(X)
        X_rec = G(codes)

        loss_e = torch.mean(
             0.05 *
            reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                X_rec.permute(0, 2, 1) + 0.5))

        loss_kld = -0.5 * torch.mean(
            1 - 2.0 * torch.log(std_assumed) + logvar -
            (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

        loss_eg = loss_e + loss_kld
        EG_optim.zero_grad()
        E.zero_grad()
        G.zero_grad()

        loss_eg.backward()
        total_loss += loss_eg.item()
        EG_optim.step()
        
        if i % 30 == 0:
            print(f'[{epoch}: ({i})] '
                      f'Loss_EG: {loss_eg.item():.4f} '
                      f'(REC: {loss_e.item(): .4f}'
                      f' KLD: {loss_kld.item(): .4f})'
                      f' Time: {datetime.now() - start_epoch_time}')

    print(
        f'[{epoch}/{400}] '
        f'Loss_G: {total_loss / i:.4f} '
        f'Time: {datetime.now() - start_epoch_time}'
    )
    
    #
    # Save intermediate results
    #
    G.eval()
    E.eval()
    with torch.no_grad():
        fake = G(fixed_noise).data.cpu().numpy()
        codes, _, _ = E(X)
        X_rec = G(codes).data.cpu().numpy()
        X = X.data.cpu().numpy()

    for k in range(5):
        fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(
            join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                  in_u_sphere=True, show=False,
                                  title=str(epoch))
        fig.savefig(
            join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
        plt.close(fig)

    for k in range(5):
        fig = plot_3d_point_cloud(X_rec[k][0],
                                  X_rec[k][1],
                                  X_rec[k][2],
                                  in_u_sphere=True, show=False)
        fig.savefig(join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
        plt.close(fig)