In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

from hiegan.config import Config
from hiegan.dataset import ShapeNetMVDataset
from hiegan.utils.mesh_utils import load_and_normalize_mesh, mesh_to_pointcloud
from hiegan.utils.render_utils import simple_renderer, render_mesh_rgb
from hiegan.models.generator import HIEGenerator


In [None]:
cfg = Config()
device = torch.device(cfg.DEVICE if torch.cuda.is_available() or cfg.DEVICE=="mps" else "cpu")
print("Using device:", device)

In [None]:
# Dataset + Loader
ds = ShapeNetMVDataset(cfg.DATASET_ROOT, image_size=cfg.IMAGE_SIZE, multi_view=cfg.MULTI_VIEW)
print("Dataset size:", len(ds))
dl = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)

In [None]:
# Fetch one sample
imgs, mesh_path = next(iter(dl))
print("Images tensor shape (V,3,H,W):", imgs.shape)
print("Mesh path:", mesh_path)

In [None]:
# Show the input image(s)
grid = make_grid(imgs, nrow=imgs.shape[0], normalize=True, value_range=(-1,1)).squeeze(0)
plt.figure()
plt.title("Input view(s)")
plt.axis("off")
plt.imshow(grid.permute(1,2,0).cpu())
plt.show()

In [None]:
# Load & render the GT mesh
mesh = load_and_normalize_mesh(mesh_path[0], device=str(device))
renderer = simple_renderer(image_size=256, device=str(device))
rgb = render_mesh_rgb(mesh, renderer, device=str(device))[0].cpu().numpy()
plt.figure()
plt.title("GT Mesh quick render")
plt.axis("off")
plt.imshow(rgb)
plt.show()

In [None]:
# Sample points for later metrics
pts = mesh_to_pointcloud(mesh, num_samples=2048)
print("Sampled pointcloud:", pts.shape)

In [None]:
G = HIEGenerator(latent_dim=512, device=device).to(device)
G.eval()

In [None]:
imgs, mesh_paths = next(iter(dl))  # imgs: (B, n_views, 3, H, W)
imgs = imgs.to(device)
print("Images shape:", imgs.shape)

In [None]:
template_mesh = load_and_normalize_mesh(mesh_paths[0], device=device)
template_vertices = template_mesh.verts_list()[0]
# Dummy edges for testing (replace with actual mesh edges for GCN)
edge_index = torch.tensor([[0,1],[1,0]], dtype=torch.long).to(device)

In [None]:
# Sample points for implicit branch
sample_xyz = torch.rand(1, 1024, 3, device=device)*2 - 1

# Forward pass
with torch.no_grad():
    out = G(imgs, template_mesh_vertices=template_vertices,
            template_mesh_edges=edge_index, sample_xyz=sample_xyz)

fused_vertices = out['fused_vertices']  # (B, V, 3)

In [None]:
# Convert fused mesh to point cloud (or use explicit vertices directly)
pc = fused_vertices.squeeze(0).cpu()
print("Point cloud shape:", pc.shape)

In [None]:
# Step 9: Visualize point cloud
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(pc[:,0], pc[:,1], pc[:,2], s=1, c='blue')
ax.set_title("Fused Mesh Point Cloud Preview")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
plt.show()