In [1]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np

from dataset import get_dataset
from logger import Logger
from core.models import get_model
from core import solver_dict
from init import get_cfg, setup_seed, dev_get_cfg

# preparer configuration
cfg  =dev_get_cfg()

device = "cuda:3"

# set random seed
setup_seed(cfg["rand_seed"])

# prepare dataset
DatasetClass = get_dataset(cfg)
datasets_dict = dict()
for mode in cfg["modes"]:
    datasets_dict[mode] = DatasetClass(cfg, mode=mode)

# prepare models
ModelClass = get_model(cfg["model"]["model_name"])
model = ModelClass(cfg)

# prepare logger
logger = Logger(cfg)

# register dataset, models, logger to the solver
solver = solver_dict[cfg["runner"].lower()](cfg, model, datasets_dict, logger)

| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:27 | Set GPU: 0 ...   [post_config.py:99]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:27 | Save configuration to local file...   [post_config.py:105]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:27 | Dataset train with 100.0% data, dataset len is 149, total len is 149   [shapenet_new2.py:203]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:27 | Caching train dataset...   [shapenet_new2.py:231]


Please check the configuration
--------------------------------------------------------------------------------
{'dataset': {'aug_ratio': 0.7,
             'categories': ['03797390'],
             'data_root': '../data/ShapeNetV1_SDF',
             'dataset_name': 'shapenet_new2',
             'dataset_proportion': [1.0, 1.0],
             'dataset_root': 'resource/data/XXXX',
             'dep_max_use_view': 12,
             'dep_min_use_view': 4,
             'dep_total_view': 12,
             'depth_postfix': '_dep_small',
             'field_mode': 'sdf',
             'indices': {'test_index': 'None',
                         'train_index': 'None',
                         'val_index': 'None'},
             'input_mode': 'dep',
             'n_pcl': 512,
             'n_query_eval': 10000,
             'n_query_nss': 1024,
             'n_query_uni': 1024,
             'noise_std': 0.01,
             'num_workers': 8,
             'pin_mem': True,
             'ram_cache': True,
  

100%|██████████| 149/149 [00:03<00:00, 37.33it/s]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:31 | Dataset val with 100.0% data, dataset len is 22, total len is 22   [shapenet_new2.py:203]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:31 | Caching val dataset...   [shapenet_new2.py:231]
100%|██████████| 22/22 [00:00<00:00, 37.03it/s]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:32 | DGCNN use Dynamic Graph (different from the input topology)   [vec_dgcnn_atten.py:50]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:32 | 2.016M params in encoder   [misc.py:16]
| sim3vec-mugs | user-AS-4124GS-TNR | INFO | Nov-30-15:13:32 | 0.790M params in decoder   [misc.py:16]


In [2]:
ckpt_path = "/home/ziran/se3/EFEM/weights/mugs.pt"
ckpt = torch.load(ckpt_path)
model.network.load_state_dict(ckpt['model_state_dict'])

<All keys matched successfully>

In [3]:
# codebook_path = "/home/ziran/se3/EFEM/lib_shape_prior/dev_ckpt/codebook.npz"
codebook_path = "/home/ziran/se3/EFEM/cache/mugs.npz"

with np.load(codebook_path) as data:
    # 将 npz 文件内容转换为字典
    codebook = {key: data[key] for key in data}

del codebook['id']
for k, v in codebook.items():
    if isinstance(v, np.ndarray):
        newv = torch.from_numpy(v)
        codebook[k] = newv
    print(k, v.shape)

z_so3 (149, 256, 3)
z_inv (149, 256)
center (149, 1, 3)
scale (149,)
z_so3_proj (149, 256, 3)
z_so3_basis (149, 3, 3)
z_so3_var (149, 3)
bbox (149, 3)
bbox_c (149, 3)
pcl (149, 5000, 3)
cls (149,)


In [4]:
codebook.keys()

dict_keys(['z_so3', 'z_inv', 'center', 'scale', 'z_so3_proj', 'z_so3_basis', 'z_so3_var', 'bbox', 'bbox_c', 'pcl', 'cls'])

In [5]:
codebook['scale'].min()

tensor(0.7090)

In [7]:
bs = 1
pred_so3_feat = codebook['z_so3'][:bs]
pred_inv_feat = codebook['z_inv'][:bs]
pred_scale = codebook['scale'][:bs]
# pred_scale = torch.ones_like(pred_scale) + 0.2
pred_center = codebook['center'][:bs]
# pred_center = torch.zeros_like(pred_center)
# 上面两个注释掉的, 如果用默认0和1的center和scale, 会导致decode不太行, 还是用codebook里的比较好
# 于是那就应该把这四个variable一块diffusion
pred_so3_feat = pred_so3_feat.to(device)
pred_inv_feat = pred_inv_feat.to(device)
pred_scale = pred_scale.to(device)
pred_center = pred_center.to(device)


RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
N = 64

space_dim = [N, N, N]  # 示例为一个50x50x50的网格


di = 1
# 创建一个网格，这里我们使用np.linspace来产生线性间隔的点
x = np.linspace(-di, di, space_dim[0])
y = np.linspace(-di, di, space_dim[1])
z = np.linspace(-di, di, space_dim[2])

# 用np.meshgrid得到每个维度的点阵
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# 将这些点整理成query的形式，每行是一个点的坐标
query = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=-1)
query = torch.tensor(query,dtype=torch.float32).to(device)
query = query.repeat(bs, 1, 1)
query.shape

In [None]:
embedding = {
            "z_so3": pred_so3_feat, # [B, 256, 3]
            "z_inv": pred_inv_feat, # [B, 256]
            "s": pred_scale, # [B]
            # "t": centroid.unsqueeze(1), # [B, 1, 3]
            "t": pred_center, # [B, 1, 3]
        }

sdf_hat = model.network.decode(  # SDF must have nss sampling
            query,
            None,
            embedding,
            return_sdf=True,
        )

In [None]:
sdf_grid = sdf_hat.reshape(-1, space_dim[0], space_dim[1], space_dim[2]).to("cpu").detach().numpy()
data = sdf_grid[2]

In [None]:
%matplotlib inline
from skimage import measure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# 使用 Marching Cubes 算法提取等值面
verts, faces, normals, values = measure.marching_cubes(data, level=0.)

# 创建一个新的图形
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# 绘制等值面
ax.plot_trisurf(verts[:, 0], verts[:,1], faces, verts[:, 2],
                cmap='Spectral', lw=1)

# 设置图形的视角和轴标签
ax.view_init(30, 60)
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")

plt.show()
