In [1]:
import torch
from torch.autograd import Variable
from chamferdist import ChamferDistance
from torch.utils.tensorboard import SummaryWriter

In [2]:
from new.configs import *
from new.utils import * 
from new.models import *


In [3]:
writer = SummaryWriter()

In [4]:
opt = parse_config()

In [5]:
writer.add_hparams(hparam_dict = vars(opt),metric_dict = {})

In [6]:
opt.batch_size = batch_size
opt.swap_axis = True
if len(opt.checkpoint_path) == 0:
    opt.checkpoint_path = None 
opt.device = "cuda:%s" % opt.cuda if opt.cuda!="" else "cpu"
opt.shuffle = not opt.warm_start

In [7]:
from utils.util_torch import PointCloudDataSet, PointCloudDataCollator

Eval not available.


In [8]:
train_data = PointCloudDataSet(opt)
data_collator = PointCloudDataCollator(opt)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
shuffle=opt.shuffle, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

In [9]:
x = next(iter(train_loader))

In [10]:
x.shape

torch.Size([16, 3, 2048])

In [11]:
net = NetWrapper()

In [12]:
if torch.cuda.is_available():
    net = net.cuda()

In [13]:
# set_gpu(args.device)
set_cuda(deterministic=gpu_deterministic)
set_seed(123)

In [14]:
optE = torch.optim.Adam(net.netE.parameters(), lr=e_lr, weight_decay=e_decay, betas=(e_beta1, e_beta2))
optG = torch.optim.Adam(net.netG.parameters(), lr=g_lr, weight_decay=g_decay, betas=(g_beta1, g_beta2))

lr_scheduleE = torch.optim.lr_scheduler.ExponentialLR(optE, e_gamma)
lr_scheduleG = torch.optim.lr_scheduler.ExponentialLR(optG, g_gamma)

# Train

In [15]:
net.train()

NetWrapper(
  (netE): NetE(
    (ebm): Sequential(
      (0): Linear(in_features=100, out_features=512, bias=True)
      (1): GELU()
      (2): Linear(in_features=512, out_features=64, bias=True)
      (3): GELU()
      (4): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (netG): NetG(
    (gen): Sequential(
      (0): Linear(in_features=100, out_features=256, bias=True)
      (1): GELU()
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): GELU()
      (4): Linear(in_features=512, out_features=1024, bias=True)
      (5): GELU()
      (6): Linear(in_features=1024, out_features=6144, bias=True)
    )
  )
  (loss_fun): ChamferDistance()
)

In [16]:
total_step = 0

In [17]:
for epoch in range(100):
    for c, x in enumerate(train_loader):
        total_step += 1
        batch_num = x.shape[0]

        if torch.cuda.is_available():
            x = x.to("cuda")

        # Initialize chains
        z_g_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)
        z_e_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)

    #     print("shape log")
    #     print(x.shape)
    #     print(z_g_0.shape)
    #     print(z_e_0.shape)

        # Langevin posterior and prior
        z_g_k = net(Variable(z_g_0), x, prior=False)
        z_e_k = net(Variable(z_e_0), prior=True)

        # Learn generator
        optG.zero_grad()
        x_hat = net.netG(z_g_k.detach())
        loss_g = net.loss_fun(x_hat.transpose(1,2), x.transpose(1,2)) / batch_num
        loss_g.backward()
        optG.step()

        # Learn prior EBM
        optE.zero_grad()
        en_neg = energy(net.netE(z_e_k.detach())).mean() # TODO(nijkamp): why mean() here and in Langevin sum() over energy? constant is absorbed into Adam adaptive lr
        en_pos = energy(net.netE(z_g_k.detach())).mean()
        loss_e = en_pos - en_neg
        loss_e.backward()
        # grad_norm_e = get_grad_norm(net.netE.parameters())
        # if args.e_is_grad_clamp:
        #    torch.nn.utils.clip_grad_norm_(net.netE.parameters(), args.e_max_norm)
        optE.step()

        # Printout
        if total_step % 10 == 0:
            with torch.no_grad():
                x_0 = net.netG(z_e_0)
                x_k = net.netG(z_e_k)

                en_neg_2 = energy(net.netE(z_e_k)).mean()
                en_pos_2 = energy(net.netE(z_g_k)).mean()

                prior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_e_k.mean(), z_e_k.std(), z_e_k.abs().max())
                posterior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_g_k.mean(), z_g_k.std(), z_g_k.abs().max())
                
                writer.add_scalar('loss/loss_g',loss_g, total_step)
                writer.add_scalar('loss/loss_e',loss_e, total_step)
                
                writer.add_scalars('energy/en_pos', {'pos_1':en_pos,
                                    'pose_2':en_pos_2,
                                    'diff': en_pos_2 - en_pos}, total_step)
                writer.add_scalars('energy/en_neg', {'pos_1':en_neg,
                                    'pose_2':en_neg_2,
                                    'diff': en_neg_2 - en_neg}, total_step)
                
                writer.add_scalar('value/|z_g_0|',z_g_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                writer.add_scalar('value/|z_g_k|',z_g_k.view(batch_num, -1).norm(dim=1).mean(), total_step)
                writer.add_scalar('value/|z_e_0|',z_e_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                writer.add_scalar('value/|z_e_k|',z_e_k.view(batch_num, -1).norm(dim=1).mean(), total_step)
                
                writer.add_scalar('disp/z_e_disp',(z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                writer.add_scalar('disp/z_g_disp',(z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                writer.add_scalar('disp/x_e_disp',(x_k-x_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                
                writer.add_scalars('moment/prior_moments', {'mean':z_e_k.mean(),
                                    'std':z_e_k.std(),
                                    'max abs': z_e_k.abs().max()}, total_step)
                writer.add_scalars('moment/posterior_moments', {'mean':z_g_k.mean(),
                                    'std':z_g_k.std(),
                                    'max abs': z_g_k.abs().max()}, total_step)
                
                
#                 print(
#                     '{} {}/{} {}/{} \n'.format(0, epoch, n_epochs, total_step, len(train_loader)) +
#                     'loss_g={:8.3f}, \n'.format(loss_g) +
#                     'loss_e={:8.3f}, \n'.format(loss_e) +
#                     'en_pos=[{:9.4f}, {:9.4f}, {:9.4f}], \n'.format(en_pos, en_pos_2, en_pos_2-en_pos) +
#                     'en_neg=[{:9.4f}, {:9.4f}, {:9.4f}], \n'.format(en_neg, en_neg_2, en_neg_2-en_neg) +
#                     '|z_g_0|={:6.2f}, \n'.format(z_g_0.view(batch_num, -1).norm(dim=1).mean()) +
#                     '|z_g_k|={:6.2f}, \n'.format(z_g_k.view(batch_num, -1).norm(dim=1).mean()) +
#                     '|z_e_0|={:6.2f}, \n'.format(z_e_0.view(batch_num, -1).norm(dim=1).mean()) +
#                     '|z_e_k|={:6.2f}, \n'.format(z_e_k.view(batch_num, -1).norm(dim=1).mean()) +
#                     'z_e_disp={:6.2f}, \n'.format((z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean()) +
#                     'z_g_disp={:6.2f}, \n'.format((z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean()) +
#                     'x_e_disp={:6.2f}, \n'.format((x_k-x_0).view(batch_num, -1).norm(dim=1).mean()) +
#                     'prior_moments={}, \n'.format(prior_moments) +
#                     'posterior_moments={}, \n'.format(posterior_moments) +
#                     #'fid={:8.2f}, '.format(fid) +
#                     #'fid_best={:8.2f}'.format(fid_best)
#                     "\n\n\n ---------------------"
#                 )


In [18]:
add_hparams

NameError: name 'add_hparams' is not defined