In [None]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np

from dataset import get_dataset
from logger import Logger
from core.models import get_model
from core import solver_dict
from init import get_cfg, setup_seed, dev_get_cfg

# preparer configuration
# cfg = get_cfg()
cfg  =dev_get_cfg()


In [None]:

# set random seed
setup_seed(cfg["rand_seed"])

# prepare dataset
DatasetClass = get_dataset(cfg)
datasets_dict = dict()
for mode in cfg["modes"]:
    datasets_dict[mode] = DatasetClass(cfg, mode=mode)

# prepare models
ModelClass = get_model(cfg["model"]["model_name"])
model = ModelClass(cfg)

# prepare logger
logger = Logger(cfg)

# register dataset, models, logger to the solver
solver = solver_dict[cfg["runner"].lower()](cfg, model, datasets_dict, logger)

In [None]:

# ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs_old/shape_prior_mugs_dup_old_rename_at_2023-11-05-21-20-55/checkpoint/141_latest.pt"
# ckpt_path = "/home/ziran/se3/EFEM/weights/mugs.pt"
ckpt_path = "/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs_old/shape_prior_mugs_dup_old_rename_at_2023-11-06-22-25-35/checkpoint/5823_latest.pt"
ckpt = torch.load(ckpt_path)


# 注意不是model.load_state_dict,
# 参见 lib_shape_prior/core/solver_v2.py, lib_shape_prior/core/models/model_base.py

model.network.load_state_dict(ckpt['model_state_dict'])

In [None]:

model.network.state_dict().keys()

In [None]:
codebook_path = "/home/ziran/se3/EFEM/lib_shape_prior/mugs.npz"
with np.load(codebook_path) as data:
    # 将 npz 文件内容转换为字典
    codebook = {key: data[key] for key in data}

print(codebook.keys())
for k, v in codebook.items():
    print(k, v.shape)

In [None]:
bs = 3
pred_so3_feat = codebook['z_so3'][:bs]
pred_inv_feat = codebook['z_inv'][:bs]
pred_scale = codebook['scale'][:bs]
pred_center = codebook['center'][:bs]
query = codebook['pcl'][:bs]

In [None]:
device = "cuda:0"

N = 64

space_dim = [N, N, N]  # 示例为一个50x50x50的网格


di = 1
# 创建一个网格，这里我们使用np.linspace来产生线性间隔的点
x = np.linspace(-di, di, space_dim[0])
y = np.linspace(-di, di, space_dim[1])
z = np.linspace(-di, di, space_dim[2])

# 用np.meshgrid得到每个维度的点阵
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# 将这些点整理成query的形式，每行是一个点的坐标
query = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=-1)

query = torch.tensor(query,dtype=torch.float32).to(device)
query = query.repeat(bs, 1, 1)
query.shape

In [None]:
# query = codebook['pcl'][:bs]
# query = torch.tensor(query).float().to(device)
# query.shape

In [None]:
pred_so3_feat = torch.tensor(pred_so3_feat).float().to(device)
pred_inv_feat = torch.tensor(pred_inv_feat).float().to(device)
pred_scale = torch.tensor(pred_scale).float().to(device)
pred_center = torch.tensor(pred_center).float().to(device)

pred_so3_feat.shape


In [None]:
pred_scale

In [None]:





# query = torch.cat([input_pack["points.uni"], 
#                    input_pack["points.nss"]], dim=1)
embedding = {
            "z_so3": pred_so3_feat, # [B, 256, 3]
            "z_inv": pred_inv_feat, # [B, 256]
            "s": pred_scale, # [B]
            # "t": centroid.unsqueeze(1), # [B, 1, 3]
            "t": pred_center, # [B, 1, 3]
        }

sdf_hat = model.network.decode(  # SDF must have nss sampling
            query,
            None,
            embedding,
            return_sdf=True,
        )

In [None]:
sdf_hat.shape

In [None]:
sdf_grid = sdf_hat.reshape(-1, space_dim[0], space_dim[1], space_dim[2]).to("cpu").detach().numpy()

In [None]:
data = sdf_grid[2]

# np.where(0-1e-1<=data and data<=0+1e-1)

In [None]:
(0-1e-1<=data).sum()

In [None]:
%matplotlib inline
from skimage import measure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# 使用 Marching Cubes 算法提取等值面
verts, faces, normals, values = measure.marching_cubes(data, level=0.001)


In [None]:

# 创建一个新的图形
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# 绘制等值面
ax.plot_trisurf(verts[:, 0], verts[:,1], faces, verts[:, 2],
                cmap='Spectral', lw=1)

# 设置图形的视角和轴标签
ax.view_init(30, 60)
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")

plt.show()


In [None]:
import ipyvolume as ipv

# Create a plot using ipyvolume
fig = ipv.figure(width=600, height=600)
# Plot the mesh
mesh = ipv.plot_trisurf(verts[:, 0], verts[:, 1], verts[:, 2], triangles=faces)
ipv.style.use('minimal') # Use minimal style
ipv.show()


In [None]:
import pythreejs as p3

# Convert vertices and faces to the format expected by pythreejs
vertices_list = verts.tolist()
faces_indices = faces.tolist()

# Create the pythreejs mesh
geom = p3.BufferGeometry(attributes={
    'position': p3.BufferAttribute(array=vertices_list),
    'index': p3.BufferAttribute(array=faces_indices, normalized=False)
})
material = p3.MeshBasicMaterial(color='gray', wireframe=True)
mesh = p3.Mesh(geometry=geom, material=material)

# Set up a scene
scene = p3.Scene(children=[mesh, p3.AmbientLight(color='#777777')])

# Set up a camera and controller
camera = p3.PerspectiveCamera(position=[3, 3, 3], up=[0, 0, 1], aspect=600/400)
controller = p3.OrbitControls(controlling=camera)
scene.add(camera)

# Render the scene in the notebook
renderer = p3.Renderer(camera=camera, scene=scene, controls=[controller],
                       width=600, height=400)

renderer


## load mesh

In [None]:
import trimesh

# Load the mesh from the uploaded OBJ file
# mesh_path = '/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs/mesh/epoch_1990/mesh_val_03797390_3d1754b7cb46c0ce5c8081810641ef6_0.obj'
mesh_path = '/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs/mesh/epoch_5469/mesh_train_03797390_b9004dcda66abf95b99d2a3bbaea842a_115.obj'
mesh = trimesh.load(mesh_path)

# We can't directly visualize the mesh in this environment, but we can output some information about it
mesh_details = {
    'vertices': mesh.vertices,
    'faces': mesh.faces,
    'is_watertight': mesh.is_watertight,
    'area': mesh.area,
    'extents': mesh.extents,
    'bounds': mesh.bounds
}

mesh_details['vertices'].shape, mesh_details['faces'].shape, mesh_details['is_watertight']

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# Extracting the vertices and faces for plotting
vertices = mesh.vertices
faces = mesh.faces

# Create a new figure
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

# Fancy indexing: `vertices[faces]` to generate a collection of triangles
mesh_collection = Poly3DCollection(vertices[faces])
mesh_collection.set_edgecolor('k')

ax.add_collection3d(mesh_collection)

# Auto scale to the mesh size
scale = np.concatenate([vertices.min(axis=0), vertices.max(axis=0)]).flatten()
ax.auto_scale_xyz(scale, scale, scale)

# Show the plot
plt.show()

## load PC

In [None]:
# file_path = '/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs/mesh/epoch_1990/input_train_03797390_2037531c43448c3016329cbc378d2a2_106.txt'
file_path = '/home/ziran/se3/EFEM/lib_shape_prior/log/shape_prior_mugs/mesh/epoch_5469/input_train_03797390_b9004dcda66abf95b99d2a3bbaea842a_115.txt'

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Parse the point cloud data
points = np.loadtxt(file_path)

# Create a new figure for the 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
ax.scatter(points[:, 0], points[:, 1], points[:, 2])

# Set labels
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

# Show plot
plt.show()



In [None]:
# We will plot the point cloud from three different representative viewpoints.

# Function to create 3D scatter plot with a given azimuth and elevation
def plot_3d_scatter(ax, azim, elev):
    # Clear current axes
    ax.cla()
    
    # Scatter plot
    ax.scatter(points[:, 0], points[:, 1], points[:, 2])

    # Set labels
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')

    # Set the view angle
    ax.view_init(azim=azim, elev=elev)

# Create a new figure for the 3D plot
fig = plt.figure(figsize=(15, 15))

# Three different angles: (1) top view, (2) side view, (3) front view
angles = [(0, 30), (45, 30), (90, 30)]

# Create subplots for each view
for i, angle in enumerate(angles, start=1):
    ax = fig.add_subplot(2, 2, i, projection='3d')
    plot_3d_scatter(ax, *angle)

# Show the plots
plt.tight_layout()
plt.show()


## codebook
