In [1]:
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
import torch
sys.path.append(os.path.join('../',os.path.dirname(os.path.abspath(''))))
from utils import load_checkpoint, save_checkpoint
from train_utils import visualize_pointcloud 
from upsampling_network import SRNet
from analysis_helper import load_pos
import os
import time
import collections

INFO - 2022-04-24 18:26:30,976 - utils - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO - 2022-04-24 18:26:30,976 - utils - NumExpr defaulting to 8 threads.
Using backend: pytorch


### Upsample coarse-grained simulation

In [2]:
from analysis_helper import write_bgeo_from_numpy
import time

def normalize_point_cloud(pcd_pos):
    centroid = np.mean(pcd_pos, axis=0, keepdims=True)
    input_pos = pcd_pos - centroid
    furthest_distance = np.float32(1.)
    input_pos /= furthest_distance
    return input_pos, centroid, furthest_distance

In [3]:
# use velocity
path_to_resume = './train_vel/tpugan_vel_checkpoint.ckpt'
ckpt = load_checkpoint(path_to_resume)  # custom method for loading last checkpoint
net = SRNet(in_feats=6,
             node_emb_dim=128)
net.load_state_dict(ckpt['sr_net'])
print("last checkpoint restored")
net = net.cuda()
net = net.eval()

out_dir = './bunny_demo'
os.makedirs(out_dir, exist_ok=True)
mask_lst = []
start_time = time.time()
for i in range(800):
    input_data = np.load(f'../data/bunny/data_{i}.npz')
    lowres_input = input_data['pos']
    lowres_vel = input_data['vel']
    with torch.no_grad():
        lowres_input, centroid, h = normalize_point_cloud(lowres_input)
        lowres_input = torch.from_numpy(lowres_input).cuda().unsqueeze(0)
        lowres_vel = torch.from_numpy(lowres_vel).cuda().unsqueeze(0)

        feature = torch.cat([lowres_input, lowres_vel*0.025],dim=2)
        refined_pos, mask_lst = net.forward_with_context(feature, lowres_input, mask_lst)
        refined_pos = refined_pos[0].cpu().numpy()
        refined_pos *= h
        refined_pos += centroid
        torch.cuda.empty_cache()
        # visualize_pointcloud(refined_pos)
    npy_pth = os.path.join(out_dir, f'pcd_{i}.npy')
    np.save(npy_pth,  refined_pos)

    
end_time = time.time()
used_time = end_time - start_time
print(f'Used :{used_time}')

 [*] Loading checkpoint from ./train_vel/tpugan_vel_checkpoint.ckpt succeed!
last checkpoint restored
Used :534.9671199321747


In [8]:
# not use velocity
path_to_resume = './train_novel/tpugan_novel_checkpoint.ckpt'
ckpt = load_checkpoint(path_to_resume)  # custom method for loading last checkpoint
net = SRNet(in_feats=3,
             node_emb_dim=128)
net.load_state_dict(ckpt['sr_net'])
print("last checkpoint restored")
net = net.cuda()
net = net.eval()

out_dir = './bunny_demo_no_vel'
os.makedirs(out_dir, exist_ok=True)
mask_lst = []
start_time = time.time()
for i in range(800):
    input_data = np.load(f'../data/bunny/data_{i}.npz')
    lowres_input = input_data['pos']
    with torch.no_grad():
        lowres_input, centroid, h = normalize_point_cloud(lowres_input)
        lowres_input = torch.from_numpy(lowres_input).cuda().unsqueeze(0)

        feature = lowres_input
        refined_pos, mask_lst = net.forward_with_context(feature, lowres_input, mask_lst)
        refined_pos = refined_pos[0].cpu().numpy()
        refined_pos *= h
        refined_pos += centroid
        torch.cuda.empty_cache()
    npy_pth = os.path.join(out_dir, f'pcd_{i}.npy')
    np.save(npy_pth,  refined_pos)
    
end_time = time.time()
used_time = end_time - start_time
print(f'Used :{used_time}')

 [*] Loading checkpoint from ./train_novel/tpugan_novel_checkpoint.ckpt succeed!
last checkpoint restored
Used :531.0246877670288


### Write the data as .bgeo for visualization/rendering
(requires installing partio: https://github.com/wdas/partio)

In [2]:
import numpy as np
import os
from analysis_helper import write_bgeo_from_numpy

out_dir = './bunny_demo_no_vel'
for i in range(800):
    pos_np = np.load(f'./bunny_demo_no_vel/pcd_{i}.npy')
    bgeo_pth = os.path.join(out_dir, 'pcd_{0:04d}.bgeo'.format(i))
    write_bgeo_from_numpy(bgeo_pth, pos_np)