In [1]:
!pip install trimesh

Collecting trimesh
  Downloading trimesh-4.3.1-py3-none-any.whl (693 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m693.8/693.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: trimesh
Successfully installed trimesh-4.3.1


In [2]:
!pip install libigl

Collecting libigl
  Downloading libigl-2.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.2/16.2 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: libigl
Successfully installed libigl-2.5.1


In [3]:
import os

import numpy as np

import trimesh

import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from tqdm import tqdm

from scipy.spatial import KDTree

from skimage import measure

import igl

import csv

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
#@title dataloader

class MeshDataset(Dataset):
    def __init__(self, mesh_path, device=torch.device('cuda:0'), subset="train"):
        """
        Initialize the dataset with the path to the mesh file, device configuration, subset selection,
        and debug mode.

        Parameters:
            mesh_path (str): Path to the mesh file.
            device (torch.device): Device on which tensors will be created.
            subset (str): Specify whether to load the 'train' or 'validation' part of the dataset.
            debug (bool): If True, also load colors data for debugging purposes.
        """

        self.device = device
        self.subset = subset

        # Load data
        self.points, self.labels = self.load_data(mesh_path)
        total_points = self.shuffle_data()
        self.split_data(total_points)


    def load_data(self, mesh_path):
        """
        Load the mesh data from the file and extract vertices as points and colors for labeling.

        Parameters:
            mesh_path (str): Path to the mesh file.

        Returns:
            points, labels: Tensors representing vertices, their labels.
        """

        pcd = trimesh.load(mesh_path)

        points = torch.tensor(pcd.vertices, dtype=torch.float32, device=self.device)
        colors = torch.tensor(pcd.visual.vertex_colors, dtype=torch.float32, device=self.device)[:, :3] / 255.

        red_threshold = torch.tensor([1, 0, 0], device=self.device)
        green_threshold = torch.tensor([0, 1, 0], device=self.device)

        self.points = points
        self.labels = torch.zeros(points.shape[0], dtype=torch.float32, device=self.device)
        self.labels[(colors == green_threshold).all(dim=1)] = 1
        self.labels[(colors == red_threshold).all(dim=1)] = 0

        return self.points, self.labels

    def shuffle_data(self):
        """
        Shuffle the data to randomize the order of samples.

        Returns:
            int: Total number of points (samples) after shuffling.
        """

        total_points = self.points.shape[0]
        permutation = torch.randperm(total_points)
        self.points = self.points[permutation]
        self.labels = self.labels[permutation]

        return total_points

    def split_data(self, total_points):
        """
        Split the data into training or validation subsets based on the specified subset type.

        Parameters:
            total_points (int): Total number of points in the dataset.
        """

        if self.subset == 'train':
            indices = torch.arange(0, int(0.8 * total_points))
        else:
            indices = torch.arange(int(0.8 * total_points), total_points)

        # Apply the indices to subset the data.
        self.points = self.points[indices]
        self.labels = self.labels[indices]

    def __len__(self):
        return self.points.shape[0]

    def __getitem__(self, idx=None):
        if idx is None:
            idx = torch.randint(0, self.points.shape[0], (1,)).item()

        return self.points[idx], self.labels[idx]

In [6]:
#@title utils

def plot_points(path):
    ax = plt.figure().add_subplot(projection="3d")
    obj = trimesh.load(path)
    x, y, z = obj.vertices[:, 0], obj.vertices[:, 1], obj.vertices[:, 2]
    mask = obj.colors[:, 1] == 255
    ax.scatter(
        x[mask], y[mask], zs=z[mask], zdir="y", alpha=1, c=obj.colors[mask] / 255
    )
    ax.scatter(
        x[~mask], y[~mask], zs=z[~mask], zdir="y", alpha=0.01, c=obj.colors[~mask] / 255
    )
    plt.show()


def download_data():
    import gdown

    if not os.path.exists("./data"):
        gdown.download_folder(
            "https://drive.google.com/drive/folders/1EKWU_daQL3pxFkjFUomGs25_qekyfeAd",
            quiet=False,
        )

    if not os.path.exists("./processed"):
        gdown.download_folder(
            "https://drive.google.com/drive/folders/175_LtuWh1LknbbMjUumPjGzeSzgQ4ett",
            quiet=False,
        )


In [7]:
EPSILON = 1e-15
IS_GOOGLE_COLAB = True

CONFIG_HASH = "hash"
CONFIG_1LOD = "one_lod"
CONFIG_MLOD = "m_lod"

In [8]:
#@title trainer

from torch.optim import Adam
from torch.nn import BCELoss
from torch.optim.lr_scheduler import StepLR

import json

class OccConfig:
    """
    A configuration class to manage training options and hyperparameters, loaded from a JSON file.
    """
    def __init__(self, name):
        self.name = name
        self.load_config(name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def load_config(self, filename):
        """
        Load configuration from a JSON file.
        """
        try:
            if IS_GOOGLE_COLAB:
              with open(f"drive/MyDrive/3dv_hw3/{filename}.json", "r") as file:
                  data = json.load(file)
                  for key, value in data.items():
                      setattr(self, key, value)
            else:
              with open(f"{filename}.json", "r") as file:
                  data = json.load(file)
                  for key, value in data.items():
                      setattr(self, key, value)

        except FileNotFoundError:
            raise FileNotFoundError(f"no configuration file found for the name '{filename}'")
        except json.JSONDecodeError:
            raise ValueError("error decoding JSON")

    def log_config(self):
        """
        Log all configurations of the OccConfig instance.
        """
        print("config settings:")
        for attr, value in self.__dict__.items():
            print(f"  {attr}: {value}")


class OccTrainer:
    def __init__(self, config):
        """
        Initializes the Trainer class with specified configuration options.

        Parameters:
        config (OccConfig): Configuration options with training parameters and device settings.
        """
        self.config = config
        self.device = config.device

        self.train_dataset = MeshDataset(self.config.current_obj_path, device=self.device)
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True)

        self.model = OCC(self.config).to(self.device)

        self.optimizer = Adam(self.model.parameters(), lr=self.config.lr, betas=(0.9, 0.99), eps=EPSILON, weight_decay=self.config.weight_decay)
        self.criterion = BCELoss()
        self.scheduler = StepLR(self.optimizer, step_size=self.config.lr_decay_step, gamma=self.config.lr_decay_gamma)

        self.num_epochs = self.config.epochs

    def run(self):
        """
        Run training for a specified number of epochs and save the best model based on loss.
        """

        self.model.train()
        best_loss = float('inf')

        for epoch in range(self.num_epochs):
            total_loss = 0.0
            with tqdm(self.train_dataloader, unit="batch") as pbar:
                for points, labels in pbar:
                    pbar.set_description(f"epoch {epoch + 1}")
                    labels = labels.view(-1, 1)

                    # Zero out gradients
                    self.optimizer.zero_grad()

                    # Form predictions
                    pred = self.model(points)

                    # Calculate loss
                    loss = self.criterion(pred, labels)

                    # Calculate gradients
                    loss.backward()

                    # Step and backpropagate
                    self.optimizer.step()
                    self.scheduler.step()

                    total_loss += loss.item()

                    pbar.set_postfix(loss=total_loss / len(self.train_dataloader))

                # Save if necessary
                if total_loss < best_loss:
                    best_loss = total_loss
                    current_obj_name = self.config.current_obj.replace(".obj", "")
                    if IS_GOOGLE_COLAB:
                        torch.save(self.model.state_dict(), f'drive/MyDrive/3dv_hw3/{self.config.name}_{current_obj_name}.pth')
                        print("model saved to: {}".format(f'drive/MyDrive/3dv_hw3/{self.config.name}_{current_obj_name}.pth'))
                    else:
                        torch.save(self.model.state_dict(), f'{self.config.name}_{current_obj_name}.pth')
                    print("final model saved to: {}".format(f'{self.config.name}_{current_obj_name}.pth'))
                elif epoch == self.num_epochs - 1:
                    current_obj_name = self.config.current_obj.replace(".obj", "")
                    torch.save(self.model.state_dict(), f'{self.config.name}_{current_obj_name}_final.pth')
                    print("final model saved to: {}".format(f'{self.config.name}_{current_obj_name}_final.pth'))


    def get_num_params(self):
        """
        Get the number of trainable parameters in the model.

        Returns:
        int: Number of trainable parameters.
        """
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

In [9]:
#@title model

class Baseline():
    def __init__(self, x, y):
        self.y = y
        self.tree = KDTree(x)

    def __call__(self, x):
        _, idx = self.tree.query(x, k=3)
        return np.sign(self.y[idx].mean(axis=1))


def trilinear_interpolate(grid, pts, res, grid_type='dense'):
    """
    Perform trilinear interpolation on a 3D grid at specified points. This function supports both
    dense grid structures and hashed grid representations.

    Parameters:
        grid (torch.Tensor): The grid containing data values, either dense or indexed by hash values.
        pts (torch.Tensor): Coordinates of the points for which interpolation is desired. Points should
                            be normalized between -1 and 1.
        res (int): Resolution of the grid, assumed to be cubic (res x res x res).
        grid_type (str): Type of grid storage, 'dense' for direct storage or any other string for hashed storage.

    Returns:
        torch.Tensor: Interpolated values at the input points.
    """

    PRIMES = [1, 265443567, 805459861]

    # Resize grid
    if grid_type == 'dense':
        grid = grid.reshape(res, res, res, -1)

    # Normalize
    xs = (pts[:, 0] + 1) * 0.5 * (res - 1)
    ys = (pts[:, 1] + 1) * 0.5 * (res - 1)
    zs = (pts[:, 2] + 1) * 0.5 * (res - 1)

    # Base of voxel
    x0 = torch.floor(torch.clip(xs, 0, res - 1 - 1e-5)).long()
    y0 = torch.floor(torch.clip(ys, 0, res - 1 - 1e-5)).long()
    z0 = torch.floor(torch.clip(zs, 0, res - 1 - 1e-5)).long()

    # Other corner
    x1 = x0 + 1
    y1 = y0 + 1
    z1 = z0 + 1

    # Calculate weights
    w1 = ((x1 - xs) * (y1 - ys) * (z1 - zs)).unsqueeze(1)
    w2 = ((xs - x0) * (y1 - ys) * (z1 - zs)).unsqueeze(1)
    w3 = ((x1 - xs) * (ys - y0) * (z1 - zs)).unsqueeze(1)
    w4 = ((xs - x0) * (ys - y0) * (z1 - zs)).unsqueeze(1)
    w5 = ((x1 - xs) * (y1 - ys) * (zs - z0)).unsqueeze(1)
    w6 = ((xs - x0) * (y1 - ys) * (zs - z0)).unsqueeze(1)
    w7 = ((x1 - xs) * (ys - y0) * (zs - z0)).unsqueeze(1)
    w8 = ((xs - x0) * (ys - y0) * (zs - z0)).unsqueeze(1)

    # Get values, which uses hashing function if hash case
    if grid_type == 'dense':
        v1 = grid[x0, y0, z0]
        v2 = grid[x1, y0, z0]
        v3 = grid[x0, y1, z0]
        v4 = grid[x1, y1, z0]
        v5 = grid[x0, y0, z1]
        v6 = grid[x1, y0, z1]
        v7 = grid[x0, y1, z1]
        v8 = grid[x1, y1, z1]
    else:
        id1 = (x0 * PRIMES[0] ^ y0 * PRIMES[1] ^ z0 * PRIMES[2]) % grid.shape[0]
        id2 = (x1 * PRIMES[0] ^ y0 * PRIMES[1] ^ z0 * PRIMES[2]) % grid.shape[0]
        id3 = (x0 * PRIMES[0] ^ y1 * PRIMES[1] ^ z0 * PRIMES[2]) % grid.shape[0]
        id4 = (x1 * PRIMES[0] ^ y1 * PRIMES[1] ^ z0 * PRIMES[2]) % grid.shape[0]
        id5 = (x0 * PRIMES[0] ^ y0 * PRIMES[1] ^ z1 * PRIMES[2]) % grid.shape[0]
        id6 = (x1 * PRIMES[0] ^ y0 * PRIMES[1] ^ z1 * PRIMES[2]) % grid.shape[0]
        id7 = (x0 * PRIMES[0] ^ y1 * PRIMES[1] ^ z1 * PRIMES[2]) % grid.shape[0]
        id8 = (x1 * PRIMES[0] ^ y1 * PRIMES[1] ^ z1 * PRIMES[2]) % grid.shape[0]

        v1 = grid[id1]
        v2 = grid[id2]
        v3 = grid[id3]
        v4 = grid[id4]
        v5 = grid[id5]
        v6 = grid[id6]
        v7 = grid[id7]
        v8 = grid[id8]

    # Lerp
    out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + w5 * v5 + w6 * v6 + w7 * v7 + w8 * v8

    return out


class Baseline():
    def __init__(self, x, y):
        self.y = y
        self.tree = KDTree(x)

    def __call__(self, x):
        _, idx = self.tree.query(x, k=3)
        return np.sign(self.y[idx].mean(axis=1))


class DenseGrid(nn.Module):
    def __init__(self, base_lod, num_lods, feature_dimension, device='cuda'):
        """
        Initializes the DenseGrid module for handling multiple levels of detail (LOD) in a dense grid.

        Parameters:
        base_lod (int): The base level of detail, where each LOD corresponds to a grid of size (2^LOD)^3.
        num_lods (int): Number of levels of detail to generate starting from the base_lod.
        feature_dimension (int): The dimensionality of features at each point in the grid.
        device (str): Device to which the grid tensors will be allocated ('cuda' or 'cpu').
        """
        super(DenseGrid, self).__init__()

        # Define grid resolutions based on lod size args
        self.lod_sizes = [2 ** l for l in range(base_lod, base_lod + num_lods)]

        self.feature_dimension = feature_dimension
        self.device = device

        self.initialize_feature_grids()

    def initialize_feature_grids(self):
        """
        Initialize the feature grids for each level of detail as a Parameter in ParameterList.
        Each grid is initialized to have a normal distribution around mean 0 with a standard deviation of 0.01.
        """
        self.feature_grids = nn.ParameterList()

        for grid_size in self.lod_sizes:
            grid_features = nn.Parameter(torch.zeros(grid_size ** 3, self.feature_dimension, dtype=torch.float32, device=self.device))
            nn.init.normal_(grid_features, mean=0, std=0.01)
            self.feature_grids.append(grid_features)

    def forward(self, points):
        """
        Define the forward pass for interpolating features at the given points from multiple levels of detail.

        Parameters:
        points (Tensor): Tensor of size (num_points, 3) containing the coordinates of points where features are to be interpolated.

        Returns:
        Tensor: Concatenated features from all levels of detail for the input points.
        """

        interpolated_features = []
        for lod_size, grid_features in zip(self.lod_sizes, self.feature_grids):
            interpolated_feature = trilinear_interpolate(grid_features, points, lod_size, grid_type='dense')
            interpolated_features.append(interpolated_feature)

        return torch.cat(interpolated_features, dim=-1)


class HashGrid(nn.Module):
    def __init__(self, minimum_resolution, maximum_resolution, num_lods, hash_bandwidth, feature_dimension, device='cuda'):
        """
        Initializes the HashGrid module for spatial hashing at multiple levels of detail.

        Parameters:
        minimum_resolution (int): Minimum resolution size at the lowest level of detail.
        maximum_resolution (int): Maximum resolution size at the highest level of detail.
        num_lods (int): Number of levels of detail to manage.
        hash_bandwidth (int): Log base 2 of the number of buckets in the hash table.
        feature_dimension (int): The dimensionality of features at each hash grid point.
        device (str): Device to which the grid tensors will be allocated ('cuda' or 'cpu').
        """
        super(HashGrid, self).__init__()

        self.minimum_resolution = minimum_resolution
        self.maximum_resolution = maximum_resolution

        self.num_lods = num_lods
        self.feature_dimension = feature_dimension
        self.device = device
        self.hash_table_size = 2 ** hash_bandwidth

        # Calculate the base for exponential growth of resolutions across LODs
        base_growth = np.exp((np.log(self.maximum_resolution) - np.log(self.minimum_resolution)) / (self.num_lods - 1))
        self.lod_resolutions = [int(1 + np.floor(self.minimum_resolution * (base_growth ** l))) for l in range(self.num_lods)]

        self.initialize_feature_grids()

    def initialize_feature_grids(self):
        """
        Initialize the feature grids for each level of detail as a Parameter in ParameterList.
        Each grid is limited by the hash table size and initialized to have a normal distribution with a very small standard deviation.
        """

        self.feature_grids = nn.ParameterList()

        for lod_size in self.lod_resolutions:
            grid_features = nn.Parameter(
                torch.zeros(min(lod_size ** 3, self.hash_table_size), self.feature_dimension, dtype=torch.float32, device=self.device))
            nn.init.normal_(grid_features, mean=0, std=0.001)
            self.feature_grids.append(grid_features)

    def forward(self, points):
        """
        Define the forward pass for interpolating features at the given points from multiple levels of detail.

        Parameters:
        points (Tensor): Tensor of size (num_points, 3) containing the coordinates of points where features are to be interpolated.

        Returns:
        Tensor: Concatenated features from all levels of detail for the input points.
        """

        interpolated_features = []

        for lod_size, grid_features in zip(self.lod_resolutions, self.feature_grids):
            if points.dim() != 2 or points.shape[1] != 3:
              raise ValueError(f"expected points to be [num_points, 3], got: {points.shape}")

            interpolated_feature = trilinear_interpolate(grid_features, points, lod_size, grid_type='hash')
            interpolated_features.append(interpolated_feature)

        concatenated_features = torch.cat(interpolated_features, dim=-1)

        return concatenated_features


class MLP(nn.Module):
    def __init__(self, num_layers, layer_width, feature_dimension, num_lods):
        """
        Initializes a multilayer perceptron (MLP) with specified parameters.

        Parameters:
        num_layers (int): Number of layers in the MLP, excluding the input and output layers.
        layer_width (int): The number of neurons in each hidden layer.
        feature_dimension (int): The dimensionality of the input features per level of detail.
        num_lods (int): The number of different levels of detail, which influences input dimension.

        The network architecture follows this sequence: Input Layer -> (num_layers) Hidden Layers -> Output Layer.
        """
        super(MLP, self).__init__()

        self.num_layers = num_layers
        self.layer_width = layer_width

        self.initialize_layers(feature_dimension, num_lods)
        self.initialize_weights()

    def initialize_layers(self, feature_dimension, num_lods):
        """
        Construct the layers of the MLP from the input to the output layer.

        Parameters:
            feature_dimension (int): Dimensionality of input features per level of detail.
            num_lods (int): Number of levels of detail, determining the input size.

        This method builds each layer and adds ReLU activations after each hidden layer.
        """

        self.layers = nn.ModuleList()
        input_dimension = feature_dimension * num_lods

        self.layers.append(nn.Linear(input_dimension, self.layer_width))
        self.layers.append(nn.ReLU())

        for _ in range(self.num_layers - 1):
            self.layers.append(nn.Linear(self.layer_width, self.layer_width))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(self.layer_width, 1))
        self.layers = nn.Sequential(*self.layers)

    def initialize_weights(self):
        """
        Initialize weights using the Xavier uniform initializer for better initial weights distribution.
        """
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)

    def forward(self, inputs):
        """
        Defines the forward pass of the MLP using the Sequential model defined.

        Parameters:
        inputs (Tensor): The input tensor to the MLP.

        Returns:
        Tensor: Output tensor after processing through MLP and applying sigmoid activation on the output layer.
        """
        outputs = self.layers(inputs)
        outputs = torch.sigmoid(outputs)
        return outputs



class OCC(nn.Module):
    def __init__(self, config):
        """
        Initializes the OCC model which includes either a DenseGrid or HashGrid and an MLP for processing.

        Parameters:
        config (object): Configuration object containing options like grid type, dimensions, and model parameters.
        """
        super(OCC, self).__init__()

        self.config = config
        self.initialize_model()

    def initialize_model(self):
        # Initialize the appropriate grid based on the configuration.
        if self.config.grid_type == 'dense':
            self.grid = DenseGrid(base_lod=self.config.base_lod, num_lods=self.config.num_lods,
                                  feature_dimension=self.config.grid_feature_dimension)

            self.mlp = MLP(num_layers=self.config.num_mlp_layers, layer_width=self.config.mlp_width,
                       feature_dimension=self.config.grid_feature_dimension, num_lods=self.config.num_lods)

        elif self.config.grid_type == 'hash':
            self.grid = HashGrid(minimum_resolution=2**self.config.base_lod,
                                 maximum_resolution=2**(self.config.base_lod + self.config.num_lods - 1),
                                 num_lods=self.config.num_lods, hash_bandwidth=13,
                                 feature_dimension=self.config.grid_feature_dimension)

            self.mlp = MLP(num_layers=self.config.num_mlp_layers, layer_width=self.config.mlp_width,
                       feature_dimension=self.config.grid_feature_dimension, num_lods=self.config.num_lods)
        else:
            raise NotImplementedError('grid type "{}" not implemented'.format(self.config.grid_type))

    def forward(self, inputs):
        """
        Defines the forward pass through the grid and MLP.

        Parameters:
        inputs (np.ndarray or Tensor): Input data, if numpy array, it will be converted to a Torch Tensor.

        Returns:
        Tensor: The output of the MLP after processing inputs through the grid.
        """
        if isinstance(inputs, np.ndarray):
            inputs = torch.from_numpy(inputs).float().cuda()

        grid_output = self.grid(inputs)
        final_output = self.mlp(grid_output)
        return final_output


In [10]:
#@title reconstruction

def generate_grid(point_cloud, resolutions):
    """Generate grid over the point cloud with given resolution
    Args:
        point_cloud (np.array, [N, 3]): 3D coordinates of N points in space
        resolutions (int): grid resolution
    Returns:
        coords (np.array, [resolutions*resolutions*resolutions, 3]): grid vertices
        coords_matrix (np.array, [4, 4]): transform matrix: [0,resolutions]x[0,resolutions]x[0,resolutions] -> [x_min, x_max]x[y_min, y_max]x[z_min, z_max]
    """
    b_min = np.min(point_cloud, axis=0)
    b_max = np.max(point_cloud, axis=0)

    coords = np.mgrid[:resolutions, :resolutions, :resolutions]
    coords = coords.reshape(3, -1)
    coords_matrix = np.eye(4)
    length = b_max - b_min
    length += length / resolutions
    coords_matrix[0, 0] = length[0] / resolutions
    coords_matrix[1, 1] = length[1] / resolutions
    coords_matrix[2, 2] = length[2] / resolutions
    coords_matrix[0:3, 3] = b_min
    coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4]
    coords = coords.T

    return coords, coords_matrix


def batch_eval(points, eval_func, num_samples):
    """Predict occupancy of values batch-wise
    Args:
        points (np.array, [N, 3]): 3D coordinates of N points in space
        eval_func (function): function that takes a batch of points and returns occupancy values
        num_samples (int): number of points to evaluate at once
    Returns:
        occ (np.array, [N,]): occupancy values for each point
    """

    num_pts = points.shape[0]
    occ = np.zeros(num_pts)

    num_batches = num_pts // num_samples
    for i in range(num_batches):
        occ[i * num_samples: i * num_samples + num_samples] = eval_func(
            points[i * num_samples: i * num_samples + num_samples]
        ).detach().cpu().numpy().squeeze()
    if num_pts % num_samples:
        occ[num_batches * num_samples:] = eval_func(
            points[num_batches * num_samples:]
        ).detach().cpu().numpy().squeeze()

    return occ


def eval_grid(coords, eval_func, num_per_sample=1024):
    """Predict occupancy of values on a grid
    Args:
        coords (np.array, [N, 3]): 3D coordinates of N points in space
        eval_func (function): function that takes a batch of points and returns occupancy values
        num_per_sample (int): number of points to evaluate at once

    Returns:
        occ (np.array, [N,]): occupancy values for each point
    """
    coords = coords.reshape([-1, 3])
    occ = batch_eval(coords, eval_func, num_samples=num_per_sample)
    return occ


def reconstruct(model, grid, resolutions, transform):
    """Reconstruct mesh by predicting occupancy values on a grid
    Args:
        model (function): function that takes a batch of points and returns occupancy values
        grid (np.array, [N, 3]): 3D coordinates of N points in space
        resolutions (int): grid resolution
        transform (np.array, [4, 4]): transform matrix: [0,resolutions]x[0,resolutions]x[0,resolutions] -> [x_min, x_max]x[y_min, y_max]x[z_min, z_max]

    Returns:
        verts (np.array, [M, 3]): 3D coordinates of M vertices
        faces (np.array, [K, 3]): indices of K faces
    """

    occ = eval_grid(grid, model)
    occ = occ.reshape([resolutions, resolutions, resolutions])

    # Correct surface level
    if occ.max() < 0.5:
        surface_level = None
    else:
        surface_level = 0.5

    verts, faces, normals, values = measure.marching_cubes(occ, surface_level)

    verts = np.matmul(transform[:3, :3], verts.T) + transform[:3, 3:4]
    verts = verts.T

    return verts, faces


def compute_metrics(reconstr_path, gt_path, num_samples=1000000):
    """Compute chamfer and hausdorff distances between the reconstructed mesh and the ground truth mesh
    Args:
        reconstr_path (str): path to the reconstructed mesh
        gt_path (str): path to the ground truth mesh
        num_samples (int): number of points to sample from each mesh

    Returns:
        chamfer_dist (float): chamfer distance between the two meshes
        hausdorff_dist (float): hausdorff distance between the two meshes
    """
    reconstr = trimesh.load(reconstr_path)
    gt = trimesh.load(gt_path)

    # Sample points on the mesh surfaces using trimesh
    reconstr_pts = reconstr.sample(num_samples)
    gt_pts = gt.sample(num_samples)

    # Compute chamfer distance between the two point clouds
    reconstr_tree = KDTree(reconstr_pts)
    gt_tree = KDTree(gt_pts)
    dist1, _ = reconstr_tree.query(gt_pts)
    dist2, _ = gt_tree.query(reconstr_pts)
    chamfer_dist = (dist1.mean() + dist2.mean()) / 2
    hausdorff_dist = max(dist1.max(), dist2.max())

    return chamfer_dist, hausdorff_dist




In [11]:
#@title run - download data

download_data()

Retrieving folder contents


Processing file 1Ci7Az0sL16E3qmyJHhWKYtTRHPEhzEhu bunny.obj
Processing file 1yLoRg5vyAiZ4LXZqePD52uW5aSglapjZ column.obj
Processing file 1ceiguf2Hi9cLXddT5tozHOpnJeIQtVEs dragon_original.obj
Processing file 1-IsqDAFAseW5g_xaYGU5djm_-yOydURS serapis.obj
Processing file 1-0i_FDkwB39zUrPlU4zHR9KL7X5VPAsd utah_teapot.obj


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=1Ci7Az0sL16E3qmyJHhWKYtTRHPEhzEhu
To: /content/data/bunny.obj
100%|██████████| 5.55M/5.55M [00:00<00:00, 44.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1yLoRg5vyAiZ4LXZqePD52uW5aSglapjZ
To: /content/data/column.obj
100%|██████████| 4.04M/4.04M [00:00<00:00, 208MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ceiguf2Hi9cLXddT5tozHOpnJeIQtVEs
To: /content/data/dragon_original.obj
100%|██████████| 74.9M/74.9M [00:00<00:00, 142MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-IsqDAFAseW5g_xaYGU5djm_-yOydURS
To: /content/data/serapis.obj
100%|██████████| 7.03M/7.03M [00:00<00:00, 98.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-0i_FDkwB39zUrPlU4zHR9KL7X5VPAsd
To: /content/data/utah_teapot.obj
100%|██████████| 699k/699k [00:00<00:00, 107MB/s]
Download completed
Retrieving folder contents


Processing file 1kKAM7Ba1or2Kjj4dwBlxjMBsM_QcKgbs bunny.obj
Processing file 1S0J22xgRxjbgC-dRkQTB8kaMZDAiVhtl column.obj
Processing file 1-YPoz8mmY5OS-P8kbBrzK9Wtbk60LpFJ dragon_original.obj
Processing file 1Dki_b4F3AxHuEe9Hi6nydrbjRxKRg_cu serapis.obj
Processing file 1lZhpryLdhyUksUD_McERjj5l4naM5Uxt utah_teapot.obj


Retrieving folder contents completed
Building directory structure
Building directory structure completed
Downloading...
From: https://drive.google.com/uc?id=1kKAM7Ba1or2Kjj4dwBlxjMBsM_QcKgbs
To: /content/processed/bunny.obj
100%|██████████| 18.3M/18.3M [00:00<00:00, 49.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1S0J22xgRxjbgC-dRkQTB8kaMZDAiVhtl
To: /content/processed/column.obj
100%|██████████| 18.2M/18.2M [00:00<00:00, 114MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1-YPoz8mmY5OS-P8kbBrzK9Wtbk60LpFJ
To: /content/processed/dragon_original.obj
100%|██████████| 18.2M/18.2M [00:00<00:00, 122MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1Dki_b4F3AxHuEe9Hi6nydrbjRxKRg_cu
To: /content/processed/serapis.obj
100%|██████████| 18.2M/18.2M [00:00<00:00, 110MB/s] 
Downloading...
From: https://drive.google.com/uc?id=1lZhpryLdhyUksUD_McERjj5l4naM5Uxt
To: /content/processed/utah_teapot.obj
100%|██████████| 18.2M/18.2M [00:00<00:00, 137MB/s]
Download complete

In [None]:
#@title download helper

from google.colab import files

def download_pth_file(file_path):
    """
    Downloads a .pth file to the local machine from Google Colab.

    Args:
    file_path (str): The path to the .pth file in the Colab environment.
    """
    # Ensure the file exists before attempting to download
    try:
        # This will display a browser download prompt
        files.download(file_path)
    except FileNotFoundError:
        print(f"The file {file_path} does not exist.")
    except Exception as e:
        print(f"An error occurred: {e}")

In [None]:
#@title run - train

config = OccConfig(CONFIG_1LOD)

print(f'load config: {config.name}')

for obj_file in os.listdir("processed"):

    print("##################")

    print("obj path: {}".format(obj_file))

    config.current_obj_path = "processed/{}".format(obj_file)
    config.current_obj = obj_file

    config.log_config()

    current_obj_name = config.current_obj.replace(".obj", "")

    trainer = OccTrainer(config)

    print('model has {} params!'.format(trainer.get_num_params()))

    print("##################")

    trainer.run()


In [12]:
#@title run - reconstruction


# Open up a results file to save
with open("results.csv", mode="a+", encoding="utf-8") as file:

    results = csv.writer(file)
    results.writerow(["Mesh", "Type", "Resolution", "Chamfer distance", "Hausdorff distance"])

    configs = ["one_lod", "m_lod", "hash"]
    resolutions = [64, 128, 256]

    for current_object in os.listdir("processed"):
        for name in configs:

            config = OccConfig(name)

            print(f"processing {current_object}, type {name}")
            model = OCC(config)

            # Find model file
            current_object_name = current_object.replace(".obj", "")
            model_name = f"drive/MyDrive/3dv_hw3/{name}_{current_object_name}_final.pth"

            if f"{name}_{current_object_name}_final.pth" not in os.listdir("drive/MyDrive/3dv_hw3"):
                raise FileNotFoundError()

            print(f"found model at {model_name}")

            model.load_state_dict(torch.load(model_name))
            model.eval()
            model.to(config.device)

            pc = trimesh.load(f"processed/{current_object}")
            verts = np.array(pc.vertices)

            for resolution in resolutions:

                print(f"settings resolution at {resolution}")

                grid, transform = generate_grid(verts, resolutions=resolution)
                rec_verts, rec_faces = reconstruct(model, grid, resolution, transform)

                reconstr_path = f"reconstructions/{current_object.split('.')[0]}_{name}_{resolution}.obj"
                os.makedirs(os.path.dirname(reconstr_path), exist_ok=True)
                trimesh.Trimesh(rec_verts, rec_faces).export(reconstr_path)

                gt_path = f"data/{current_object}"

                chamfer_dist, hausdorff_dist = compute_metrics(
                    reconstr_path, gt_path, num_samples=1000000
                )

                results.writerow([current_object, name, resolution, chamfer_dist, hausdorff_dist])

                print(current_object, name, resolution)
                print(f"Chamfer distance: {chamfer_dist:.4f}")
                print(f"Hausdorff distance: {hausdorff_dist:.4f}")
                print("##################")

                break

            break

        break

processing utah_teapot.obj, type one_lod
found model at drive/MyDrive/3dv_hw3/one_lod_utah_teapot_final.pth
settings resolution at 64
utah_teapot.obj one_lod 64
Chamfer distance: 0.0186
Hausdorff distance: 0.2061
##################
