# SIREN Tutorial

In this tutorial, we will explore the basic properties and applications of the SIREN MLP. For the theoretical background, please refer to the paper [Implicit Neural Representations with Periodic Activation Functions](https://arxiv.org/abs/2006.09661).

## Table of Contents
* Fitting an image
* Solving Poisson's equation
* Fitting SDF from Point Clouds

In [None]:
import os
import time
from pathlib import Path


import numpy as np
import scipy
import skimage
import matplotlib.pyplot as plt
import logging
import plyfile
import skimage.measure
from tqdm import tqdm
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from torchmetrics.image import PeakSignalNoiseRatio

Define a function to generate coordinate grids.

In [None]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

In [None]:
print(get_mgrid(256).shape)

In [None]:
mgrid = get_mgrid(256)
y_index = 0
for x_index in range(0, 256, 32):
    print(f"The position of index [{x_index}, {y_index}] is {mgrid[x_index*256+y_index]}")

Define functions for differential operations including Laplacian, divergence, and gradient, which are computed numerically using PyTorch's automatic differentiation.

In [None]:
def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)

def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div

def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

<a id='section_1'></a>
## 1. Fitting an Image

In this section, we parametrize a grayscale image $f(x)$ using the SIREN function $\Phi(x)$, where $x$ represents pixel coordinates.

We optimize $\Phi$ using the following loss function:
$$L=\int_{\Omega} \lVert \Phi(\mathbf{x}) - f(\mathbf{x}) \rVert\mathrm{d}\mathbf{x},$$ where $\Omega$ represents the image domain.

#### Define a PyTorch Dataset for Image Fitting
We define a PyTorch dataset that returns pairs of pixel coordinates and their corresponding grayscale values from the cameraman image (a standard test image provided by the scikit-image library).

In [None]:
plt.imshow(skimage.data.camera())

In [None]:
def get_cameraman_tensor(sidelength):
    img = Image.fromarray(skimage.data.camera())
    transform = Compose([
        Resize(sidelength),
        ToTensor(),
        Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
    ])
    img = transform(img)
    return img

class ImageFitting(Dataset):
    def __init__(self, sidelength):
        super().__init__()
        img = get_cameraman_tensor(sidelength)
        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(sidelength, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if idx > 0: raise IndexError

        return self.coords, self.pixels

#### Instantiate Dataset and DataLoader

In [None]:
cameraman = ImageFitting(256)
dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

#### Define Training Loop

In [None]:
def train(model, dataloader, log_dir, total_steps=500, log_steps=50):
    optim = torch.optim.Adam(lr=1e-4, params=model.parameters())
    
    model_input, ground_truth = next(iter(dataloader))
    model_input, ground_truth = model_input.cuda(), ground_truth.cuda()
    
    os.makedirs(log_dir, exist_ok=True)
    with tqdm(total=total_steps) as pbar:
        for step in range(total_steps):
            model_output, coords = model(model_input)
            loss = ((model_output - ground_truth)**2).mean()
        
            if not step % log_steps:
                pbar.set_description("Step %d, Total loss %0.6f" % (step, loss))
                img_grad = gradient(model_output, coords)
                img_laplacian = laplace(model_output, coords)
        
                # Create figure and subplots explicitly
                fig, axes = plt.subplots(1, 3, figsize=(5, 5))
                
                # Plot the three images
                axes[0].imshow(model_output.cpu().view(256,256).detach().numpy())
                axes[0].set_title('Model Output')
                axes[0].axis('off')
                
                axes[1].imshow(img_grad.norm(dim=-1).cpu().view(256,256).detach().numpy())
                axes[1].set_title('Gradient')
                axes[1].axis('off')
                
                axes[2].imshow(img_laplacian.cpu().view(256,256).detach().numpy())
                axes[2].set_title('Laplacian')
                axes[2].axis('off')
                
                # Save the figure
                plt.savefig(log_dir / f"{step}.png", dpi=300, bbox_inches='tight')
                plt.close()  # Close the figure to free memory
    
                psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()
                psnr_value = psnr(model_output, ground_truth)
                
                print(f"Test at step {step}, PSNR: {psnr_value:0.6f}")
        
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            pbar.update(1)

#### Define Baseline Model with ReLU Activation
We implement a baseline model that uses ReLU activation functions for implicit neural representation.

In [None]:
class ReLULayer(nn.Module):

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False):
        super().__init__()
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.relu = torch.nn.ReLU()

    def forward(self, input):
        return self.relu(self.linear(input))

class ReLUBaseline(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False):
        super().__init__()

        self.net = []
        self.net.append(ReLULayer(in_features, hidden_features,
                                  is_first=True))

        for i in range(hidden_layers):
            self.net.append(ReLULayer(hidden_features, hidden_features,
                                      is_first=False))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            self.net.append(final_linear)
        else:
            self.net.append(ReLULayer(hidden_features, out_features,
                                      is_first=False))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

#### Train the ReLU Baseline Model

In [None]:
# Define a model
relu_baseline = ReLUBaseline(
    in_features=2,
    out_features=1,
    hidden_features=256,
    hidden_layers=3,
    outermost_linear=True)
relu_baseline.cuda()

# Define a dataloader
cameraman = ImageFitting(256)
dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)

# Training loop
log_dir = Path("./results/siren/img_relu")
train(relu_baseline, dataloader, log_dir)

#### Define SIREN Model

In [None]:
class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        ######## Implement from here ########
        ####### End of Implementation #######

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate

class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30.):
        super().__init__()

        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output, coords

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)

                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()

                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)

                if retain_grad:
                    x.retain_grad()

            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

#### Exercise 1: Train the SIREN Model

In [None]:
log_dir = Path("./results/siren/img_siren")

######## Implement from here ########

####### End of Implementation #######

<a id='section_3'></a>
## 2. Solving Poisson's Equation

This section demonstrates how to reconstruct an original image when only gradient information is available.

We optimize $\Phi$ using the following loss function:
$$L=\int_{\Omega} \lVert \nabla\Phi(\mathbf{x}) - \nabla f(\mathbf{x}) \rVert\mathrm{d}\mathbf{x},$$ where $\Omega$ represents the image domain.

#### Define PyTorch Dataset for Solving Poisson's Equation
Similar to Section 1, we use the cameraman image. This dataset returns pixel coordinates along with their corresponding grayscale values, gradients, and Laplacian.

In [None]:
import scipy.ndimage

class PoissonEqn(Dataset):
    def __init__(self, sidelength):
        super().__init__()
        img = get_cameraman_tensor(sidelength)

        ######## Implement from here ########
        # Compute gradient and laplacian
        # Refer: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.sobel.html
        grads_x =
        grads_y = 
        ####### End of Implementation #######
        grads_x, grads_y = torch.from_numpy(grads_x), torch.from_numpy(grads_y)

        self.grads = torch.stack((grads_x, grads_y), dim=-1).view(-1, 2)
        self.laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]
        self.laplace = torch.from_numpy(self.laplace)

        self.pixels = img.permute(1, 2, 0).view(-1, 1)
        self.coords = get_mgrid(sidelength, 2)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.coords, {'pixels':self.pixels, 'grads':self.grads, 'laplace':self.laplace}

In [None]:
img = get_cameraman_tensor(256)

grads_x = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]
grads_y = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]
grads_x, grads_y = torch.from_numpy(grads_x), torch.from_numpy(grads_y)

fig, axes = plt.subplots(1,2, figsize=(12,6))
axes[0].imshow(grads_x)
axes[1].imshow(grads_y)
plt.show()

#### Instantiate Dataset and SIREN Model

In [None]:
def gradients_mse(model_output, coords, gt_gradients):
    # compute gradients on the model
    gradients = gradient(model_output, coords)
    # compare them with the ground-truth
    gradients_loss = torch.mean((gradients - gt_gradients).pow(2).sum(-1))
    return gradients_loss

def train(model, dataloader, log_dir, total_steps=500, log_steps=50):
    optim = torch.optim.Adam(lr=1e-4, params=relu_baseline.parameters())
    
    model_input, ground_truth = next(iter(dataloader))
    model_input = model_input.cuda()
    gt = {key: value.cuda() for key, value in ground_truth.items()}
    
    os.makedirs(log_dir, exist_ok=True)
    with tqdm(total=total_steps) as pbar:
        for step in range(total_steps):
            model_output, coords = model(model_input)
            loss = gradients_mse(model_output, coords, gt['grads'])
        
            if not step % log_steps:
                pbar.set_description("Step %d, Total loss %0.6f" % (step, loss))
                img_grad = gradient(model_output, coords)
                img_laplacian = laplace(model_output, coords)
        
                # Create figure and subplots explicitly
                fig, axes = plt.subplots(1, 3, figsize=(5, 5))
                
                # Plot the three images
                axes[0].imshow(model_output.cpu().view(128,128).detach().numpy())
                axes[0].set_title('Model Output')
                axes[0].axis('off')
                
                axes[1].imshow(img_grad.norm(dim=-1).cpu().view(128,128).detach().numpy())
                axes[1].set_title('Gradient')
                axes[1].axis('off')
                
                axes[2].imshow(img_laplacian.cpu().view(128,128).detach().numpy())
                axes[2].set_title('Laplacian')
                axes[2].axis('off')
                
                # Save the figure
                plt.savefig(log_dir / f"{step}.png", dpi=300, bbox_inches='tight')
                plt.close()  # Close the figure to free memory
    
                psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()
                psnr_value = psnr(model_output, gt['pixels'])
                
                print(f"Test at step {step}, PSNR: {psnr_value:0.6f}")
        
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            pbar.update(1)

In [None]:
cameraman_poisson = PoissonEqn(128)
dataloader = DataLoader(cameraman_poisson, batch_size=1, pin_memory=True, num_workers=0)

#### Exercise 2: Train the ReLU Baseline Model

In [None]:
relu_baseline = ReLUBaseline(
    in_features=2,
    out_features=1,
    hidden_features=256,
    hidden_layers=3,
    outermost_linear=True)
relu_baseline.cuda()

log_dir = Path("./results/siren/img_poisson_relu")
train(relu_baseline, dataloader, log_dir)

#### Exercise 3: Train SIREN Model on Poisson's Equation

In [None]:
log_dir = Path("./results/siren/img_poisson_siren")
######## Implement from here ########

####### End of Implementation #######

## 3. Fitting SDF from Point Clouds

This section demonstrates how to reconstruct a 3D surface from a given point cloud.
The 3D surface is represented as a Signed Distance Field (SDF) $\Phi: x \rightarrow s$, where $x$ represents 3D coordinates and $s$ represents the distance to the surface.

We optimize $\Phi$ using the following loss function $L$:
$$L=\int_{\Omega} \lVert |\nabla\Phi(x)| - 1 \rVert\mathrm{d}\mathbf{x} + \int_{\Omega_0}\lVert\Phi(x)\rVert + (1-\nabla\Phi(x) \cdot n(x))\mathrm{d}\mathbf{x} + \int_{\Omega\setminus\Omega_0} \Psi(\Phi(x))\mathrm{d}\mathbf{x}$$

#### Define PyTorch Dataset

* We use the Thai Statue from the 3D Stanford Model as our dataset (located at `/data/thai_statue.xyz`).
* Visualization of xyz files is possible using MeshLab ([link](https://www.meshlab.net/)).
* The dataset samples a specified number of points (`on_surface_points`) from the input point cloud and an equal number of points from the entire domain, returning the SDF and normal values for these points.

In [None]:
class PointCloud(Dataset):
    def __init__(self, on_surface_points, keep_aspect_ratio=True):
        super().__init__()

        print("Loading point cloud")
        point_cloud = np.genfromtxt('/data/thai_statue.xyz')
        print("Finished loading point cloud")

        coords = point_cloud[:, :3]
        self.normals = point_cloud[:, 3:]

        # Reshape point cloud such that it lies in bounding box of (-1, 1) (distorts geometry, but makes for high
        # sample efficiency)
        coords -= np.mean(coords, axis=0, keepdims=True)
        if keep_aspect_ratio:
            coord_max = np.amax(coords)
            coord_min = np.amin(coords)
        else:
            coord_max = np.amax(coords, axis=0, keepdims=True)
            coord_min = np.amin(coords, axis=0, keepdims=True)

        self.coords = (coords - coord_min) / (coord_max - coord_min)
        self.coords -= 0.5
        self.coords *= 2.

        self.on_surface_points = on_surface_points

    def __len__(self):
        return self.coords.shape[0] // self.on_surface_points

    def __getitem__(self, idx):
        point_cloud_size = self.coords.shape[0]

        off_surface_samples = self.on_surface_points  # **2
        total_samples = self.on_surface_points + off_surface_samples

        # Random coords
        rand_idcs = np.random.choice(point_cloud_size, size=self.on_surface_points)

        on_surface_coords = self.coords[rand_idcs, :]
        on_surface_normals = self.normals[rand_idcs, :]

        off_surface_coords = np.random.uniform(-1, 1, size=(off_surface_samples, 3))
        off_surface_normals = np.ones((off_surface_samples, 3)) * -1

        sdf = np.zeros((total_samples, 1))  # on-surface = 0
        sdf[self.on_surface_points:, :] = -1  # off-surface = -1

        coords = np.concatenate((on_surface_coords, off_surface_coords), axis=0)
        normals = np.concatenate((on_surface_normals, off_surface_normals), axis=0)

        return {'coords': torch.from_numpy(coords).float()}, {'sdf': torch.from_numpy(sdf).float(),
                                                              'normals': torch.from_numpy(normals).float()}

#### Define Mesh Export Functions

In [None]:
import logging
import plyfile
import skimage.measure


def create_mesh(
    decoder, filename, N=256, max_batch=64 ** 3, offset=None, scale=None
):
    start = time.time()
    ply_filename = filename

    decoder.eval()

    # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
    voxel_origin = [-1, -1, -1]
    voxel_size = 2.0 / (N - 1)

    overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
    samples = torch.zeros(N ** 3, 4)

    # transform first 3 columns
    # to be the x, y, z index
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index.long() / N) % N
    samples[:, 0] = ((overall_index.long() / N) / N) % N

    # transform first 3 columns
    # to be the x, y, z coordinate
    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]

    num_samples = N ** 3

    samples.requires_grad = False

    head = 0

    while head < num_samples:
        print(head)
        sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()

        samples[head : min(head + max_batch, num_samples), 3] = (
            decoder(sample_subset)[0]
            .squeeze()#.squeeze(1)
            .detach()
            .cpu()
        )
        head += max_batch

    sdf_values = samples[:, 3]
    sdf_values = sdf_values.reshape(N, N, N)

    end = time.time()
    print("sampling takes: %f" % (end - start))

    convert_sdf_samples_to_ply(
        sdf_values.data.cpu(),
        voxel_origin,
        voxel_size,
        ply_filename,
        offset,
        scale,
    )


def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    voxel_grid_origin,
    voxel_size,
    ply_filename_out,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """

    start_time = time.time()

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()

    verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
    verts, faces, normals, values = skimage.measure.marching_cubes(
        numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3
    )

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

    ply_data = plyfile.PlyData([el_verts, el_faces])
    logging.debug("saving mesh to %s" % (ply_filename_out))
    ply_data.write(ply_filename_out)

    logging.debug(
        "converting to ply format and writing to file took {} s".format(
            time.time() - start_time
        )
    )

#### Instantiate Dataset and SIREN Model

In [None]:
thai_statue = PointCloud(250000)
dataloader = DataLoader(thai_statue, batch_size=1, pin_memory=True, num_workers=0)

sdf_siren = Siren(in_features=3, out_features=1, hidden_features=256,
                  hidden_layers=3, outermost_linear=True)
sdf_siren.cuda()

#### Define Loss Function and Training Loop

In [None]:
def sdf_loss(pred_sdf, coords, gt_sdf, gt_normals):
    '''
       x: batch of input coordinates
       y: usually the output of the trial_soln function
       '''

    _gradient = gradient(pred_sdf, coords)

    # Wherever boundary_values is not equal to zero, we interpret it as a boundary constraint.
    sdf_constraint = torch.where(gt_sdf != -1, pred_sdf, torch.zeros_like(pred_sdf))
    inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf), torch.exp(-1e2 * torch.abs(pred_sdf)))
    normal_constraint = torch.where(gt_sdf != -1, 1 - F.cosine_similarity(_gradient, gt_normals, dim=-1)[..., None],
                                    torch.zeros_like(_gradient[..., :1]))
    grad_constraint = torch.abs(_gradient.norm(dim=-1) - 1)
    return {'sdf': torch.abs(sdf_constraint).mean() * 3e3,
            'inter': inter_constraint.mean() * 1e2,
            'normal_constraint': normal_constraint.mean() * 1e2,
            'grad_constraint': grad_constraint.mean() * 5e1}


def train(model, dataloader, log_dir, total_steps=2000, log_steps=250):
    optim = torch.optim.Adam(lr=1e-4, params=model.parameters())
    
    model_input, ground_truth = next(iter(dataloader))
    model_input, sdf_gt, normals_gt = model_input['coords'].cuda(), ground_truth['sdf'].cuda(), ground_truth['normals'].cuda()
    
    os.makedirs(log_dir, exist_ok=True)
    with tqdm(total=total_steps) as pbar:
        for step in range(total_steps):
            sdf_pred, coords = model(model_input)
            losses = sdf_loss(sdf_pred, coords, sdf_gt, normals_gt)
            
            train_loss = 0
            log_message = "Step %d, " % step
            for loss_name, loss in losses.items():
                single_loss = loss.mean()
                train_loss += single_loss
                log_message += ("%s %0.2f, " % (loss_name, single_loss.item()))
            pbar.set_description(log_message)

            if not step % log_steps:
                create_mesh(sdf_siren, log_dir / f"{step}.ply")
        
            optim.zero_grad()
            train_loss.backward()
            optim.step()
            
            pbar.update(1)

#### Exercise 4: Train ReLU Baseline Model

In [None]:
log_dir = Path("./results/siren/pc_relu")
######## Implement from here ########

####### End of Implementation #######

#### Exercise 5: Train SIREN Model

In [None]:
log_dir = Path("./results/siren/pc_siren")
######## Implement from here ########

####### End of Implementation #######