# 
https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=cc57b01f

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

from ddpm import *

In [None]:
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(1984)

# 初始化模型
latent_dim = 256
hidden_dims = [2048, 2048, 2048, 2048] 
max_freq = 4  # Example max frequency for Fourier features
num_bands = 4  # Number of frequency bands
scalar_hidden_dims = [256,256,256,256]
diffusion_model = LatentDiffusionModel(latent_dim, hidden_dims, scalar_hidden_dims, max_freq, num_bands).to(device)
diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_20k_test1/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/mugs_ddpm_cos_20k_l2loss/model.pth"
# diffusion_ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/chairs_vnhead_residual_2048*4_cos_10k/model.pth"
diffusion_ckpt = torch.load(diffusion_ckpt_path)
diffusion_model.load_state_dict(diffusion_ckpt['model'])
diffusion_model = diffusion_model.to(device)
print('Diffusion Model parameters:', sum(p.numel() for p in diffusion_model.parameters()))

In [None]:
bs = 3
trajs = uncond_p_sample_loop(model=diffusion_model, shape=(bs, 256, 4), return_traj=True)

In [None]:
from core.models import get_model
from dataset import get_dataset
from init import get_cfg, setup_seed, dev_get_cfg

# preparer configuration
category = "mugs"
# category = "kit4cates"
# category = "chairs"
cfg  =dev_get_cfg(category)
# prepare models
ModelClass = get_model(cfg["model"]["model_name"])
model = ModelClass(cfg)
ckpt_path = f"/home/ziran/se3/EFEM/weights/{category}.pt"
ckpt = torch.load(ckpt_path)
model.network.load_state_dict(ckpt['model_state_dict'])
model.network.to(device)

# prepare dataset
DatasetClass = get_dataset(cfg)
datasets_dict = dict()
for mode in cfg["modes"]:
    datasets_dict[mode] = DatasetClass(cfg, mode=mode)

train_ds = datasets_dict["train"]

In [None]:
# fakes_x, fakes_s = sample_xs(diffusion_model, noise_x, noise_s, steps, eta)

fakes_x, fakes_s = torch.from_numpy(trajs[-1][:,:,:3]), torch.from_numpy(trajs[-1][:,:,3])

query_start = 0
query_end = bs
pred_so3_feat = fakes_x[query_start:query_end,:, :].to(device)
pred_inv_feat = fakes_s[query_start:query_end,:].to(device)
pred_scale = (torch.ones((query_end-query_start)) + 0.2).to(device)
pred_center = torch.zeros(query_end-query_start, 1, 3).to(device)

print(pred_so3_feat.shape, pred_inv_feat.shape, pred_scale.shape, pred_center.shape)

In [None]:
len(trajs)

In [None]:
time_slice = slice(980, 1001, 1)
# trajs[time_slice]
sample_idx = 0
fakes_x, fakes_s = np.stack(trajs[time_slice])[:,sample_idx,:,:3], np.stack(trajs[time_slice])[:,sample_idx,:,3]
# fakes_x.shape

In [None]:
fakes_x = torch.from_numpy(fakes_x).to(device)
fakes_s = torch.from_numpy(fakes_s).to(device)

query_start = 0
query_end = 20
pred_so3_feat = fakes_x[query_start:query_end,:, :].to(device)
pred_inv_feat = fakes_s[query_start:query_end,:].to(device)
pred_scale = (torch.ones((query_end-query_start)) + 0.2).to(device)
pred_center = torch.zeros(query_end-query_start, 1, 3).to(device)

print(pred_so3_feat.shape, pred_inv_feat.shape, pred_scale.shape, pred_center.shape)

In [None]:
torch.cuda.empty_cache()

N = 50
space_dim = [N, N, N]  # 示例为一个50x50x50的网格


di = 1
# 创建一个网格，这里我们使用np.linspace来产生线性间隔的点
x = np.linspace(-di, di, space_dim[0])
y = np.linspace(-di, di, space_dim[1])
z = np.linspace(-di, di, space_dim[2])

# 用np.meshgrid得到每个维度的点阵
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# 将这些点整理成query的形式，每行是一个点的坐标
viz_query = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=-1)

viz_query = torch.tensor(viz_query,dtype=torch.float32).to(device)
viz_query = viz_query.repeat(pred_so3_feat.shape[0], 1, 1)
viz_query.shape

In [None]:
with torch.no_grad():
    embedding = {
        "z_so3": pred_so3_feat, # [B, 256, 3]
        "z_inv": pred_inv_feat, # [B, 256]
        "s": pred_scale, # [B]
        # "t": centroid.unsqueeze(1), # [B, 1, 3]
        "t": pred_center, # [B, 1, 3]
    }

    sdf_hat = model.network.decode(  # SDF must have nss sampling
        viz_query,
        None,
        embedding,
        return_sdf=True,
    )
    sdf_grid = sdf_hat.reshape(-1, space_dim[0], space_dim[1], space_dim[2]).to("cpu").detach().numpy()


In [None]:
sdf_grid_list = []
sdf_grid_list.append(sdf_grid)

In [None]:
%matplotlib inline
for idx in range(len(sdf_grid_list[0])):
    for sdf_grid in sdf_grid_list:
        # print("sdf shape",sdf_hat.shape)
        data = sdf_grid[idx]
        # print("sample shape: ", data.shape)


        # pointcloud = query[idx].to("cpu").detach().numpy()
        # # print("pointcloud shape before mask: ", pointcloud.shape)

        # pointcloud = pointcloud[query_mask[idx].to("cpu").detach().numpy()]
        # # print("pointcloud shape AFTER mask: ", pointcloud.shape)
        from skimage import measure
        import matplotlib.pyplot as plt
        import plotly
        import plotly.graph_objs as go
        plotly.offline.init_notebook_mode()


        # pc_x, pc_y, pc_z = pointcloud[:, 0], pointcloud[:, 1], pointcloud[:, 2]


        # 使用 Marching Cubes 算法提取等值面
        print("Max:", data.max(), "Min", data.min())
        verts, faces, normals, values = measure.marching_cubes(data, level=max(data.min()+0.,0.01))
        verts = (verts / (N - 1)) * 2 - 1

        x, y, z = zip(*verts)
        i, j, k = zip(*faces)

        # 创建 mesh3d 图表
        mesh = go.Mesh3d(
            x=x, y=y, z=z,
            i=i, j=j, k=k,
            opacity=0.5,
            name='Mesh'
        )

        # # 创建点云图表
        # pointcloud_plot = go.Scatter3d(
        #     x=pc_x, y=pc_y, z=pc_z,
        #     mode='markers',
        #     marker=dict(
        #         size=2,
        #         opacity=0.8
        #     ),
        #     name='Point Cloud'
        # )

        # 创建图表布局
        layout = go.Layout(
            title='3D Mesh and Point Cloud Visualization',
            scene=dict(
                xaxis=dict(title='X'),
                yaxis=dict(title='Y'),
                zaxis=dict(title='Z')
            )
        )

        # 合并图表并显示
        # fig = go.Figure(data=[mesh, pointcloud_plot], layout=layout)
        fig = go.Figure(data=[mesh], layout=layout)
        plotly.offline.iplot(fig)


