In [1]:
from pathlib import Path
import shutil
import os
import gc
from importlib import reload

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import cv2
from numpy.random import default_rng
import trimesh

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

from tqdm import tqdm

# from kaggle_secrets import UserSecretsClient
import wandb



In [2]:
LOCAL = True

In [3]:
if LOCAL:
    my_secret = "3954148eac0eeb54c223e2c9e928de862ea74f68"
else:
    user_secrets = UserSecretsClient()
    my_secret = user_secrets.get_secret("wandb_key") 
wandb.login(key=my_secret)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mviktor_povazhuk[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/viktor/.netrc


True

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [5]:
DATA_DIR = Path('../input/implicit-repr-xyz/') if not LOCAL else Path('../meshes/')
MODELS_DIR = Path('../input/implicit-repr-models/') if not LOCAL else Path('./models/')
WORK_DIR = Path('/kaggle/working/') if not LOCAL else Path('./working/')

## Data preparation

In [6]:
class PointCloud(Dataset):
    def __init__(self, mesh_path, on_surface_points, keep_aspect_ratio=True):
        super().__init__()

        print("Loading point cloud")
        mesh = trimesh.load_mesh(mesh_path)
        print("Finished loading point cloud")

        coords = mesh.vertices
        self.normals = mesh.vertex_normals

        # Reshape point cloud such that it lies in bounding box of (-1, 1) (distorts geometry, but makes for high
        # sample efficiency)
        coords -= np.mean(coords, axis=0, keepdims=True)
        if keep_aspect_ratio:
            coord_max = np.amax(coords)
            coord_min = np.amin(coords)
        else:
            coord_max = np.amax(coords, axis=0, keepdims=True)
            coord_min = np.amin(coords, axis=0, keepdims=True)

        self.coords = (coords - coord_min) / (coord_max - coord_min)
        self.coords -= 0.5
        self.coords *= 2.

        self.on_surface_points = on_surface_points

    def __len__(self):
        return self.coords.shape[0] // self.on_surface_points

    def __getitem__(self, idx):
        point_cloud_size = self.coords.shape[0]

        off_surface_samples = self.on_surface_points  # **2
        total_samples = self.on_surface_points + off_surface_samples

        # Random coords
        rand_idcs = np.random.choice(point_cloud_size, size=self.on_surface_points)

        on_surface_coords = self.coords[rand_idcs, :]
        on_surface_normals = self.normals[rand_idcs, :]

        off_surface_coords = np.random.uniform(-1, 1, size=(off_surface_samples, 3))
        off_surface_normals = np.ones((off_surface_samples, 3)) * -1

        sdf = np.zeros((total_samples, 1))  # on-surface = 0
        sdf[self.on_surface_points:, :] = -1  # off-surface = -1

        coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
        normals = np.concatenate((on_surface_normals, off_surface_normals), axis=0)

        return {'coords': torch.from_numpy(coords).float()}, {'sdf': torch.from_numpy(sdf).float(),
                                                              'normals': torch.from_numpy(normals).float()}

In [7]:
data_parameters = {
    'batch_size': 1400,
}

In [8]:
sdf_dataset = PointCloud(DATA_DIR / '0.xyz', on_surface_points=data_parameters['batch_size'])

Loading point cloud
Finished loading point cloud


In [9]:
dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)

## Architecture

In [10]:
@torch.no_grad()
def sine_init(m):
    if hasattr(m, 'weight'):
        num_input = m.weight.size(-1)
        # See supplement Sec. 1.5 for discussion of factor 30
        m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)

@torch.no_grad()
def first_layer_sine_init(m):
    if hasattr(m, 'weight'):
        num_input = m.weight.size(-1)
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        m.weight.uniform_(-1 / num_input, 1 / num_input)

In [11]:
class Sine(nn.Module):
    def forward(self, x):
        return torch.sin(30 * x)

In [12]:
class Siren(nn.Module):
    def __init__(self):
        super().__init__()

        in_features = 3
        out_features = 1
        hidden_features = 256
        num_hidden_layers = 3

        self.layers = []

        self.layers.append(nn.Sequential(
            nn.Linear(in_features, hidden_features),
            Sine()
        ))

        for i in range(num_hidden_layers):
            self.layers.append(nn.Sequential(
                nn.Linear(hidden_features, hidden_features),
                Sine()
            ))
        
        self.layers.append(nn.Sequential(
            nn.Linear(hidden_features, out_features),
            Sine()
        ))

        self.layers = nn.Sequential(*self.layers)

        self.layers.apply(sine_init)
        self.layers[0].apply(first_layer_sine_init)
    
    def forward(self, input):
        # Enables us to compute gradients w.r.t. coordinates
        coords_org = input['coords'].clone().detach().requires_grad_(True)
        coords = coords_org

        output = self.layers(coords)

        return {'model_in': coords_org, 'model_out': output}

In [13]:
def calc_gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

In [14]:
def SDFLoss(model_output, gt):
    gt_sdf = gt['sdf']
    gt_normals = gt['normals']

    coords = model_output['model_in']
    pred_sdf = model_output['model_out']

    gradient = calc_gradient(pred_sdf, coords)

    # Wherever boundary_values is not equal to zero, we interpret it as a boundary constraint.
    sdf_constraint = torch.where(gt_sdf != -1, pred_sdf, torch.zeros_like(pred_sdf))
    inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf), torch.exp(-1e2 * torch.abs(pred_sdf)))
    normal_constraint = torch.where(gt_sdf != -1, 1 - F.cosine_similarity(gradient, gt_normals, dim=-1)[..., None],
                                    torch.zeros_like(gradient[..., :1]))
    grad_constraint = torch.abs(gradient.norm(dim=-1) - 1)

    losses = {
        'sdf': torch.abs(sdf_constraint).mean() * 3e3,
        'inter': inter_constraint.mean() * 1e2,
        'normal_constraint': normal_constraint.mean() * 1e2,
        'grad_constraint': grad_constraint.mean() * 5e1
    }

    full_loss = 0
    for loss_name, loss in losses.items():
        full_loss += loss.mean()

    return full_loss

In [15]:
algo_parameters = {
    'lr': 1e-4,
    'epochs': 1,
    'device': device,
    'device_number': 1,
    'continue': False,
}

In [16]:
class SirenPl(pl.LightningModule):
    def __init__(self, model=None):
        super().__init__()
        
        self.model = model if model is not None else Siren()
    
    def forward(self, x):
        y = self.model(x)
        return y

    def training_step(self, batch, batch_idx):
        if self.trainer.global_step == 0:
            if wandb.run is not None:
                wandb.define_metric('train_c_loss', summary='mean')

                wandb.define_metric('val_c_loss', summary='mean')

        model_input, gt = batch

        model_output = self(model_input)
        loss = SDFLoss(model_output, gt)
        
        self.log('train_loss', loss)
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {'loss': loss}
    
    # Called each training
    def configure_optimizers(self):
        optim = torch.optim.Adam(lr=algo_parameters['lr'], params=self.model.parameters())
        
        return optim

## Training

In [17]:
model = SirenPl()

In [18]:
def test_model_run(model):
    print('Model train')
    trainer = pl.Trainer(fast_dev_run=True, accelerator=algo_parameters['device'].type,
                         devices=algo_parameters['device_number'])
    trainer.fit(model=model, train_dataloaders=dataloader)

def test_data_pipeline():
    input, gt = next(iter(sdf_dataset))
    print(input['coords'][:5])
    print(gt['sdf'][:5])
    print(gt['normals'][:5])

In [19]:
test_data_pipeline()
test_model_run(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


tensor([[-0.0128, -0.1823,  0.6272],
        [-0.0874, -0.1921,  0.6574],
        [ 0.0131,  0.2925,  0.1196],
        [-0.0179, -0.1200, -0.7367],
        [-0.0243, -0.1350,  0.6689]])
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]])
tensor([[ 2.7577, -1.3920, -1.2207],
        [ 3.5160,  2.6087,  1.7178],
        [ 2.3228, -2.2470, -1.5867],
        [ 5.0516,  0.3837,  1.5555],
        [ 2.1252,  0.3616, -0.9825]])
Model train



  | Name  | Type  | Params
--------------------------------
0 | model | Siren | 198 K 
--------------------------------
198 K     Trainable params
0         Non-trainable params
198 K     Total params
0.795     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [24]:
def train_model(model):
    ckpt_path = WORK_DIR
    ckpt_path = str(ckpt_path)
    
    wandb_logger = WandbLogger(project='implicit-representation',
                               log_model=True,
                               entity='viktor_povazhuk',
                               tags=['baseline'],
                               notes='Test run')
    wandb_logger.experiment.config.update(algo_parameters)
    
    early_stop_callback = EarlyStopping(monitor="train_loss", min_delta=1e-3, patience=5,
                                        verbose=True, mode="min", strict=True)
    checkpoint_callback = ModelCheckpoint(
        save_top_k=2,
        monitor="global_step",
        mode="max",
        dirpath=ckpt_path,
        filename="ckpt_{epoch}",
        every_n_epochs=10
    )
    
    trainer = pl.Trainer(max_epochs=algo_parameters['epochs'], accelerator=algo_parameters['device'].type,
                         devices=algo_parameters['device_number'], logger=wandb_logger,
                         callbacks=[early_stop_callback, checkpoint_callback],
                         )
    if algo_parameters['continue']:
        trainer.fit(model=model, train_dataloaders=dataloader, ckpt_path=str(WORK_DIR / '_.ckpt'))
    else:
        trainer.fit(model=model, train_dataloaders=dataloader)

In [25]:
train_model(model)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type  | Params
--------------------------------
0 | model | Siren | 198 K 
--------------------------------
198 K     Trainable params
0         Non-trainable params
198 K     Total params
0.795     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Metric train_loss improved. New best score: 864.946
`Trainer.fit` stopped: `max_epochs=1` reached.


In [28]:
torch.save(model.state_dict(), WORK_DIR / 'model.pth')

In [26]:
wandb.finish()

## Evaluation

In [None]:
model = Siren().load_state_dict(torch.load(WORK_DIR / 'model.pth'))

model.eval()

# load mesh -> sample points

with torch.no_grad():
    pass

In [28]:
mesh = trimesh.load_mesh('../meshes/0.obj')

normal = mesh.vertex_normals[0]
vertice = mesh.vertices[0]

print(vertice)
print(normal)
print(np.linalg.norm(normal))

xyz_normal = [-1.397990, -4.958488, -0.919577]
xyz_vertex = [-0.001592, -0.498301, -0.796199]
print(xyz_normal / normal)

[-0.00159161 -0.49830121 -0.79619932]
[-0.26713738 -0.94750222 -0.17571901]
0.9999999999999999
[5.23322502 5.23322047 5.23322435]


In [26]:
# Load the mesh from the OBJ file
mesh = trimesh.load_mesh('../meshes/0.obj')

# s = trimesh.Scene([mesh])

# s.show()

# Calculate the normals at the surface points
normals = mesh.vertex_normals
vertices = mesh.vertices

# # Access the surface points and their corresponding normals
# surface_points = mesh.vertices[mesh.faces]
# surface_normals = normals[mesh.faces]

# # Print the surface points and their normals
# for i in range(5):
#     print('Surface point:', surface_points[i])
#     print('Normal:', surface_normals[i])
#     print()

[-0.00159161 -0.49830121 -0.79619932]
[-0.26713738 -0.94750222 -0.17571901]


array([5.23322502, 5.23322047, 5.23322435])