In [1]:
use_tensorboard = False

In [2]:
import torch
from torch.autograd import Variable

if use_tensorboard:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()

In [3]:
from utils.util_torch import PointCloudDataSet, PointCloudDataCollator

In [4]:
torch.__version__

'1.10.0'

In [5]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 2080'

In [6]:
torch.version.cuda

'11.3'

In [7]:
from new.configs import *
from new.utils import * 
from new.models import *

In [8]:
opt = parse_config()

In [9]:
if use_tensorboard:
    writer.add_hparams(hparam_dict = vars(opt),metric_dict = {})

In [10]:
opt.batch_size = batch_size
opt.swap_axis = True
if len(opt.checkpoint_path) == 0:
    opt.checkpoint_path = None 
opt.device = "cuda:%s" % opt.cuda if opt.cuda!="" else "cpu"
opt.shuffle = not opt.warm_start

In [11]:
print(opt)

Namespace(activate_eval=0, activation='ReLU', argment_mode=0, argment_noise=0.01, batch_norm='ln', batch_size=16, beta1_des=0.9, category='chair', checkpoint_path=None, cuda='-1', data_path='data', data_size=10000, debug=99, device='cuda:-1', do_evaluation=1, drop_last=False, eval_step=50, fp16='None', gradient_accumulation_steps=1, langevin_clip=1, langevin_decay=0, learning_mode=0, lr=0.0005, lr_decay=0.998, mode='train', net_type='default_medium', noise_decay=0, normalize='ebp', num_chain=1, num_point=2048, num_steps=2000, output_dir='default', point_dim=3, random_sample=1, ref_sigma=0.3, sample_step=64, seed=666, shuffle=True, stable_check=1, step_size=0.01, swap_axis=True, test_size=16, visualize_mode=0, warm_start=0)


In [12]:
train_data = PointCloudDataSet(opt)
data_collator = PointCloudDataCollator(opt)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
    shuffle=opt.shuffle, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

In [13]:
# test_dataset = torch.utils.data.TensorDataset(np.load("data/%s_test.npy" % opt.category))
# test_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
#     shuffle=False, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

TypeError: 'int' object is not callable

In [None]:
x = next(iter(train_loader))

In [None]:
x.shape

In [None]:
net = NetWrapper()

In [None]:
if torch.cuda.is_available():
    net = net.cuda()

In [None]:
# set_gpu(args.device)
set_cuda(deterministic=gpu_deterministic)
set_seed(123)

In [None]:
optE = torch.optim.Adam(net.netE.parameters(), lr=e_lr, weight_decay=e_decay, betas=(e_beta1, e_beta2))
optG = torch.optim.Adam(net.netG.parameters(), lr=g_lr, weight_decay=g_decay, betas=(g_beta1, g_beta2))

lr_scheduleE = torch.optim.lr_scheduler.ExponentialLR(optE, e_gamma)
lr_scheduleG = torch.optim.lr_scheduler.ExponentialLR(optG, g_gamma)

# Train

In [None]:
net.train()

In [None]:
total_step = 0

In [None]:
for epoch in range(100):
    for c, x in enumerate(train_loader):
        total_step += 1
        batch_num = x.shape[0]

        if torch.cuda.is_available():
            x = x.to("cuda")

        # Initialize chains
        z_g_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)
        z_e_0 = sample_p_0(n = batch_num,sig=e_init_sig, device = x.device)

    #     print("shape log")
    #     print(x.shape)
    #     print(z_g_0.shape)
    #     print(z_e_0.shape)

        # Langevin posterior and prior
        z_g_k = net(Variable(z_g_0), x, prior=False)
        z_e_k = net(Variable(z_e_0), prior=True)

        # Learn generator
        optG.zero_grad()
        x_hat = net.netG(z_g_k.detach())
        loss_g = net.loss_fun(x_hat.transpose(1,2).contiguous(), x.transpose(1,2).contiguous())
        loss_g.backward()
        optG.step()

        # Learn prior EBM
        optE.zero_grad()
        en_neg = energy(net.netE(z_e_k.detach())).mean() # TODO(nijkamp): why mean() here and in Langevin sum() over energy? constant is absorbed into Adam adaptive lr
        en_pos = energy(net.netE(z_g_k.detach())).mean()
        loss_e = en_pos - en_neg
        loss_e.backward()
        # grad_norm_e = get_grad_norm(net.netE.parameters())
        # if args.e_is_grad_clamp:
        #    torch.nn.utils.clip_grad_norm_(net.netE.parameters(), args.e_max_norm)
        optE.step()

        # Printout
        if total_step % 10 == 0:
            with torch.no_grad():
                x_0 = net.netG(z_e_0)
                x_k = net.netG(z_e_k)

                en_neg_2 = energy(net.netE(z_e_k)).mean()
                en_pos_2 = energy(net.netE(z_g_k)).mean()

                prior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_e_k.mean(), z_e_k.std(), z_e_k.abs().max())
                posterior_moments = '[{:8.2f}, {:8.2f}, {:8.2f}]'.format(z_g_k.mean(), z_g_k.std(), z_g_k.abs().max())
                
                
                if use_tensorboard:
                    writer.add_scalar('loss/loss_g',loss_g, total_step)
                    writer.add_scalar('loss/loss_e',loss_e, total_step)

                    writer.add_scalars('energy/en_pos', {'pos_1':en_pos,
                                        'pose_2':en_pos_2,
                                        'diff': en_pos_2 - en_pos}, total_step)
                    writer.add_scalars('energy/en_neg', {'pos_1':en_neg,
                                        'pose_2':en_neg_2,
                                        'diff': en_neg_2 - en_neg}, total_step)

                    writer.add_scalar('value/|z_g_0|',z_g_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_g_k|',z_g_k.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_e_0|',z_e_0.view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('value/|z_e_k|',z_e_k.view(batch_num, -1).norm(dim=1).mean(), total_step)

                    writer.add_scalar('disp/z_e_disp',(z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('disp/z_g_disp',(z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean(), total_step)
                    writer.add_scalar('disp/x_e_disp',(x_k-x_0).view(batch_num, -1).norm(dim=1).mean(), total_step)

                    writer.add_scalars('moment/prior_moments', {'mean':z_e_k.mean(),
                                        'std':z_e_k.std(),
                                        'max abs': z_e_k.abs().max()}, total_step)
                    writer.add_scalars('moment/posterior_moments', {'mean':z_g_k.mean(),
                                        'std':z_g_k.std(),
                                        'max abs': z_g_k.abs().max()}, total_step)
                else:
                    print(
                        '{} {}/{} {}/{} \n'.format(0, epoch, n_epochs, total_step, len(train_loader)) +
                        'loss_g={:8.5f}, \n'.format(loss_g) +
                        'loss_e={:8.5f}, \n'.format(loss_e) +
                        'en_pos=[{:9.5f}, {:9.5f}, {:9.5f}], \n'.format(en_pos, en_pos_2, en_pos_2-en_pos) +
                        'en_neg=[{:9.5f}, {:9.5f}, {:9.5f}], \n'.format(en_neg, en_neg_2, en_neg_2-en_neg) +
                        '|z_g_0|={:6.3f}, \n'.format(z_g_0.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_g_k|={:6.3f}, \n'.format(z_g_k.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_e_0|={:6.3f}, \n'.format(z_e_0.view(batch_num, -1).norm(dim=1).mean()) +
                        '|z_e_k|={:6.3f}, \n'.format(z_e_k.view(batch_num, -1).norm(dim=1).mean()) +
                        'z_e_disp={:6.3f}, \n'.format((z_e_k-z_e_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'z_g_disp={:6.3f}, \n'.format((z_g_k-z_g_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'x_e_disp={:6.3f}, \n'.format((x_k-x_0).view(batch_num, -1).norm(dim=1).mean()) +
                        'prior_moments={}, \n'.format(prior_moments) +
                        'posterior_moments={}, \n'.format(posterior_moments) +
                        #'fid={:8.2f}, '.format(fid) +
                        #'fid_best={:8.2f}'.format(fid_best)
                        "\n\n\n ---------------------"
                    )


In [None]:
#torch.save(net.state_dict(),"runs/Oct22_12-14-52_yizhou-Z370-AORUS-Gaming-5/net.pth")

In [None]:
len(train_loader)