In [1]:
import yaml
import torch
import torch.nn as nn
import numpy as np
import tqdm
import os
import argparse
from torch.backends import cudnn

cudnn.benchmark = True

def delete_module(name):
    import sys
    del sys.modules[name]

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
config_file = './test_config.yaml'

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

with open(config_file, 'r') as f:
    config = yaml.load(f)
        
cfg = dict2namespace(config)
cfg

  


Namespace(data=Namespace(batch_size=32, cates=['airplane'], data_dir='data/ShapeNetCore.v2.PC15k', dataset_scale=1, dataset_type='shapenet15k', normalize_per_shape=False, normalize_std_per_axis=False, num_workers=4, recenter_per_shape=True, te_max_sample_points=2048, tr_max_sample_points=2048, type='datasets.pointflow_datasets'), inference=Namespace(num_points=2048, num_steps=10, step_size_ratio=1, weight=1), models=Namespace(scorenet=Namespace(dim=3, hidden_size=256, n_blocks=24, out_dim=3, param_likelihood=False, sigma_condition=True, type='models.decoders.resnet_add', xyz_condition=True, z_dim=128)), trainer=Namespace(epochs=2000, opt_criticnet=Namespace(beta1=0.9, beta2=0.999, lr='1e-3', momentum=0.9, scheduler='linear', step_epoch=2000, type='adam', weight_decay=0.0), opt_scorenet=Namespace(beta1=0.9, beta2=0.999, lr='1e-3', momentum=0.9, scheduler='linear', step_epoch=2000, type='adam', weight_decay=0.0), seed=100, sigma_begin=1, sigma_end=0.01, sigma_num=10, type='trainers.ae_tr

In [4]:
cfg.data

Namespace(batch_size=32, cates=['airplane'], data_dir='data/ShapeNetCore.v2.PC15k', dataset_scale=1, dataset_type='shapenet15k', normalize_per_shape=False, normalize_std_per_axis=False, num_workers=4, recenter_per_shape=True, te_max_sample_points=2048, tr_max_sample_points=2048, type='datasets.pointflow_datasets')

In [5]:
cfg.trainer

Namespace(epochs=2000, opt_criticnet=Namespace(beta1=0.9, beta2=0.999, lr='1e-3', momentum=0.9, scheduler='linear', step_epoch=2000, type='adam', weight_decay=0.0), opt_scorenet=Namespace(beta1=0.9, beta2=0.999, lr='1e-3', momentum=0.9, scheduler='linear', step_epoch=2000, type='adam', weight_decay=0.0), seed=100, sigma_begin=1, sigma_end=0.01, sigma_num=10, type='trainers.ae_trainer_3D')

In [6]:
cfg.models

Namespace(scorenet=Namespace(dim=3, hidden_size=256, n_blocks=24, out_dim=3, param_likelihood=False, sigma_condition=True, type='models.decoders.resnet_add', xyz_condition=True, z_dim=128))

In [7]:
cfg.inference

Namespace(num_points=2048, num_steps=10, step_size_ratio=1, weight=1)

In [8]:
cfg.viz

Namespace(log_freq=10, save_freq=100, val_freq=100, viz_freq=5000)

# Train

In [9]:
from scorenet import Scorenet
from critic import Criticnet
from utils import get_opt, set_random_seed, exact_jacobian_trace, langevin_dynamics, visualize
from data_loader import get_data
import random

In [10]:
set_random_seed(getattr(cfg.trainer, "seed", 666))

# load data
train_data = get_data(cfg.data, 0)
tr_pts = train_data['tr_points'].unsqueeze(0)
te_pts = train_data['te_points'].unsqueeze(0)
tr_pts = tr_pts.to(device)
te_pts = te_pts.to(device)
tr_pts.requires_grad_()
te_pts.requires_grad_()

100%|██████████| 1/1 [00:01<00:00,  1.84s/it]


Total number of data:2832
Min number of points: (train)2048 (test)2048


tensor([[[-0.0832, -0.0008, -0.0106],
         [ 0.2124, -0.1133,  0.3423],
         [ 0.3726, -0.1521, -0.1293],
         ...,
         [-0.0380,  0.0209, -0.6827],
         [ 0.0835, -0.0098, -0.6379],
         [ 0.6067, -0.1337, -0.0316]]], device='cuda:0', requires_grad=True)

In [11]:
# sigmas
if hasattr(cfg.trainer, "sigmas"):
    np_sigmas = cfg.trainer.sigmas
else:
    sigma_begin = float(cfg.trainer.sigma_begin)
    sigma_end = float(cfg.trainer.sigma_end)
    num_classes = int(cfg.trainer.sigma_num)
    np_sigmas = np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), num_classes))
print("Sigma:, ", np_sigmas)
sigmas = torch.tensor(np.array(np_sigmas)).float().cuda().view(-1, 1)

Sigma:,  [1.         0.59948425 0.35938137 0.21544347 0.12915497 0.07742637
 0.04641589 0.02782559 0.01668101 0.01      ]


In [None]:
# score net
score_net = Scorenet()
critic_net = Criticnet()
critic_net.to(device)
score_net.to(device)

print(score_net)
print(critic_net)

# optimizer
opt_scorenet, scheduler_scorenet = get_opt(score_net.parameters(), cfg.trainer.opt_scorenet)
opt_criticnet, scheduler_criticnet = get_opt(critic_net.parameters(), cfg.trainer.opt_scorenet)

# training
start_epoch = 0
n_epochs = 1000
print("Start epoch: %d End epoch: %d" % (start_epoch, cfg.trainer.epochs))
k_iters = 20
e_iters = 1
# for epoch in range(start_epoch, cfg.trainer.epochs):
sigmas = sigmas[-5:]
for epoch in range(start_epoch, n_epochs):
    score_net.train()
    critic_net.train()
    opt_scorenet.zero_grad()
    opt_criticnet.zero_grad()
    
    perturbed_points = tr_pts + torch.randn_like(tr_pts) * sigmas[..., None]
    
    score_pred = score_net(perturbed_points, sigmas)
    critic_output = critic_net(perturbed_points, sigmas)
    
    t1 = (score_pred * critic_output).sum(-1)
    t2 = exact_jacobian_trace(critic_output, perturbed_points)
    
    stein = t1 + t2
    l2_penalty = (critic_output * critic_output).sum(-1).mean()
    loss = stein.mean()
    
    cycle_iter = epoch % (k_iters + e_iters)
    cpu_loss = loss.detach().cpu().item()
    cpu_t1 = t1.mean().detach().cpu().item()
    cpu_t2 = t2.mean().detach().cpu().item()
    if cycle_iter < k_iters:
        (-loss + l2_penalty).backward()
        opt_criticnet.step()
        print("Epoch (critic) %d Loss=%2.5f t1=%2.5f t2=%2.5f" % (epoch, cpu_loss, cpu_t1, cpu_t2))
    else:
        loss.backward()
        opt_scorenet.step()
        print("Epoch (score) %d Loss=%2.5f t1=%2.5f t2=%2.5f" % (epoch, cpu_loss, cpu_t1, cpu_t2))

Scorenet(
  (conv_p): Conv1d(4, 256, kernel_size=(1,), stride=(1,))
  (blocks): ModuleList(
    (0): ResnetBlockConv1d(
      (bn_0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn_1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (fc_0): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (fc_1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (fc_c): Conv1d(4, 256, kernel_size=(1,), stride=(1,))
      (actvn): ReLU()
    )
    (1): ResnetBlockConv1d(
      (bn_0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn_1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (fc_0): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (fc_1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
      (fc_c): Conv1d(4, 256, kernel_size=(1,), stride=(1,))
      (actvn): ReLU()
    )
    (2): ResnetBlockConv1d(
      (bn_0): BatchN

Epoch (score) 62 Loss=7.37540 t1=0.93016 t2=6.44524
Epoch (critic) 63 Loss=7.00124 t1=0.12129 t2=6.87995
Epoch (critic) 64 Loss=7.86937 t1=0.53931 t2=7.33006
Epoch (critic) 65 Loss=9.52598 t1=0.92837 t2=8.59761
Epoch (critic) 66 Loss=11.68926 t1=0.75633 t2=10.93293
Epoch (critic) 67 Loss=7.51500 t1=0.69500 t2=6.82000
Epoch (critic) 68 Loss=8.23316 t1=0.82578 t2=7.40738
Epoch (critic) 69 Loss=8.02761 t1=0.56316 t2=7.46445
Epoch (critic) 70 Loss=10.57810 t1=0.43734 t2=10.14076
Epoch (critic) 71 Loss=12.00031 t1=0.24821 t2=11.75211
Epoch (critic) 72 Loss=15.92436 t1=0.19901 t2=15.72536
Epoch (critic) 73 Loss=7.41022 t1=0.42220 t2=6.98802
Epoch (critic) 74 Loss=7.39093 t1=0.20188 t2=7.18905
Epoch (critic) 75 Loss=12.41480 t1=0.39094 t2=12.02386
Epoch (critic) 76 Loss=8.25627 t1=0.51802 t2=7.73824
Epoch (critic) 77 Loss=12.68285 t1=0.36426 t2=12.31858
Epoch (critic) 78 Loss=10.35757 t1=0.19979 t2=10.15778
Epoch (critic) 79 Loss=12.85553 t1=0.10625 t2=12.74928
Epoch (critic) 80 Loss=14.65753

Epoch (critic) 212 Loss=7.85992 t1=-0.07073 t2=7.93065
Epoch (critic) 213 Loss=8.07143 t1=0.04872 t2=8.02271
Epoch (critic) 214 Loss=10.78274 t1=0.12619 t2=10.65656
Epoch (critic) 215 Loss=13.72184 t1=0.22058 t2=13.50126
Epoch (critic) 216 Loss=18.93133 t1=0.44212 t2=18.48922
Epoch (critic) 217 Loss=20.77901 t1=0.27883 t2=20.50019
Epoch (critic) 218 Loss=30.47858 t1=0.53504 t2=29.94354
Epoch (critic) 219 Loss=31.98213 t1=0.50704 t2=31.47509
Epoch (critic) 220 Loss=36.20663 t1=0.33235 t2=35.87429
Epoch (critic) 221 Loss=33.08251 t1=0.04379 t2=33.03872
Epoch (critic) 222 Loss=32.26554 t1=0.07961 t2=32.18593
Epoch (critic) 223 Loss=17.94314 t1=-0.37946 t2=18.32259
Epoch (critic) 224 Loss=43.43725 t1=0.90952 t2=42.52772
Epoch (critic) 225 Loss=17.86088 t1=0.60707 t2=17.25381
Epoch (critic) 226 Loss=24.63659 t1=0.30291 t2=24.33368
Epoch (critic) 227 Loss=20.37271 t1=0.22886 t2=20.14385
Epoch (critic) 228 Loss=21.20286 t1=0.34059 t2=20.86227
Epoch (critic) 229 Loss=18.39295 t1=0.33609 t2=18.

Epoch (critic) 358 Loss=36.97232 t1=2.38044 t2=34.59188
Epoch (critic) 359 Loss=38.47107 t1=1.86733 t2=36.60374
Epoch (critic) 360 Loss=19.05404 t1=2.40391 t2=16.65013
Epoch (critic) 361 Loss=34.92704 t1=1.62911 t2=33.29794
Epoch (critic) 362 Loss=33.99706 t1=1.86616 t2=32.13090
Epoch (critic) 363 Loss=31.72571 t1=1.07697 t2=30.64875
Epoch (critic) 364 Loss=45.60816 t1=0.95799 t2=44.65017
Epoch (critic) 365 Loss=31.27006 t1=1.33652 t2=29.93354
Epoch (critic) 366 Loss=26.19077 t1=1.09304 t2=25.09773
Epoch (critic) 367 Loss=38.61103 t1=0.52298 t2=38.08805
Epoch (critic) 368 Loss=45.52943 t1=-0.03197 t2=45.56140
Epoch (critic) 369 Loss=38.23691 t1=-0.13422 t2=38.37114
Epoch (critic) 370 Loss=34.25155 t1=-0.39364 t2=34.64520
Epoch (critic) 371 Loss=45.33377 t1=-1.43925 t2=46.77303
Epoch (critic) 372 Loss=49.90918 t1=-1.18700 t2=51.09618
Epoch (critic) 373 Loss=55.25904 t1=-1.55266 t2=56.81170
Epoch (critic) 374 Loss=56.29768 t1=0.06124 t2=56.23644
Epoch (critic) 375 Loss=67.44949 t1=0.4323

Epoch (critic) 504 Loss=69.31344 t1=0.27659 t2=69.03683
Epoch (critic) 505 Loss=32.49733 t1=1.26250 t2=31.23482
Epoch (critic) 506 Loss=6.96052 t1=1.84472 t2=5.11580
Epoch (critic) 507 Loss=45.40466 t1=0.54791 t2=44.85675
Epoch (critic) 508 Loss=21.36867 t1=0.67373 t2=20.69495
Epoch (critic) 509 Loss=5.71885 t1=-0.62098 t2=6.33983
Epoch (critic) 510 Loss=23.86750 t1=-0.67888 t2=24.54638
Epoch (critic) 511 Loss=47.04663 t1=0.54024 t2=46.50640
Epoch (critic) 512 Loss=34.48725 t1=2.00221 t2=32.48504
Epoch (critic) 513 Loss=68.49168 t1=2.50946 t2=65.98222
Epoch (critic) 514 Loss=32.58523 t1=2.18901 t2=30.39622
Epoch (critic) 515 Loss=42.23355 t1=0.75428 t2=41.47927
Epoch (critic) 516 Loss=7.34086 t1=2.77455 t2=4.56631
Epoch (critic) 517 Loss=3.04212 t1=2.94911 t2=0.09302
Epoch (critic) 518 Loss=1.61189 t1=3.13236 t2=-1.52047
Epoch (critic) 519 Loss=3.19660 t1=2.97476 t2=0.22184
Epoch (critic) 520 Loss=3.24234 t1=2.74644 t2=0.49590
Epoch (critic) 521 Loss=3.40223 t1=2.68678 t2=0.71546
Epoch

Epoch (critic) 651 Loss=68.72998 t1=1.25074 t2=67.47923
Epoch (critic) 652 Loss=62.12374 t1=1.69989 t2=60.42386
Epoch (critic) 653 Loss=70.35093 t1=2.45687 t2=67.89406
Epoch (critic) 654 Loss=47.43303 t1=1.08324 t2=46.34979
Epoch (critic) 655 Loss=50.25041 t1=1.14131 t2=49.10910
Epoch (critic) 656 Loss=60.03160 t1=-0.07924 t2=60.11084
Epoch (critic) 657 Loss=68.39242 t1=0.59004 t2=67.80238
Epoch (critic) 658 Loss=52.63983 t1=1.29552 t2=51.34431
Epoch (critic) 659 Loss=39.16830 t1=0.81130 t2=38.35700
Epoch (critic) 660 Loss=35.83537 t1=0.65012 t2=35.18525
Epoch (critic) 661 Loss=45.09742 t1=0.48962 t2=44.60780
Epoch (critic) 662 Loss=56.18547 t1=0.40695 t2=55.77852
Epoch (critic) 663 Loss=48.08339 t1=1.10297 t2=46.98042
Epoch (critic) 664 Loss=56.58051 t1=0.79675 t2=55.78375
Epoch (critic) 665 Loss=67.61380 t1=0.83038 t2=66.78341
Epoch (critic) 666 Loss=34.98766 t1=-0.08065 t2=35.06831
Epoch (critic) 667 Loss=52.03860 t1=0.94464 t2=51.09397
Epoch (critic) 668 Loss=54.07127 t1=0.28311 t2

In [None]:
pt_cl, _ = langevin_dynamics(score_net, sigmas, eps=1e-4, num_steps=3)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib notebook

In [None]:
# Visualization
visualize(pt_cl)

In [None]:
visualize(tr_pts)