In [1]:
import sys
sys.path.append('..')
import os
import skimage
import time
import kaolin as kal
import tntorch as tn
import torch
import trimesh
import tqdm
import os.path as osp
import matplotlib.pyplot as plt
from t4dt.metrics import compute_metrics, hausdorff, MSDM2
from t4dt.utils import sdf2mesh
from t4dt.t4dt import reduce_tucker

# Constant tsdf

In [2]:
min_tsdf, max_tsdf = -0.005, 0.005
data_dir = '/scratch2/data/cape_release/'
model = '00032'
scene = 'longshort_flying_eagle'
frames = []
for frame in sorted(os.listdir(osp.join(data_dir, 'meshes', model, scene, 'posed'))):
    if frame.startswith('sdf'):
        frames.append(frame)

In [6]:
folder = osp.join(data_dir, 'meshes', model, scene, 'posed')
sdf0 = torch.load(osp.join(folder, frames[0]))
sdf1 = torch.load(osp.join(folder, frames[142]))
sdf2 = torch.load(osp.join(folder, frames[283]))

files = [(osp.join(folder, frame), osp.join(folder, frame[4:-2] + 'obj'))
            for frame in frames]

In [7]:
from typing import List
def compute_metrics_constant(
        frames: List[str],
        sdf_pred: torch.Tensor,
        min_tsdf: float,
        max_tsdf: float,
        num_sample_points: int,
        sample_frames: List[int]):
    res = torch.tensor(sdf_pred.shape)
    result = {}
    for i in tqdm.tqdm(sample_frames):
        result[i] = {}
        frame_pred = sdf_pred
        frame_pred.clamp_min_(min_tsdf)
        frame_pred.clamp_max_(max_tsdf)

        sdf_w_coords = torch.load(frames[i][0])
        sdf = sdf_w_coords['sdf']
        coords = torch.tensor(sdf_w_coords['coords'])

        tqdm.tqdm.write('Marching cube started')
        t0 = time.time()
        mesh_pred = sdf2mesh(frame_pred, coords, res)
        tqdm.tqdm.write(f'Marching cube finished. Took: {time.time() - t0} s.')

        tqdm.tqdm.write('Marching cube started')
        t0 = time.time()
        mesh_gt = sdf2mesh(sdf, coords, res)
        tqdm.tqdm.write(f'Marching cube finished. Took: {time.time() - t0} s.')

        tqdm.tqdm.write('Sampling points started')
        t0 = time.time()
        points_gt, _ = trimesh.sample.sample_surface(mesh_gt, num_sample_points)
        points_pred, _ = trimesh.sample.sample_surface(mesh_pred, num_sample_points)
        tqdm.tqdm.write(f'Sampling points finished. Took: {time.time() - t0} s.')

        points_gt = torch.tensor(points_gt[None]).cuda()
        points_pred = torch.tensor(points_pred[None]).cuda()

        chamfer_distance_error = kal.metrics.pointcloud.chamfer_distance(points_gt, points_pred)[0].detach().cpu()
        del points_gt
        del points_pred

        l2_error = torch.norm(frame_pred - sdf.clamp_min(min_tsdf).clamp_max(max_tsdf))

        tqdm.tqdm.write('Voxelgrid conversion started')
        t0 = time.time()
        vg_pred = kal.ops.conversions.trianglemeshes_to_voxelgrids(
            torch.tensor(mesh_pred.vertices[None]),
            torch.tensor(mesh_pred.faces),
            res.max().item())
        vg_gt = kal.ops.conversions.trianglemeshes_to_voxelgrids(
            torch.tensor(mesh_gt.vertices[None]),
            torch.tensor(mesh_gt.faces),
            res.max().item())
        tqdm.tqdm.write(f'Voxelgrid conversion finished. Took: {time.time() - t0} s.')

        IoU = kal.metrics.voxelgrid.iou(vg_pred, vg_gt)
        del vg_pred
        del vg_gt

        tqdm.tqdm.write('hausdorff computation started')
        t0 = time.time()
        hausdorff_dist = hausdorff(
            torch.tensor(mesh_gt.vertices), torch.tensor(mesh_gt.faces),
            torch.tensor(mesh_pred.vertices), torch.tensor(mesh_pred.faces))
        tqdm.tqdm.write(f'hausdorff computation finished. Took: {time.time() - t0} s.')

        tqdm.tqdm.write('MSDM2 computation started')
        t0 = time.time()
        MSDM2_err = MSDM2(
            torch.tensor(mesh_gt.vertices), torch.tensor(mesh_gt.faces),
            torch.tensor(mesh_pred.vertices), torch.tensor(mesh_pred.faces))
        tqdm.tqdm.write(f'MSDM2 computation finished. Took: {time.time() - t0} s.')
        result[i] = {
            'l2': l2_error,
            'chamfer_distance': chamfer_distance_error,
            'IoU': IoU[0],
            'hausdorff': hausdorff_dist,
            'MSDM2': MSDM2_err}
    return result

In [8]:
compute_metrics_constant(
    files,
    sdf0['sdf'],
    min_tsdf, max_tsdf,
    30000,
    [0, len(frames) // 2, len(frames) - 1])

  0%|                                                                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s]

Marching cube started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:01<?, ?it/s]

Marching cube finished. Took: 1.3789918422698975 s.
Marching cube started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:03<?, ?it/s]

Marching cube finished. Took: 1.3810975551605225 s.
Sampling points started
Sampling points finished. Took: 0.13541054725646973 s.


  0%|                                                                                                                                                                                                                                 | 0/3 [00:03<?, ?it/s]

Voxelgrid conversion started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:11<?, ?it/s]

Voxelgrid conversion finished. Took: 7.789614200592041 s.


  0%|                                                                                                                                                                                                                                 | 0/3 [00:12<?, ?it/s]

hausdorff computation started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:12<?, ?it/s]

hausdorff computation finished. Took: 0.753493070602417 s.
MSDM2 computation started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:31<09:03, 271.77s/it]

Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
MSDM2 computation finished. Took: 258.8096902370453 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:31<09:03, 271.77s/it]

Marching cube started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:33<09:03, 271.77s/it]

Marching cube finished. Took: 1.3744697570800781 s.
Marching cube started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:34<09:03, 271.77s/it]

Marching cube finished. Took: 1.3788716793060303 s.
Sampling points started
Sampling points finished. Took: 0.13970375061035156 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:35<09:03, 271.77s/it]

Voxelgrid conversion started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:43<09:03, 271.77s/it]

Voxelgrid conversion finished. Took: 7.992024660110474 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:44<09:03, 271.77s/it]

hausdorff computation started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:45<09:03, 271.77s/it]

hausdorff computation finished. Took: 1.5896661281585693 s.
MSDM2 computation started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:26<04:45, 285.12s/it]

Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
MSDM2 computation finished. Took: 280.50464272499084 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:26<04:45, 285.12s/it]

Marching cube started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:27<04:45, 285.12s/it]

Marching cube finished. Took: 1.3711464405059814 s.
Marching cube started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:29<04:45, 285.12s/it]

Marching cube finished. Took: 1.3747434616088867 s.
Sampling points started
Sampling points finished. Took: 0.13660573959350586 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:30<04:45, 285.12s/it]

Voxelgrid conversion started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:37<04:45, 285.12s/it]

Voxelgrid conversion finished. Took: 7.801785707473755 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:38<04:45, 285.12s/it]

hausdorff computation started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [09:39<04:45, 285.12s/it]

hausdorff computation finished. Took: 0.9312138557434082 s.
MSDM2 computation started


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [14:02<00:00, 280.69s/it]

Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
MSDM2 computation finished. Took: 262.7808926105499 s.





{0: {'l2': tensor(0.),
  'chamfer_distance': tensor(4.2841e-05, dtype=torch.float64),
  'IoU': tensor(0.9999),
  'hausdorff': 0.0005439726856382182,
  'MSDM2': 0.004876062542257095},
 142: {'l2': tensor(12.2017),
  'chamfer_distance': tensor(0.0067, dtype=torch.float64),
  'IoU': tensor(0.0072),
  'hausdorff': 0.19801137616898737,
  'MSDM2': 0.5154686371753484},
 283: {'l2': tensor(4.1074),
  'chamfer_distance': tensor(0.0002, dtype=torch.float64),
  'IoU': tensor(0.0311),
  'hausdorff': 0.06490128288537549,
  'MSDM2': 0.3891190929040058}}

In [15]:
compute_metrics_constant(
    files,
    (sdf0['sdf'] + sdf1['sdf']) / 2,
    min_tsdf, max_tsdf,
    30000,
    [0, len(frames) // 2, len(frames) - 1])

  0%|                                                                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s]

Marching cube started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:01<?, ?it/s]

Marching cube finished. Took: 1.369436502456665 s.
Marching cube started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:03<?, ?it/s]

Marching cube finished. Took: 1.3688263893127441 s.
Sampling points started
Sampling points finished. Took: 0.1344602108001709 s.


  0%|                                                                                                                                                                                                                                 | 0/3 [00:03<?, ?it/s]

Voxelgrid conversion started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:11<?, ?it/s]

Voxelgrid conversion finished. Took: 7.740429162979126 s.


  0%|                                                                                                                                                                                                                                 | 0/3 [00:11<?, ?it/s]

hausdorff computation started


  0%|                                                                                                                                                                                                                                 | 0/3 [00:13<?, ?it/s]

hausdorff computation finished. Took: 1.4946918487548828 s.
MSDM2 computation started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:22<08:45, 262.74s/it]

Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
MSDM2 computation finished. Took: 249.25557374954224 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:22<08:45, 262.74s/it]

Marching cube started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:24<08:45, 262.74s/it]

Marching cube finished. Took: 1.3812518119812012 s.
Marching cube started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:25<08:45, 262.74s/it]

Marching cube finished. Took: 1.3923509120941162 s.
Sampling points started
Sampling points finished. Took: 0.13874220848083496 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:26<08:45, 262.74s/it]

Voxelgrid conversion started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:34<08:45, 262.74s/it]

Voxelgrid conversion finished. Took: 8.061902284622192 s.


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:35<08:45, 262.74s/it]

hausdorff computation started


 33%|████████████████████████████████████████████████████████████████████████                                                                                                                                                | 1/3 [04:36<08:45, 262.74s/it]

hausdorff computation finished. Took: 0.9936776161193848 s.
MSDM2 computation started
Asking to calculate curvature


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:33<04:15, 255.89s/it]

Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
MSDM2 computation finished. Took: 237.63420391082764 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:34<04:15, 255.89s/it]

Marching cube started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:35<04:15, 255.89s/it]

Marching cube finished. Took: 1.3711090087890625 s.
Marching cube started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:36<04:15, 255.89s/it]

Marching cube finished. Took: 1.3776323795318604 s.
Sampling points started
Sampling points finished. Took: 0.13610029220581055 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:37<04:15, 255.89s/it]

Voxelgrid conversion started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:45<04:15, 255.89s/it]

Voxelgrid conversion finished. Took: 7.927035331726074 s.


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:46<04:15, 255.89s/it]

hausdorff computation started


 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                        | 2/3 [08:47<04:15, 255.89s/it]

hausdorff computation finished. Took: 1.5349617004394531 s.
MSDM2 computation started
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature
Asking to calculate curvature


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [12:57<00:00, 259.11s/it]

MSDM2 computation finished. Took: 249.64593958854675 s.





{0: {'l2': tensor(11.2982),
  'chamfer_distance': tensor(0.0062, dtype=torch.float64),
  'IoU': tensor(0.0080),
  'hausdorff': 0.1682107862003486,
  'MSDM2': 0.5376455052178554},
 142: {'l2': tensor(3.8179),
  'chamfer_distance': tensor(0.0002, dtype=torch.float64),
  'IoU': tensor(0.0412),
  'hausdorff': 0.03962373930620786,
  'MSDM2': 0.4032762050864984},
 283: {'l2': tensor(11.6279),
  'chamfer_distance': tensor(0.0068, dtype=torch.float64),
  'IoU': tensor(0.0060),
  'hausdorff': 0.1900046183604913,
  'MSDM2': 0.5454763465573684}}

# sequential round

In [None]:
rmax = 50
TSDF_MIN = -0.005
TSDF_MAX = 0.005
t = tn.zeros((512, 512, 512, 284))
for i, frame in enumerate(frames):
    sdf = torch.load(osp.join(folder, frame))['sdf']
    t[..., i] = sdf.clamp_min(TSDF_MIN).clamp_max(TSDF_MAX)
    if i == 1:
        break

In [23]:
t.round_tt(rmax=rmax)

In [24]:
t

4D TT tensor:

 512 512 512 284
  |   |   |   |
 (0) (1) (2) (3)
 / \ / \ / \ / \
1   50  50  1   1

# Compression experiment

In [7]:
tuckers = torch.load('../logs/sweeps/tt_tucker_00032_longshort_flying_eagle_high_ranks.pt')

In [8]:
rmax = 100

[t.round_tucker(rmax=rmax) for t in tuckers]
local_res_tucker = reduce_tucker(
    [t[..., None] for t in tuckers],
    eps=1e-16, rmax=1 * rmax, algorithm='svd')

In [9]:
local_res_tucker

4D TT-Tucker tensor:

 512 512 512 284
  |   |   |   |
 100 100 100 100
 (0) (1) (2) (3)
 / \ / \ / \ / \
1   200 1000100 1

In [None]:
max_tt_rank = 15
low_rank_tt = local_res_tucker.clone()
low_rank_tt.round_tt(rmax=low_rank_tt.ranks_tt[1:-2].tolist() + [max_tt_rank])

In [None]:
low_rank_tt

In [6]:
coords = torch.tensor(
    torch.load('/scratch2/data/cape_release/meshes/00032/longshort_flying_eagle/coords.pt')['coords'])

In [37]:
# mesh = sdf2mesh(tuckers[284 // 2].torch(), coords, 512)
# mesh = sdf2mesh(local_res_tucker[..., 284 // 2].torch(), coords, 512)
mesh = sdf2mesh(low_rank_tt[..., 284 // 2].torch(), coords, 512)
obj = trimesh.exchange.obj.export_obj(mesh, include_texture=False)
with open('./mesh_sdf.obj', 'w') as f:
    f.write(obj)

In [6]:
import torch

def get_qtt3d_reshape_plan(dim_grid_log2):
    dim_grid = 2 ** dim_grid_log2
    num_factors = dim_grid_log2 * 3

    shape_src = [dim_grid] * 3
    shape_dst = [8] * dim_grid_log2
    shape_factors = [2] * num_factors

    factor_ids = torch.arange(num_factors)
    permute_factors_src_to_dst = factor_ids.reshape(3, dim_grid_log2).T.reshape(-1).tolist()
    permute_factors_dst_to_src = factor_ids.reshape(dim_grid_log2, 3).T.reshape(-1).tolist()

    permute_factors_src_to_dst.append(num_factors)
    permute_factors_dst_to_src.append(num_factors)

    return {
        'shape_factors': shape_factors,
        'shape_src': shape_src,
        'shape_dst': shape_dst,
        'permute_factors_src_to_dst': permute_factors_src_to_dst,
        'permute_factors_dst_to_src': permute_factors_dst_to_src,
    }


def tensor_order_to_qtt(x, plan):
    x = x.reshape(plan['shape_factors'])
    x = x.permute(plan['permute_factors_src_to_dst'])
    x = x.reshape(plan['shape_dst'])
    return x


def tensor_order_from_qtt(x, plan):
    x = x.reshape(plan['shape_factors'])
    x = x.permute(plan['permute_factors_dst_to_src'])
    x = x.reshape(plan['shape_src'])
    return x

In [7]:
get_qtt3d_reshape_plan(2)

{'shape_factors': [2, 2, 2, 2, 2, 2],
 'shape_src': [4, 4, 4],
 'shape_dst': [8, 8],
 'permute_factors_src_to_dst': [0, 2, 4, 1, 3, 5, 6],
 'permute_factors_dst_to_src': [0, 3, 1, 4, 2, 5, 6]}

In [9]:
import numpy as np
np.log(4) / np.log(2)

2.0

In [10]:
def is_pow2(n: int) -> bool:
    logn = np.log(n) / np.log(2)
    return (2**logn - 2**int(logn)) < 1e-8

In [18]:
is_pow2(1024)

True

In [19]:
import itertools

In [20]:
dimentions = np.arange(8)

[0,
 8,
 16,
 1,
 9,
 17,
 2,
 10,
 18,
 3,
 11,
 19,
 4,
 12,
 20,
 5,
 13,
 21,
 6,
 14,
 22,
 7,
 15,
 23]

In [63]:
def is_pow2(n: int) -> bool:
    return n != 0 and ((n & (n - 1)) == 0)


def tensor3d2qtt(t: torch.Tensor, checks: bool = True) -> torch.Tensor:
    shape = t.shape
    if checks:
        assert len(shape) == 3
        assert shape[0] == shape[1] == shape[2], f'Only tensors with all equal dimensions are supported'
        assert is_pow2(shape[0])
    dim_grid = int(np.log(shape[0]) / np.log(2))
    num_dims = 3 * dim_grid
    qtt = t.reshape([2] * num_dims)
    dimentions = np.arange(num_dims)
    dims = list(
        itertools.chain.from_iterable(
            zip(
                dimentions[:dim_grid],
                dimentions[dim_grid:2 * dim_grid],
                dimentions[2 * dim_grid:])))
    qtt = qtt.permute(dims)
    return qtt


def qtt2tensor3d(qtt: torch.Tensor, checks: bool = True) -> torch.Tensor:
    shape = qtt.shape
    if checks:
        assert all([shape[0] == shape[i] == 2 for i in range(1, len(shape))]), f'Only qtt tensors are supported'
        assert len(shape) // 3 == len(shape) / 3
    num_dims = len(shape)
    dimentions = np.arange(num_dims)
    dims = list(np.concatenate([
        dimentions[0::3],
        dimentions[1::3],
        dimentions[2::3]]))
    t = qtt.permute(dims)
    dim_grid = 2 ** (num_dims // 3)
    return t.reshape([dim_grid] * 3)

In [64]:
t = torch.rand((32,32,32))

In [65]:
torch.norm(qtt2tensor3d(tensor3d2qtt(t, True)) - t)

tensor(0.)

In [55]:
4096 / 16/ 16/ 16

1.0