In [1]:
import sys
import os
import numpy as np
import torch
sys.path.append(os.path.join('../',os.path.dirname(os.path.abspath(''))))
from utils import load_checkpoint, save_checkpoint
from train_utils import visualize_pointcloud, dump_pointcloud_visualization
from upsampling_network import NoMaskSRNet
from msr_dataset import get_test_dataloader, MSRAction3D

INFO - 2022-04-25 00:45:44,024 - utils - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO - 2022-04-25 00:45:44,024 - utils - NumExpr defaulting to 8 threads.
Using backend: pytorch


In [3]:
path_to_resume = './train_dir/tpugan_checkpoint.ckpt'
ckpt = load_checkpoint(path_to_resume)
sr_net = NoMaskSRNet(3, 128, 16)
sr_net.load_state_dict(ckpt['sr_net'])
print("last checkpoint restored")

sr_net = sr_net.cuda()
sr_net = sr_net.eval()
    # # # # #
dataset = MSRAction3D(root='../data/MSR-Action3D', frames_per_clip=24, train=False, num_points=2048)
dat_loader = get_test_dataloader(dataset)
sample_dir = 'action_test'
os.makedirs(sample_dir, exist_ok=True)
dat_iter = iter(dat_loader)

# generate some sequences
for i in range(10):
    dat = next(dat_iter)

    highres_pos_lst, lowres_pos_lst, c_lst,  label_lst = dat
    pred_pos_arr = []
    for l in range(len(highres_pos_lst)):
        lowres_pos = lowres_pos_lst[l]
        offset = c_lst[l]
        lowres_pos = lowres_pos.cuda()
        feature = lowres_pos
        with torch.no_grad():
            pred, _ = sr_net(feature, lowres_pos)
        pred += offset.view(-1, 1, 3).cuda()
        pred_pos_arr += [pred.cpu().numpy()[None, ...]]
    pred_pos_arr = np.concatenate(pred_pos_arr, axis=0)  # [24, 8, 2048, 3]
    pred_pos_arr = np.transpose(pred_pos_arr, axes=(1, 0, 2, 3))

    for b in range(label_lst.shape[0]):
        np.savez(os.path.join(sample_dir, f'pcd_{8*i+b}.npz'), pred=pred_pos_arr[b])

 [*] Loading checkpoint from ./train_dir/tpugan_checkpoint.ckpt succeed!
last checkpoint restored


### Write the data as .bgeo for visualization/rendering
(requires installing partio: https://github.com/wdas/partio)

In [None]:
import numpy as np
import os
from analysis_helper import write_bgeo_from_numpy

out_dir = 'action_test'
for i in range(80):
    pos_np = np.load(f'./action_test/pcd_{i}.npy')
    bgeo_pth = os.path.join(out_dir, 'pcd_{0:04d}.bgeo'.format(i))
    write_bgeo_from_numpy(bgeo_pth, pos_np)