In [3]:
import math
import os
import time
import warnings
from functools import partial

import h5py
import kaolin
import opt_einsum
import torch
import torch.nn.functional as F
import trimesh
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import pyvista as pv
pv.set_jupyter_backend('pythreejs')
from t4dt.utils import sdf2mesh
import numpy as np

import sys
sys.path.append('..')

from t4dt.utils import sdf2mesh


def get_qtt_reshape_plan(dim_grid_log2, qtt_group_mode_size=3):
    dim_grid = 2 ** dim_grid_log2
    num_factors = dim_grid_log2 * qtt_group_mode_size

    shape_src = [dim_grid] * qtt_group_mode_size
    shape_dst = [2 ** qtt_group_mode_size] * dim_grid_log2
    shape_factors = [2] * num_factors

    factor_ids = torch.arange(num_factors)
    permute_factors_src_to_dst = factor_ids.reshape(qtt_group_mode_size, dim_grid_log2).T.reshape(-1).tolist()
    permute_factors_dst_to_src = factor_ids.reshape(dim_grid_log2, qtt_group_mode_size).T.reshape(-1).tolist()

    return {
        'shape_factors': shape_factors,
        'shape_src': shape_src,
        'shape_dst': shape_dst,
        'permute_factors_src_to_dst': permute_factors_src_to_dst,
        'permute_factors_dst_to_src': permute_factors_dst_to_src,
    }


def tensor_order_to_qtt(x, plan):
    x = x.reshape(plan['shape_factors'])
    x = x.permute(plan['permute_factors_src_to_dst'])
    x = x.reshape(plan['shape_dst'])
    return x


def tensor_order_from_qtt(x, plan):
    x = x.reshape(plan['shape_factors'])
    x = x.permute(plan['permute_factors_dst_to_src'])
    x = x.reshape(plan['shape_src'])
    return x


def get_tt_ranks(shape, max_rank=None):
    if type(shape) not in (tuple, list) or len(shape) == 0:
        raise ValueError(f'Invalid shape: {shape}')
    if len(shape) == 1:
        return [1, 1]
    ranks_left = [1] + torch.cumprod(torch.tensor(shape), dim=0).tolist()
    ranks_right = list(reversed([1] + torch.cumprod(torch.tensor(list(reversed(shape))), dim=0).tolist()))
    ranks_tt = [min(a, b) for a, b in zip(ranks_left, ranks_right)]
    if max_rank is not None:
        ranks_tt = [min(r, max_rank) for r in ranks_tt]
    return ranks_tt


def gen_letter():
    next_letter_id = 0
    while True:
        yield opt_einsum.get_symbol(next_letter_id)
        next_letter_id += 1


def shapes(input):
    return [c.shape for c in input]


def is_tt_shapes(
        input_shapes,
        inputs_with_batch_dim=None,
        batch_size=None,
        allow_loose_rank_left=False,
        allow_loose_rank_right=False,
):
    if type(input_shapes) not in (tuple, list) or \
            any(len(s) != 3 for s in input_shapes) or \
            any(input_shapes[i-1][-1] != input_shapes[i][0] for i in range(1, len(input_shapes))):
        return False
    if not (allow_loose_rank_left or input_shapes[0][0] == 1):
        return False
    if not (allow_loose_rank_right or input_shapes[-1][-1] == 1):
        return False
    if inputs_with_batch_dim is not None and not (
            type(batch_size) is int and
            batch_size > 0 and
            type(inputs_with_batch_dim) in (tuple, list) and
            len(inputs_with_batch_dim) == len(input_shapes) and
            all([type(b) is bool for b in inputs_with_batch_dim])
    ):
        return False
    return True


def is_list_of_tensors(input):
    return type(input) in (list, tuple) and all(torch.is_tensor(c) for c in input)


def is_tt(input):
    return is_list_of_tensors(input) and is_tt_shapes(shapes(input))


def perf_report(equation, *shapes, einsum_opt_method='dp'):
    _, pathinfo = opt_einsum.contract_path(equation, *shapes, shapes=True, optimize=einsum_opt_method)
    out = {
        'flops': int(pathinfo.opt_cost),
        'size_max_intermediate': int(pathinfo.largest_intermediate),
        'size_all_intermediate': int(sum(pathinfo.size_list)),
        'equation': equation,
        'input_shapes': shapes,
    }
    return out


def compile_tt_contraction_fn(
        input_shapes,
        inputs_with_batch_dim=None,
        batch_size=None,
        allow_loose_rank_left=False,
        allow_loose_rank_right=False,
        last_core_is_payload=False,
        output_modes_squeeze=False,
        output_last_rank_keep=False,
        einsum_opt_method='dp',
        report_flops=False
):
    if not is_tt_shapes(input_shapes, inputs_with_batch_dim, batch_size, allow_loose_rank_left, allow_loose_rank_right):
        raise ValueError(f'Operand shapes do not form a tensor train: {input_shapes=} {inputs_with_batch_dim=} '
                         f'{batch_size=} {allow_loose_rank_left=} {allow_loose_rank_right=}')

    have_batch_dim = inputs_with_batch_dim is not None and any(inputs_with_batch_dim)
    letter_batch = None
    letter = gen_letter()
    if have_batch_dim:
        letter_batch = next(letter)

    equation_left = ''
    equation_right = letter_batch if have_batch_dim else ''

    letter_core_last_rank_right = None
    input_shapes_with_batch_dim = []

    for i in range(len(input_shapes)):
        if inputs_with_batch_dim is not None and inputs_with_batch_dim[i]:
            input_shapes_with_batch_dim.append([batch_size] + list(input_shapes[i]))
        else:
            input_shapes_with_batch_dim.append(list(input_shapes[i]))

        letter_rank_left = next(letter) if i == 0 else letter_core_last_rank_right
        letters_modes = [next(letter) for _ in range(len(input_shapes[i]) - 2)]
        letter_rank_right = next(letter)
        letter_core_last_rank_right = letter_rank_right
        if i > 0:
            equation_left += ','
        if inputs_with_batch_dim is not None and inputs_with_batch_dim[i]:
            equation_left += letter_batch
        equation_left += letter_rank_left
        equation_left += ''.join(letters_modes)
        equation_left += letter_rank_right
        if i == 0 and input_shapes[i][0] > 1:
            equation_right += letter_rank_left
        if output_modes_squeeze:
            for c, si in zip(letters_modes, input_shapes[i][1:-1]):
                if si > 1 or (last_core_is_payload and i == len(input_shapes) - 1):
                    equation_right += c
        else:
            equation_right += ''.join(letters_modes)
        if i == len(input_shapes) - 1 and (output_last_rank_keep or input_shapes[i][-1] > 1):
            equation_right += letter_rank_right

    equation = equation_left + '->' + equation_right
    contraction_fn = opt_einsum.contract_expression(equation, *input_shapes_with_batch_dim, optimize=einsum_opt_method)

    if report_flops:
        report = perf_report(equation, *input_shapes_with_batch_dim, einsum_opt_method=einsum_opt_method)
        return contraction_fn, report

    return contraction_fn

def convert_qtt_to_tensor(input, qtt_reshape_plan=None, fn_contract=None, checks=False):
    if checks and not is_tt(input):
        raise ValueError('Operand is not a tensor train')
    if fn_contract is None:
        input_shapes = shapes(input)
        fn_contract = compile_tt_contraction_fn(input_shapes)
    out = fn_contract(*input)
    if qtt_reshape_plan is not None:
        out = tensor_order_from_qtt(out, qtt_reshape_plan)
    return out

loaded = torch.load('/scratch2/data/qttnf_model.pt', map_location='cpu')


def extract_frame_sdf_from_ttnf(
        cores,
        dim_grid,
        frame_id,
        checks=False,
):
    dim_grid_log2 = int(math.log2(dim_grid))
    frame_id_bits = [1 if (frame_id & (1 << (dim_grid_log2 - 1 - n))) else 0 for n in range(dim_grid_log2)]
    assert len(cores) == dim_grid_log2, shapes(cores)
    cores = [c.reshape(c.shape[0], 2, 8, c.shape[-1]) for c in cores]
    cores = [c[:, i, :, :].cpu() for c, i in zip(cores, frame_id_bits)]
    plan = get_qtt_reshape_plan(dim_grid_log2, qtt_group_mode_size=3)
    out = convert_qtt_to_tensor(cores, qtt_reshape_plan=plan, checks=checks)
    return out

rec = extract_frame_sdf_from_ttnf(list(loaded.values()), 512, 0)
coords = torch.tensor([-0.9714, -0.6217, -1.0168,  0.9179,  1.2920,  0.6641])
pl = pv.Plotter()
pl.camera_position = [0, 5, 10]
pl.camera.elevation = 0
pl.camera.roll = 0
pl.camera.azimuth = 0
pl.camera.zoom(1.5)

for j, i in enumerate([0, 142, 283]):
    framei = extract_frame_sdf_from_ttnf(list(loaded.values()), 512, i)
    tmeshi = sdf2mesh(framei, coords)
    tmeshi.vertices += j * np.array([-1, 0, 1])
    mesh = pv.wrap(tmeshi)

    pl.add_mesh(mesh)
pl.show(screenshot=f'perda.png', jupyter_backend='none')

In [28]:
import pyvista as pv
pv.set_jupyter_backend('pythreejs')

rec = extract_frame_sdf_from_ttnf(list(loaded.values()), 512, 0)
coords = torch.tensor([-0.9714, -0.6217, -1.0168,  0.9179,  1.2920,  0.6641])
pl = pv.Plotter()
pl.camera_position = [0, 5, 10]
pl.camera.elevation = 0
pl.camera.roll = 0
pl.camera.azimuth = 0
pl.camera.zoom(1.5)

for j, i in enumerate([0, 142, 283]):
    framei = extract_frame_sdf_from_ttnf(list(loaded.values()), 512, i)
    tmeshi = sdf2mesh(framei, coords)
    tmeshi.vertices += j * np.array([-1, 0, 1])
    mesh = pv.wrap(tmeshi)

    pl.add_mesh(mesh)
pl.show(screenshot=f'perda.png', jupyter_backend='none')

In [30]:
for core in list(loaded.values()):
    print(core.shape)

torch.Size([1, 16, 16])
torch.Size([16, 16, 256])
torch.Size([256, 16, 256])
torch.Size([256, 16, 256])
torch.Size([256, 16, 256])
torch.Size([256, 16, 256])
torch.Size([256, 16, 256])
torch.Size([256, 16, 16])
torch.Size([16, 16, 1])
