In [None]:
!git clone https://github.com/samay-patel-2110/MeshCNN_3DVSS/tree/main

In [None]:
!pip install -r requirements.txt

## Get Shrec Dataset

In [None]:
# get data
!mkdir -p $DATADIR && cd $DATADIR
!wget https://www.dropbox.com/s/w16st84r6wc57u7/shrec_16.tar.gz
!tar -xzvf shrec_16.tar.gz && rm shrec_16.tar.gz
!echo "downloaded the data and putting it in: " $DATADIR

# ## 2. Dataset Structure
# The dataset should be organized as follows:
# ```
# datasets/
# └── shrec_16/
#     ├── laptop/
#     │   ├── train/
#     │   ├── test/
#     │   └── val/
#     └── ...
# ```

## Features extraction

In [2]:
import trimesh
import numpy as np
from models.layers.mesh_prepare import from_scratch

In [3]:
x = from_scratch("./datasets/shrec_16/laptop/train/T17.obj",[0])

In [4]:
def build_gemm_explain(mesh, faces, face_areas):
    """
    gemm_edges: array (#E x 4) of the 4 one-ring neighbors for each edge
    sides: array (#E x 4) indices (values of: 0,1,2,3) indicating where an edge is in the gemm_edge entry of the 4 neighboring edges
    for example edge i -> gemm_edges[gemm_edges[i], sides[i]] == [i, i, i, i]
    """
    print("Initializing data structures...")
    mesh.ve = [[] for _ in mesh.vs]  # Vertex to edge mapping
    edge_nb = []  # Edge neighbors
    sides = []    # Side indices
    edge2key = dict()  # Edge to index mapping
    edges = []    # List of edges
    edges_count = 0
    nb_count = []  # Neighbor count per edge

    print(f"Processing {len(faces)} faces...")
    for face_id, face in enumerate(faces):
        print(f"\nProcessing face {face_id}: {face}")
        faces_edges = []
        
        # Get edges for current face
        for i in range(3):
            cur_edge = (face[i], face[(i + 1) % 3])
            faces_edges.append(cur_edge)
        print(f"Face edges before sorting: {faces_edges}")

        # Process each edge
        for idx, edge in enumerate(faces_edges):
            edge = tuple(sorted(list(edge)))  # Sort vertices to ensure consistent edge representation
            faces_edges[idx] = edge
            print(f"Processing edge {edge}")
            
            if edge not in edge2key:
                print(f"New edge found: {edge}")
                edge2key[edge] = edges_count
                edges.append(list(edge))
                edge_nb.append([-1, -1, -1, -1])  # Initialize 4 neighbors as -1
                sides.append([-1, -1, -1, -1])    # Initialize 4 sides as -1
                mesh.ve[edge[0]].append(edges_count)  # Add edge to vertex 0's edge list
                mesh.ve[edge[1]].append(edges_count)  # Add edge to vertex 1's edge list
                mesh.edge_areas.append(0)
                nb_count.append(0)
                edges_count += 1
            
            # Update edge areas
            mesh.edge_areas[edge2key[edge]] += face_areas[face_id] / 3
            print(f"Updated edge areas for edge {edge}: {mesh.edge_areas[edge2key[edge]]}")

        # Set up edge neighbors
        for idx, edge in enumerate(faces_edges):
            edge_key = edge2key[edge]
            print(f"\nSetting up neighbors for edge {edge} (key: {edge_key})")
            
            # Connect to next edge in face
            next_edge = faces_edges[(idx + 1) % 3]
            edge_nb[edge_key][nb_count[edge_key]] = edge2key[next_edge]
            print(f"Connected to next edge {next_edge} (key: {edge2key[next_edge]})")
            
            # Connect to previous edge in face
            prev_edge = faces_edges[(idx + 2) % 3]
            edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[prev_edge]
            print(f"Connected to previous edge {prev_edge} (key: {edge2key[prev_edge]})")
            
            nb_count[edge_key] += 2

        # Set up side indices
        for idx, edge in enumerate(faces_edges):
            edge_key = edge2key[edge]
            print(f"\nSetting up sides for edge {edge} (key: {edge_key})")
            
            # Calculate side indices for connections
            sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1
            sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2
            print(f"Updated sides: {sides[edge_key]}")

    print("\nFinalizing mesh data...")
    mesh.edges = np.array(edges, dtype=np.int32)
    mesh.gemm_edges = np.array(edge_nb, dtype=np.int64)
    mesh.sides = np.array(sides, dtype=np.int64)
    mesh.edges_count = edges_count
    mesh.edge_areas = np.array(mesh.edge_areas, dtype=np.float32) / np.sum(face_areas)
    
    print(f"Final statistics:")
    print(f"Total edges: {edges_count}")
    print(f"Edge areas shape: {mesh.edge_areas.shape}")
    print(f"Gemm edges shape: {mesh.gemm_edges.shape}")
    print(f"Sides shape: {mesh.sides.shape}")

In [None]:
x.sides

In [None]:
x.edges

In [None]:
x.gemm_edges

## Classification

In [None]:
#!/usr/bin/env bash

CHECKPOINT='checkpoints/shrec16'

!mkdir -p $CHECKPOINT
!wget https://www.dropbox.com/s/wqq1qxj4fjbpfas/shrec16_wts.tar.gz
!tar -xzvf shrec16_wts.tar.gz && rm shrec16_wts.tar.gz
!mv latest_net.pth $CHECKPOINT
!echo "downloaded pretrained weights to" $CHECKPOINT

In [None]:
## run the test and export collapses

# Processing batch of 16 meshes
# Forward pass -> compute correct classification -> outputs confsion matrix

!python test.py \
--dataroot datasets/shrec_16 \
--name shrec16 \
--ncf 64 128 256 256 \
--pool_res 600 450 300 180 \
--norm group \
--resblocks 1 \
--export_folder meshes \

## Visualize Edge Collapse

In [None]:
import trimesh

item = 1

# Load and display a mesh using trimesh
mesh = trimesh.load(f'checkpoints/shrec16/meshes/T{item}_0.obj')

# Create a scene and add the mesh
scene = trimesh.Scene(mesh)

# Show the mesh with triangles visible
scene.show(flags={'cull': False, 'wireframe': True})


In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Create a figure with 4 subplots side by side
fig = plt.figure(figsize=(20, 5))

# Load all 4 meshes
meshes = []
for i in range(4):
    if i == 1 :
        continue
    mesh = trimesh.load(f'checkpoints/shrec16/meshes/T{item}_{i}.obj')
    meshes.append(mesh)

# Plot each mesh in its own subplot
for i, mesh in enumerate(meshes):
    ax = fig.add_subplot(1, 3, i+1, projection='3d')
    
    # Get vertices and faces
    vertices = np.array(mesh.vertices)
    faces = np.array(mesh.faces)
    
    # Plot the mesh with wireframe
    for face in faces:
        # Get the vertices for this face
        face_vertices = vertices[face]
        # Plot the edges of the face
        ax.plot([face_vertices[0,0], face_vertices[1,0]], 
                [face_vertices[0,1], face_vertices[1,1]], 
                [face_vertices[0,2], face_vertices[1,2]], 'k-', linewidth=0.5)
        ax.plot([face_vertices[1,0], face_vertices[2,0]], 
                [face_vertices[1,1], face_vertices[2,1]], 
                [face_vertices[1,2], face_vertices[2,2]], 'k-', linewidth=0.5)
        ax.plot([face_vertices[2,0], face_vertices[0,0]], 
                [face_vertices[2,1], face_vertices[0,1]], 
                [face_vertices[2,2], face_vertices[0,2]], 'k-', linewidth=0.5)
    
    # Set equal aspect ratio
    ax.set_box_aspect([1,1,1])
    
    # Remove axis labels and ticks
    ax.set_axis_off()
    
    # Set title
    ax.set_title(f'Mesh {i}')

plt.tight_layout()
plt.show()
