In [None]:
import h5py
import numpy as np
import os, sys
import open3d as o3d
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import time

import torch
from itertools import product

In [None]:
def store_data(data_names, data, path):
    hf = h5py.File(path, 'w')
    for i in range(len(data_names)):
        hf.create_dataset(data_names[i], data=data[i])
    hf.close()

def load_data(data_names, path):
    hf = h5py.File(path, 'r')
    data = []
    for i in range(len(data_names)):
        d = np.array(hf.get(data_names[i]))
        data.append(d)
    hf.close()
    return data

In [None]:
def p2g(x, size=64, p_mass=1.):
    if x.dim() == 2:
        x = x[None, :]
    batch = x.shape[0]
    grid_m = torch.zeros(batch, size * size * size, dtype=x.dtype, device=x.device)
    inv_dx = size
    # base = (self.x[f, p] * self.inv_dx - 0.5).cast(int)
    # fx = self.x[f, p] * self.inv_dx - base.cast(self.dtype)
    fx = x * inv_dx
    base = (x * inv_dx - 0.5).long()
    fx = fx - base.float()
    w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
    #for offset in ti.static(ti.grouped(self.stencil_range())):
    #    weight = ti.cast(1.0, self.dtype)
    #    for d in ti.static(range(self.dim)):
    #        weight *= w[offset[d]][d]
    #    self.grid_m[base + offset] += weight * self.p_mass
    for i in range(3):
        for j in range(3):
            for k in range(3):
                weight = w[i][..., 0] * w[j][..., 1] * w[k][..., 2] * p_mass
                target = (base + torch.tensor(np.array([i, j, k]), dtype=torch.long, device='cuda:0')).clamp(0, size-1)
                idx = (target[..., 0] * size + target[..., 1]) * size + target[..., 2]
                grid_m.scatter_add_(1, idx, weight)
    return grid_m.reshape(batch, size, size, size)

In [None]:
def compute_sdf(density, eps=1e-4, inf=1e10):
    if density.dim() == 3:
        density = density[None, :, :]
    dx = 1./density.shape[1]
    with torch.no_grad():
        nearest_points = torch.stack(torch.meshgrid(
            torch.arange(density.shape[1]),
            torch.arange(density.shape[2]),
            torch.arange(density.shape[3]),
        ), axis=-1)[None, :].to(density.device).expand(density.shape[0], -1, -1, -1, -1) * dx
        mesh_points = nearest_points.clone()

        is_object = (density <= eps) * inf
        sdf = is_object.clone()

        for i in range(density.shape[1] * 2): # np.sqrt(1^2+1^2+1^2)
            for x, y, z in product(range(3), range(3), range(3)):
                if x + y + z == 0: continue
                def get_slice(a):
                    if a == 0: return slice(None), slice(None)
                    if a == 1: return slice(0, -1), slice(1, None)
                    return slice(1, None), slice(0, -1)
                f1, t1 = get_slice(x)
                f2, t2 = get_slice(y)
                f3, t3 = get_slice(z)
                fr = (slice(None), f1, f2, f3)
                to = (slice(None), t1, t2, t3)
                dist = ((mesh_points[to] - nearest_points[fr])**2).sum(axis=-1)**0.5
                dist += (sdf[fr] >= inf) * inf
                sdf_to = sdf[to]
                mask = (dist < sdf_to).float()
                sdf[to] = mask * dist + (1-mask) * sdf_to
                mask = mask[..., None]
                nearest_points[to] = (1-mask) * nearest_points[to] + mask * nearest_points[fr]
        return sdf

In [None]:
# task_name = "gripper"
rollout_dir = f"/home/haochen/projects/deformable/VGPL-Dynamics-Prior/data/data_ngrip_new/train"
n_vid = 50
n_frame = 89
data_names = ['positions', 'shape_quats', 'scene_params']
for i in range(n_vid):
    for t in range(n_frame):
        print(f"vid {i} frame {t}")
        frame_path = os.path.join(rollout_dir, str(i).zfill(3), 'gt_' + str(t) + '.h5')
        this_data = load_data(data_names, frame_path)
        states = this_data[0]
        states_tmp = states[:300]

        gt_pos = torch.from_numpy(states_tmp).cuda()
        grid = p2g(gt_pos)
        sdf = compute_sdf(grid)

        store_data(['sdf'], [sdf.cpu().numpy()], os.path.join(rollout_dir, str(i).zfill(3), 'sdf_gt_' + str(t) + '.h5'))
