In [None]:
use_tensorboard = True
exp_name = "Oct_26_school"

In [None]:
import torch

print(torch.__version__)

print(torch.cuda.get_device_name(0))

print(torch.version.cuda)

from torch.autograd import Variable
from tqdm.auto import tqdm

import matplotlib
import matplotlib.pyplot as plt

if use_tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter("runs/" + exp_name)
    


In [None]:
from utils.util_torch import *

from new.configs import *
from new.utils import * 
from new.models import *

In [None]:
opt = parse_config()

if use_tensorboard:
    writer.add_hparams(hparam_dict = vars(opt),metric_dict = {})

opt.batch_size = batch_size * 3
opt.swap_axis = True
if len(opt.checkpoint_path) == 0:
    opt.checkpoint_path = None 
opt.device = "cuda:%s" % opt.cuda if opt.cuda!="" else "cpu"
opt.shuffle = not opt.warm_start

In [None]:
opt.category = "modelnet10"

In [None]:
print(opt)

In [None]:
train_data = PointCloudDataSet(opt)
data_collator = PointCloudDataCollator(opt)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
    shuffle=opt.shuffle, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

In [None]:
len(train_data)

In [None]:
# test_dataset = torch.utils.data.TensorDataset(np.load("data/%s_test.npy" % opt.category))
# test_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
#     shuffle=False, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

In [None]:
#x = next(iter(train_loader))

In [None]:
#x.shape

In [None]:
net = NetWrapper()

In [None]:
if torch.cuda.is_available():
    net = net.cuda()

In [None]:
# set_gpu(args.device)
set_cuda(deterministic=gpu_deterministic)
set_seed(123)

In [None]:
optE = torch.optim.Adam(net.netE.parameters(), lr=e_lr, weight_decay=e_decay, betas=(e_beta1, e_beta2))
optG = torch.optim.Adam(net.netG.parameters(), lr=g_lr, weight_decay=g_decay, betas=(g_beta1, g_beta2))

lr_scheduleE = torch.optim.lr_scheduler.ExponentialLR(optE, e_gamma)
lr_scheduleG = torch.optim.lr_scheduler.ExponentialLR(optG, g_gamma)

In [None]:
# Test reference
# ref_pcs = np.load("data/%s_test.npy" % opt.category)

# Train

In [None]:
total_step = 0

In [None]:
%matplotlib inline

for epoch in tqdm(range(100)):
#     if epoch > 0:
#         break
    # Train phase
    net.train()
    for c, x in enumerate(train_loader):
        total_step += 1
        batch_num = x.shape[0]

        if torch.cuda.is_available():
            x = x.to("cuda")

        # Initialize chains
        z_g_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)
        z_e_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)

#         print("shape log")
#         print(x.shape)
#         print(z_g_0.shape)
#         print(z_e_0.shape)

        # Langevin posterior and prior
        z_g_k = net(Variable(z_g_0), x, prior=False)
        z_e_k = net(Variable(z_e_0), prior=True)
        
        # print("z_g_k", z_g_k)
        # print("z_e_k shape", z_e_k.shape)
        
        # Learn generator
        optG.zero_grad()
        x_hat = net.netG(z_g_k.detach())
        
        #print("x_hat shape", x_hat.shape)
        
        loss_g = net.loss_fun(x_hat.transpose(1,2).contiguous(), x.transpose(1,2).contiguous())
        loss_g.backward()
        optG.step()

        # Learn prior EBM
        optE.zero_grad()
        en_neg = energy(net.netE(z_e_k.detach())).mean() # TODO(nijkamp): why mean() here and in Langevin sum() over energy? constant is absorbed into Adam adaptive lr
        en_pos = energy(net.netE(z_g_k.detach())).mean()
        loss_e = en_pos - en_neg
        loss_e.backward()
        # grad_norm_e = get_grad_norm(net.netE.parameters())
        # if args.e_is_grad_clamp:
        #    torch.nn.utils.clip_grad_norm_(net.netE.parameters(), args.e_max_norm)
        optE.step()
        
        # break
        
        # Printout
        if total_step % 15 == 0:
            with torch.no_grad():
                x_0 = net.netG(z_e_0)
                x_k = net.netG(z_e_k)

                en_neg_2 = energy(net.netE(z_e_k)).mean()
                en_pos_2 = energy(net.netE(z_g_k)).mean()

                prior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_e_k.mean(), z_e_k.std(), z_e_k.abs().max())
                posterior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_g_k.mean(), z_g_k.std(), z_g_k.abs().max())
                
                
                if use_tensorboard:
                    writer.add_scalar('loss/loss_g',loss_g, total_step)
                    writer.add_scalar('loss/loss_e',loss_e, total_step)

                    writer.add_scalars('energy/en_pos', {'pos_1':en_pos,
                                        'pose_2':en_pos_2,
                                        'diff': en_pos_2 - en_pos}, total_step)
                    writer.add_scalars('energy/en_neg', {'pos_1':en_neg,
                                        'pose_2':en_neg_2,
                                        'diff': en_neg_2 - en_neg}, total_step)

                    writer.add_scalar('value/|z_g_0|',z_g_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_g_k|',z_g_k.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_e_0|',z_e_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_e_k|',z_e_k.view(batch_num, -1).norm(dim=1).mean(), total_step)

                    writer.add_scalar('disp/z_e_disp',(z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('disp/z_g_disp',(z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('disp/x_e_disp',(x_k-x_0).view(batch_num, -1).norm(dim=1).mean(), total_step)

                    writer.add_scalars('moment/prior_moments', {'mean':z_e_k.mean(),
                                        'std':z_e_k.std(),
                                        'max abs': z_e_k.abs().max()}, total_step)
                    writer.add_scalars('moment/posterior_moments', {'mean':z_g_k.mean(),
                                        'std':z_g_k.std(),
                                        'max abs': z_g_k.abs().max()}, total_step)
                else:
                    print(
                        '{} {}/{} {}/{} \n'.format(0, epoch, n_epochs, total_step, len(train_loader)) +
                        'loss_g={:8.5f}, \n'.format(loss_g) +
                        'loss_e={:8.5f}, \n'.format(loss_e) +
                        'en_pos=[{:9.5f}, {:9.5f}, {:9.5f}], \n'.format(en_pos, en_pos_2, en_pos_2-en_pos) +
                        'en_neg=[{:9.5f}, {:9.5f}, {:9.5f}], \n'.format(en_neg, en_neg_2, en_neg_2-en_neg) +
                        '|z_g_0|={:6.3f}, \n'.format(z_g_0.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_g_k|={:6.3f}, \n'.format(z_g_k.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_e_0|={:6.3f}, \n'.format(z_e_0.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_e_k|={:6.3f}, \n'.format(z_e_k.view(batch_num, -1).norm(dim=1).mean()) +
                        'z_e_disp={:6.3f}, \n'.format((z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'z_g_disp={:6.3f}, \n'.format((z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'x_e_disp={:6.3f}, \n'.format((x_k-x_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'prior_moments={}, \n'.format(prior_moments) +
                        'posterior_moments={}, \n'.format(posterior_moments) +
                        #'fid={:8.2f}, '.format(fid) +
                        #'fid_best={:8.2f}'.format(fid_best)
                        "\n\n\n ---------------------"
                    )
                    
    # Eval phase
    
    net.eval()
    syn_pcs = net.sample_x(n=16)
#     res = quantitative_analysis(syn_pcs.data.numpy(), ref_pcs, 16, full=False)
#     if use_tensorboard:
#         writer.add_scalar('test_record/jsd', res['jsd'], epoch)
#         writer.add_scalar('test_record/mmd-CD', res['mmd-CD'], epoch)
#         writer.add_scalar('test_record/mmd-EMD', res['mmd-EMD'], epoch)
#         writer.add_scalar('test_record/cov-CD', res['cov-CD'], epoch)
#         writer.add_scalar('test_record/cov-EMD', res['cov-EMD'], epoch)
#     else:
#         print("epoch {} test record {}".format(epoch ,res))
        
    show_point_clouds(syn_pcs)
    print(epoch)
    plt.show()

In [None]:
for c, x in enumerate(train_loader):
    print(c)
    show_point_clouds(x)

In [None]:
show_point_clouds(x.cpu().data.numpy())

In [None]:
show_point_clouds(x_hat.cpu().data.numpy())

In [None]:
#torch.save(net.state_dict(),"runs/Oct23_11-51-39_yizhou-Z370-AORUS-Gaming-5/model.pth")

In [None]:
len(train_loader)

In [None]:
syn_pcs = net.sample_x(n=16)

In [None]:
syn_pcs.shape

In [None]:
res = quantitative_analysis(syn_pcs.data.numpy(), ref_pcs, 16, full=False)