In [1]:
import torch

In [2]:
from itertools import chain

In [3]:
from new.models import *
from new.configs import *
from new.utils import * 
from utils.util_torch import *

In [4]:
opt = parse_config()
opt.batch_size = batch_size
opt.swap_axis = True
if len(opt.checkpoint_path) == 0:
    opt.checkpoint_path = None 
opt.device = "cuda:%s" % opt.cuda if opt.cuda!="" else "cpu"
opt.shuffle = not opt.warm_start
print(opt)

Namespace(activate_eval=0, activation='ReLU', argment_mode=0, argment_noise=0.01, batch_norm='ln', batch_size=16, beta1_des=0.9, category='chair', checkpoint_path=None, cuda='-1', data_path='data', data_size=10000, debug=99, device='cuda:-1', do_evaluation=1, drop_last=False, eval_step=50, fp16='None', gradient_accumulation_steps=1, langevin_clip=1, langevin_decay=0, learning_mode=0, lr=0.0005, lr_decay=0.998, mode='train', net_type='default_medium', noise_decay=0, normalize='ebp', num_chain=1, num_point=2048, num_steps=2000, output_dir='default', point_dim=3, random_sample=1, ref_sigma=0.3, sample_step=64, seed=666, shuffle=True, stable_check=1, step_size=0.01, swap_axis=True, test_size=16, visualize_mode=0, warm_start=0)


In [5]:
train_data = PointCloudDataSet(opt)
data_collator = PointCloudDataCollator(opt)
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True)
# torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, drop_last=opt.drop_last, 
#     shuffle=opt.shuffle, collate_fn = data_collator, num_workers=torch.cuda.device_count() * 4)

In [None]:
G = NetG().cuda()
E = Encoder().cuda()

In [None]:
EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()),
                    lr = 1e-4)

In [None]:
a = torch.randn(16, 3, 2048).to("cuda")

In [None]:
E(a)[0].shape

In [None]:
from metrics.evaluation_metrics import distChamferCUDA, distChamfer

In [None]:
def loss_fun(x:torch.Tensor, y:torch.Tensor, loss_type = "chamfer distance"):
    if x.is_cuda:
        dl, dr = distChamferCUDA(x, y)
    else:
        dl, dr = distChamfer(x, y)

    cd = torch.mean(dl + dr)
    return cd

In [None]:
#
# Float Tensors
#
fixed_noise = torch.FloatTensor(16, 2048, 1)
fixed_noise.normal_(mean=0, std=0.2)
std_assumed = torch.tensor(0.2)

fixed_noise = fixed_noise.to("cuda")
std_assumed = std_assumed.to("cuda")

In [None]:

total_step = 0

In [None]:
%matplotlib inline

for epoch in range(10):
    G.train()
    E.train()

    total_loss = 0.0
    for i, point_data in enumerate(train_loader):
        X = point_data
        X = X.to("cuda")

        total_step += 1

        # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
        if X.size(-1) == 3:
            X.transpose_(X.dim() - 2, X.dim() - 1)

        codes, mu, logvar = E(X)
        X_rec = G(codes)

        loss_e = torch.mean(
            0.05 *
            loss_fun(X.permute(0, 2, 1).contiguous() + 0.5,
                                X_rec.permute(0, 2, 1).contiguous() + 0.5))

        loss_kld = -0.5 * torch.mean(
            1 - 2.0 * torch.log(std_assumed) + logvar -
            (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

        loss_eg = loss_e + loss_kld
        EG_optim.zero_grad()
        E.zero_grad()
        G.zero_grad()

        loss_eg.backward()
        total_loss += loss_eg.item()
        EG_optim.step()

        if total_step % 10 == 0:
            print(f'[{epoch}: ({i})] '
                      f'Loss_EG: {loss_eg.item():.4f} '
                      f'(REC: {loss_e.item(): .4f}'
                      f' KLD: {loss_kld.item(): .4f})')

    ############################## EVAL #####################################
    G.eval()
    E.eval()
    with torch.no_grad():
        fake = G(fixed_noise.squeeze(2))
        codes, _, _ = E(X)
        X_rec = G(codes).data.cpu().numpy()

        print("fake")
        show_point_clouds(fake.data.cpu().numpy())
        print("X_rec")
        show_point_clouds(X_rec)