In [None]:
!pip install trimesh

In [None]:
from pathlib import Path
import shutil
import os
import gc
from importlib import reload

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import cv2

import trimesh

from sklearn.metrics import f1_score

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

from tqdm import tqdm

from kaggle_secrets import UserSecretsClient
import wandb

In [None]:
LOCAL = False

In [None]:
if LOCAL:
    my_secret = "3954148eac0eeb54c223e2c9e928de862ea74f68"
else:
    user_secrets = UserSecretsClient()
    my_secret = user_secrets.get_secret("wandb_key") 
wandb.login(key=my_secret)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
DATA_DIR = Path('../input/implicit-repr-xyz/') if not LOCAL else Path('../meshes/')
MODELS_DIR = Path('../input/implicit-repr-models/') if not LOCAL else Path('./models/')
WORK_DIR = Path('/kaggle/working/') if not LOCAL else Path('./working/')

## Data preparation

In [None]:
def transform_coords(coords, keep_aspect_ratio):
    # Reshape point cloud such that it lies in bounding box of (-1, 1) (distorts geometry, but makes for high
    # sample efficiency)
    coords -= np.mean(coords, axis=0, keepdims=True)
    if keep_aspect_ratio:
        coord_max = np.amax(coords)
        coord_min = np.amin(coords)
    else:
        coord_max = np.amax(coords, axis=0, keepdims=True)
        coord_min = np.amin(coords, axis=0, keepdims=True)

    coords = (coords - coord_min) / (coord_max - coord_min)
    coords -= 0.5
    coords *= 2.
    
    return coords

In [None]:
class PointCloud(Dataset):
    def __init__(self, mesh_path, on_surface_points, keep_aspect_ratio=True):
        super().__init__()

        print("Loading point cloud")
        mesh = trimesh.load_mesh(mesh_path)
        print("Finished loading point cloud")
        
        self.bbox_min, self.bbox_max = mesh.bounds

        self.coords = mesh.vertices
        self.normals = mesh.vertex_normals
        
#         self.coords = transform_coords(coords, keep_aspect_ratio)

        self.on_surface_points = on_surface_points

    def __len__(self):
        return self.coords.shape[0] // self.on_surface_points

    def __getitem__(self, idx):
        point_cloud_size = self.coords.shape[0]

        off_surface_samples = self.on_surface_points
        total_samples = self.on_surface_points + off_surface_samples

        # Random coords
        rand_idcs = np.random.choice(point_cloud_size, size=self.on_surface_points)

        on_surface_coords = self.coords[rand_idcs, :]
        on_surface_normals = self.normals[rand_idcs, :]

        x = np.random.uniform(self.bbox_min[0], self.bbox_max[0], size=off_surface_samples).reshape(-1, 1)
        y = np.random.uniform(self.bbox_min[1], self.bbox_max[1], size=off_surface_samples).reshape(-1, 1)
        z = np.random.uniform(self.bbox_min[2], self.bbox_max[2], size=off_surface_samples).reshape(-1, 1)

        off_surface_coords = np.hstack([x, y, z])
        off_surface_normals = np.ones((off_surface_samples, 3)) * -1

        sdf = np.zeros((total_samples, 1))  # on-surface = 0
        sdf[self.on_surface_points:, :] = -1  # off-surface = -1

        coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
        normals = np.concatenate((on_surface_normals, off_surface_normals), axis=0)

        return {'coords': torch.from_numpy(coords).float()}, {'sdf': torch.from_numpy(sdf).float(),
                                                              'normals': torch.from_numpy(normals).float()}

In [None]:
data_parameters = {
    'batch_size': 500,
}

In [None]:
sdf_dataset = PointCloud(DATA_DIR / '0.obj', on_surface_points=data_parameters['batch_size'])

In [None]:
print(len(sdf_dataset))

In [None]:
dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True)

## Architecture

In [None]:
@torch.no_grad()
def sine_init(m):
    if hasattr(m, 'weight'):
        num_input = m.weight.size(-1)
        # See supplement Sec. 1.5 for discussion of factor 30
        m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)

@torch.no_grad()
def first_layer_sine_init(m):
    if hasattr(m, 'weight'):
        num_input = m.weight.size(-1)
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        m.weight.uniform_(-1 / num_input, 1 / num_input)

In [None]:
class Sine(nn.Module):
    def forward(self, x):
        return torch.sin(30 * x)

In [None]:
class Siren(nn.Module):
    def __init__(self):
        super().__init__()

        in_features = 3
        out_features = 1
        hidden_features = 256
        num_hidden_layers = 3

        self.layers = []

        self.layers.append(nn.Sequential(
            nn.Linear(in_features, hidden_features),
            Sine()
        ))

        for i in range(num_hidden_layers):
            self.layers.append(nn.Sequential(
                nn.Linear(hidden_features, hidden_features),
                Sine()
            ))
        
        self.layers.append(nn.Sequential(
            nn.Linear(hidden_features, out_features),
            Sine()
        ))

        self.layers = nn.Sequential(*self.layers)

        self.layers.apply(sine_init)
        self.layers[0].apply(first_layer_sine_init)
    
    def forward(self, input):
        # Enables us to compute gradients w.r.t. coordinates
        coords_org = input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        output = self.layers(coords)

        return {'model_in': coords_org, 'model_out': output}

In [None]:
def calc_gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

In [None]:
def SDFLoss(model_output, gt):
    gt_sdf = gt['sdf']
    gt_normals = gt['normals']

    coords = model_output['model_in']
    pred_sdf = model_output['model_out']

    gradient = calc_gradient(pred_sdf, coords)

    sdf_constraint = torch.where(gt_sdf != -1, pred_sdf, torch.zeros_like(pred_sdf))
    inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf), torch.exp(-1e2 * torch.abs(pred_sdf)))
    normal_constraint = torch.where(gt_sdf != -1, 1 - F.cosine_similarity(gradient, gt_normals, dim=-1)[..., None],
                                    torch.zeros_like(gradient[..., :1]))
    grad_constraint = torch.abs(gradient.norm(dim=-1) - 1)

    losses = {
        'sdf': torch.abs(sdf_constraint).mean() * 3e3,
        'inter': inter_constraint.mean() * 1e2,
        'normal_constraint': normal_constraint.mean() * 1e2,
        'grad_constraint': grad_constraint.mean() * 5e1
    }

    full_loss = 0
    for loss_name, loss in losses.items():
        full_loss += loss.mean()

    return full_loss

In [None]:
algo_parameters = {
    'lr': 1e-4,
    'epochs': 100,
    'device': device,
    'device_number': 1,
    'continue': False,
}

In [None]:
class SirenPl(pl.LightningModule):
    def __init__(self, model=None):
        super().__init__()
        
        self.model = model if model is not None else Siren()
    
    def forward(self, x):
        y = self.model(x)
        return y

    def training_step(self, batch, batch_idx):
        if self.trainer.global_step == 0:
            if wandb.run is not None:
                wandb.define_metric('train_loss', summary='mean')

        model_input, gt = batch

        model_output = self(model_input)
        loss = SDFLoss(model_output, gt)
        
        self.log('train_loss', loss)
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {'loss': loss}
    
    # Called each training
    def configure_optimizers(self):
        optim = torch.optim.Adam(lr=algo_parameters['lr'], params=self.model.parameters())
        
        return optim

## Training

In [None]:
pl_model = SirenPl()

In [None]:
def test_model_run(model):
    print('Model train')
    trainer = pl.Trainer(fast_dev_run=True, accelerator=algo_parameters['device'].type,
                         devices=algo_parameters['device_number'])
    trainer.fit(model=model, train_dataloaders=dataloader)

def test_data_pipeline():
    input, gt = next(iter(sdf_dataset))
    print(input['coords'][:5])
    print(gt['sdf'][:5])
    print(gt['normals'][:5])

In [None]:
test_data_pipeline()
test_model_run(pl_model)

In [None]:
def train_model(model):
    ckpt_path = WORK_DIR
    ckpt_path = str(ckpt_path)
    
    wandb_logger = WandbLogger(project='implicit-representation',
                               log_model=True,
                               entity='viktor_povazhuk',
                               tags=['baseline'],
                               notes='Test run')
    wandb_logger.experiment.config.update(algo_parameters)
    
    early_stop_callback = EarlyStopping(monitor="train_loss", min_delta=1e-3, patience=100,
                                        verbose=True, mode="min", strict=True)
    checkpoint_callback = ModelCheckpoint(
        save_top_k=2,
        monitor="global_step",
        mode="max",
        dirpath=ckpt_path,
        filename="{epoch}",
        every_n_epochs=10
    )
    best_checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        monitor="train_loss",
        mode="min",
        dirpath=ckpt_path,
        filename="{train_loss:.2f}",
    )
    
    trainer = pl.Trainer(max_epochs=algo_parameters['epochs'], accelerator=algo_parameters['device'].type,
                         devices=algo_parameters['device_number'], logger=wandb_logger,
                         callbacks=[early_stop_callback, checkpoint_callback, best_checkpoint_callback],
                         )
    if algo_parameters['continue']:
        trainer.fit(model=model, train_dataloaders=dataloader, ckpt_path=str(WORK_DIR / '_.ckpt'))
    else:
        trainer.fit(model=model, train_dataloaders=dataloader)

In [None]:
train_model(pl_model)

In [None]:
torch.save(pl_model.model.state_dict(), WORK_DIR / 'model.pth')

In [None]:
wandb.finish()

## Evaluation

In [None]:
model = Siren()
model.load_state_dict(torch.load(WORK_DIR / 'model.pth'))

model.eval();

In [None]:
inp, _ = next(iter(sdf_dataset))
points = inp['coords'][:500]

preds = model({'coords': points})['model_out'].detach().numpy()

print(sum(abs(preds) < 0.001))

In [None]:
mesh = trimesh.load_mesh(DATA_DIR / '0.obj')
rng = np.random.default_rng(seed=0)
num_points = 500

surface_points, face_idx = trimesh.sample.sample_surface(mesh, num_points)
noise = rng.normal(loc=0, scale=1e-2, size=(num_points, 3))
# print(noise.shape, surface_points.shape)
surface_points = surface_points + noise

surface_gt = (trimesh.proximity.signed_distance(mesh, surface_points) > 0).astype(int)

bbox_min, bbox_max = mesh.bounds

x = rng.uniform(bbox_min[0], bbox_max[0], size=num_points).reshape(-1, 1)
y = rng.uniform(bbox_min[1], bbox_max[1], size=num_points).reshape(-1, 1)
z = rng.uniform(bbox_min[2], bbox_max[2], size=num_points).reshape(-1, 1)

volume_points = np.hstack([x, y, z])

volume_gt = (trimesh.proximity.signed_distance(mesh, volume_points) > 0).astype(int)

with torch.no_grad():
    surface_preds = model({'coords': torch.from_numpy(surface_points).float()})['model_out']
    surface_preds = (surface_preds.numpy() > 0).astype(int)

    surface_f1 = f1_score(surface_gt, surface_preds, average='weighted')

    volume_preds = model({'coords': torch.from_numpy(volume_points).float()})['model_out']
    volume_preds = (volume_preds.numpy() > 0).astype(int)

    volume_f1 = f1_score(volume_gt, volume_preds, average='weighted')

    print(f'Surface points F1: {surface_f1}')
    print(f'Volume points F1: {volume_f1}')

In [None]:
f1_score(surface_gt, surface_preds)

In [None]:
mesh = trimesh.load_mesh(DATA_DIR / '0.obj')

normal = mesh.vertex_normals[0]
vertice = mesh.vertices[0]

print(vertice)
print(normal)
print(np.linalg.norm(normal))

xyz_normal = [-1.397990, -4.958488, -0.919577]
xyz_vertex = [-0.001592, -0.498301, -0.796199]
print(xyz_normal / normal)

In [None]:
# Load the mesh from the OBJ file
mesh = trimesh.load_mesh(DATA_DIR / '0.obj')

# s = trimesh.Scene([mesh])

# s.show()

num_points = 2000
points, face_index = trimesh.sample.sample_surface(mesh, num_points)

rng = np.random.default_rng(seed=0)
noise = rng.normal(loc=0, scale=1e-2, size=(num_points, 3))

bbox_min, bbox_max = mesh.bounds

x = rng.uniform(bbox_min[0], bbox_max[0], size=num_points).reshape(-1, 1)
y = rng.uniform(bbox_min[1], bbox_max[1], size=num_points).reshape(-1, 1)
z = rng.uniform(bbox_min[2], bbox_max[2], size=num_points).reshape(-1, 1)

points = np.hstack([x, y, z])

print(points[:5])

sdf = trimesh.proximity.signed_distance(mesh, points)

print(sum(sdf < 0))

## Visual evaluation

In [None]:
!pip install plyfile

In [None]:
import logging
import plyfile
import skimage.measure
import time

In [None]:
mesh = trimesh.load_mesh(DATA_DIR / '0.obj')

bbox_min, bbox_max = mesh.bounds

In [None]:
def create_mesh(
    decoder, filename, N=256, max_batch=64 ** 3, offset=None, scale=None
):
    start = time.time()
    ply_filename = filename

    decoder.eval()

    # NOTE: the voxel_origin is actually the (bottom, left, down) = (z, y, x) corner, not the middle
    bbox_min = np.array([-1, -1, -1])
    bbox_max = np.array([1., 1., 1.])
    voxel_origin = np.flip(bbox_min)
    voxel_sizes = (bbox_max - np.flip(voxel_origin)) / (N - 1)
    voxel_size = max(voxel_sizes)

    overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
    samples = torch.zeros(N ** 3, 4)

    # transform first 3 columns
    # to be the x, y, z index
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index.long() / N) % N
    samples[:, 0] = ((overall_index.long() / N) / N) % N

    # transform first 3 columns
    # to be the x, y, z coordinate
    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]

    num_samples = N ** 3

    samples.requires_grad = False

    head = 0

    while head < num_samples:
        sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()

        samples[head : min(head + max_batch, num_samples), 3] = (
            decoder(sample_subset)
            .squeeze()
            .detach()
            .cpu()
        )
        head += max_batch

    sdf_values = samples[:, 3]
    sdf_values = sdf_values.reshape(N, N, N)

    end = time.time()
    print("sampling takes: %f" % (end - start))

    convert_sdf_samples_to_ply(
        sdf_values.data.cpu(),
        voxel_origin,
        voxel_size,
        ply_filename,
        offset,
        scale,
    )


def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    voxel_grid_origin,
    voxel_size,
    ply_filename_out,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """

    start_time = time.time()

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()

    verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
    try:
        verts, faces, normals, values = skimage.measure.marching_cubes(
            numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3
        )
        print('Was in marching cubes')
    except:
        pass

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

    ply_data = plyfile.PlyData([el_verts, el_faces])
    logging.debug("saving mesh to %s" % (ply_filename_out))
    ply_data.write(ply_filename_out)

    logging.debug(
        "converting to ply format and writing to file took {} s".format(
            time.time() - start_time
        )
    )

In [None]:
class SDFDecoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Siren()
        self.model.load_state_dict(torch.load(MODELS_DIR / 'model_transform.pth'))
        self.model.cuda()

    def forward(self, coords):
        model_in = {"coords": coords}
        return self.model(model_in)["model_out"]

In [None]:
sdf_decoder = SDFDecoder()

In [None]:
torch.cuda.set_per_process_memory_fraction(0.2)
torch.set_num_threads(3)

In [None]:
resolution = 600

In [None]:
mesh_path = WORK_DIR / "test.ply"

create_mesh(sdf_decoder, mesh_path, N=resolution)

In [None]:
mesh = trimesh.load(mesh_path)
# print(mesh)
mesh.show()