<a href="https://colab.research.google.com/github/reshalfahsi/novel-view-synthesis/blob/master/Novel_View_Synthesis_Using_NeRF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Novel View Synthesis Using NeRF**

## **Important Libraries**

### **Install**

In [None]:
!pip install -q --no-cache-dir lightning torchmetrics moviepy wget

### **Import**

In [None]:
try:
    import lightning as L
except:
    import lightning as L

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from torchmetrics.image import (
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure
)

from google.colab.patches import cv2_imshow
from IPython.display import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data

import numpy as np
from tqdm import tqdm
import imageio.v2 as imageio
import matplotlib.pyplot as plt

from moviepy.editor import VideoFileClip

from glob import glob

import os
import cv2
import wget
import math
import time
import random
import warnings

warnings.filterwarnings("ignore")

# %matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

## **Configuration**

In [None]:
LEARNING_RATE = 3.1e-4
NEAR = 2.
BATCH_SIZE = 4
FAR = 6.
POS_ENCODE_DIMS = 16
NUM_SAMPLES = 32
NeRF_WIDTH = 108
EPOCH = 1600

## **Dataset**

### **Download and Process**

In [None]:
if not os.path.exists('tiny_nerf_data.npz'):
    DATASET_URL = (
        "http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz"
    )
    wget.download(DATASET_URL, 'tiny_nerf_data.npz')

dataset = np.load("./tiny_nerf_data.npz")

IMAGES = dataset["images"]
POSES = dataset["poses"]
FOCAL = float(dataset["focal"])

del dataset

### **Utils**

In [None]:
class ViewSynthesisDataset(data.Dataset):
    def __init__(self):
        super().__init__()

        self.split = None

        self.num_images, self.H, self.W, C = IMAGES.shape
        self._split_index = int(0.8 * self.num_images)

    def __call__(self, split=None):
        assert (
            split is not None
        ), "Please specify dataset split: train, val, and test"
        self.split = split
        return self

    def __len__(self):
        if self.split is None:
            raise KeyError("'split' is not defined yet")
        if self.split == "train":
            return len(IMAGES[: int(0.8 * self._split_index)])
        elif self.split == "val":
            return len(IMAGES[int(0.8 * self._split_index) : self._split_index])
        elif self.split == "test":
            return len(IMAGES[self._split_index : self.num_images])

    def get_rays(self, pose):
        xs = torch.arange(self.W) - (self.W / 2 - 0.5)
        ys = torch.arange(self.H) - (self.H / 2 - 0.5)
        xs, ys = torch.meshgrid(xs, -ys, indexing="xy")

        zs = torch.full_like(xs, -FOCAL)

        directions = torch.stack(
            [xs, ys, zs], dim=-1
        )
        directions = directions / FOCAL

        # Get the camera matrix.
        camera_matrix = pose[:3, :3]
        height_width_focal = pose[:3, -1]

        ray_directions = torch.einsum("ij,hwj->hwi", camera_matrix, directions)
        ray_origins = height_width_focal.expand(ray_directions.shape)

        return ray_origins, ray_directions

    def render_rays(self, ray_origins, ray_directions):
        t_vals = torch.linspace(NEAR, FAR, NUM_SAMPLES)

        noise = torch.rand(
            torch.Size(list(ray_origins.shape[:-1]) + [NUM_SAMPLES])
        ) * (FAR - NEAR) / NUM_SAMPLES
        t_vals = t_vals + noise

        # Equation: r(t) = o + td -> Building the "r" here.
        rays = ray_origins[..., None, :] + (
            ray_directions[..., None, :] * t_vals[..., None]
        )

        return rays, t_vals

    def __getitem__(self, idx):
        if self.split is None:
            raise KeyError("'split' is not defined yet")
        if self.split == "train":
            image = torch.Tensor(
                IMAGES[: int(0.8 * self._split_index)][idx]
            )
            pose = torch.Tensor(
                POSES[: int(0.8 * self._split_index)][idx]
            )
        elif self.split == "val":
            image = torch.Tensor(
                IMAGES[int(0.8 * self._split_index) : self._split_index][idx]
            )
            pose = torch.Tensor(
                POSES[int(0.8 * self._split_index) : self._split_index][idx]
            )
        elif self.split == "test":
            image = torch.Tensor(
                IMAGES[self._split_index : self.num_images][idx]
            )
            pose = torch.Tensor(
                POSES[self._split_index : self.num_images][idx]
            )

        ray_origins, ray_directions = self.get_rays(pose)
        rays, t_vals = self.render_rays(ray_origins, ray_directions)

        return image, pose, rays, t_vals

In [None]:
ViewSynthesisDataset = ViewSynthesisDataset()

## **Model**

### **Utils**

In [None]:
class AvgMeter(object):
    def __init__(self, complete=False):
        self.reset()
        self.complete = complete

    def reset(self):
        self.scores = []

    def update(self, val):
        self.scores.append(val)

    @property
    def score(self):
        score = [s.numpy() for s in self.scores]
        return score

    def show(self):
        scores = torch.stack(self.scores)

        if self.complete:
            mean = torch.mean(scores)
            std = torch.std(scores)
            return mean, std
        else:
            return torch.mean(scores)

### **Positional Encoding**

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, length):
        super().__init__()
        self.length = length

    def forward(self, x):
        _x = [x]
        for l in range(self.length):
            _x.append(torch.sin(2 ** l * torch.pi * x))
            _x.append(torch.cos(2 ** l * torch.pi * x))
        return torch.cat(_x, dim=-1)

### **NeRF**

In [None]:
class NeRF(nn.Module):
    def __init__(self, bias=True):
        super().__init__()

        self.positional_encoding = PositionalEncoding(POS_ENCODE_DIMS)

        self.mlp0 = nn.Sequential(
            nn.Linear(3 + 3 * 2 * POS_ENCODE_DIMS, NeRF_WIDTH, bias=bias),
            nn.ReLU(True),
            nn.Linear(NeRF_WIDTH, NeRF_WIDTH, bias=bias),
            nn.ReLU(True),
            nn.Linear(NeRF_WIDTH, NeRF_WIDTH, bias=bias),
            nn.ReLU(True),
        )

        self.mlp1 = nn.Sequential(
            nn.Linear(3 + 3 * 2 * POS_ENCODE_DIMS + NeRF_WIDTH, NeRF_WIDTH, bias=bias),
            nn.ReLU(True),
            nn.Linear(NeRF_WIDTH, NeRF_WIDTH, bias=bias),
            nn.ReLU(True),
        )

        self.sigma = nn.Sequential(
            nn.Linear(NeRF_WIDTH, NeRF_WIDTH // 2, bias=bias),
            nn.ReLU(True),
            nn.Linear(NeRF_WIDTH // 2, 1),
            nn.ReLU(True),

        )
        self.color = nn.Sequential(
            nn.Linear(NeRF_WIDTH, NeRF_WIDTH // 2, bias=bias),
            nn.ReLU(True),
            nn.Linear(NeRF_WIDTH // 2, 3),
            nn.Sigmoid(),
        )

    def forward(self, xs):
        xs = self.positional_encoding(xs)

        output = self.mlp0(xs)
        output = self.mlp1(torch.cat([xs, output], dim=-1))
        sigma_output = self.sigma(output)
        color_output = self.color(output)

        return color_output, sigma_output

In [None]:
MODEL = NeRF
MODEL_NAME = MODEL.__name__

### **Wrapper**

In [None]:
class ModelWrapper(L.LightningModule):
    def __init__(
        self,
        arch,
        lr,
        batch_size,
        max_epoch,
        n_visualization_frames=200,
    ):
        super().__init__()

        self.arch = arch

        self.lr = lr
        self.lr_now = 1e1

        self.batch_size = batch_size
        self.max_epoch = max_epoch

        self.automatic_optimization = False

        self.train_psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.train_ssim = StructuralSimilarityIndexMeasure(data_range=1.0)

        self.val_psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1.0)

        self.test_psnr = PeakSignalNoiseRatio(data_range=1.0)
        self.test_ssim = StructuralSimilarityIndexMeasure(data_range=1.0)

        self.train_loss = list()
        self.val_loss = list()

        self.train_psnr_list = list()
        self.train_ssim_list = list()

        self.val_psnr_list = list()
        self.val_ssim_list = list()

        self.train_loss_recorder = AvgMeter()
        self.val_loss_recorder = AvgMeter()

        self.train_psnr_recorder = AvgMeter()
        self.train_ssim_recorder = AvgMeter()

        self.val_psnr_recorder = AvgMeter()
        self.val_ssim_recorder = AvgMeter()

        self.test_psnr_recorder = AvgMeter()
        self.test_ssim_recorder = AvgMeter()

        self.loss_function = nn.MSELoss()

        self.sanity_check_counter = 1
        self.n_visualization_frames = n_visualization_frames

    def forward(self, ray, t_vals):
        color, sigma = self.arch(ray)
        color = color.reshape(
            ray.shape[0],
            ViewSynthesisDataset.H,
            ViewSynthesisDataset.W,
            NUM_SAMPLES,
            3,
        )
        sigma = sigma.reshape(
            ray.shape[0],
            ViewSynthesisDataset.H,
            ViewSynthesisDataset.W,
            NUM_SAMPLES,
            1,
        )

        # Get the distance of adjacent intervals.
        delta = t_vals[..., 1:] - t_vals[..., :-1]
        delta = torch.concat(
            [
                delta,
                torch.full(
                    (
                        ray.shape[0],
                        ViewSynthesisDataset.H,
                        ViewSynthesisDataset.W,
                        1,
                    ),
                    1e10,
                ).to("cpu" if not torch.cuda.is_available() else "cuda"),
            ],
            dim=-1,
        )

        alpha = 1.0 - torch.exp(-sigma * delta[..., None])

        # Get transmittance.
        exp_term = 1.0 - alpha
        epsilon = 1e-10
        transmittance = torch.cumprod(exp_term + epsilon, dim=-1)
        weights = alpha * transmittance

        color = torch.sum(weights * color, dim=-2)
        depth = torch.sum(weights * t_vals[..., None], dim=-2)

        return color, depth

    def training_step(self, batch, batch_nb):
        image, pose, rays, t_vals = batch

        color, _ = self(rays, t_vals)

        loss = self.loss_function(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))

        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        self.log("train_loss", loss, prog_bar=True)
        self.train_loss_recorder.update(loss.data)

        self.train_psnr.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
        psnr = self.train_psnr.compute().data.cpu()

        self.train_ssim.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
        ssim = self.train_ssim.compute().data.cpu()

        self.log("train_psnr", psnr, prog_bar=True)
        self.log("train_ssim", ssim, prog_bar=True)

        self.train_psnr_recorder.update(psnr)
        self.train_ssim_recorder.update(ssim)

    def on_train_epoch_end(self):
        self.train_loss.append(self.train_loss_recorder.show().data.cpu().numpy())
        self.train_loss_recorder = AvgMeter()
        self.train_psnr_list.append(self.train_psnr_recorder.show().data.cpu().numpy())
        self.train_psnr_recorder = AvgMeter()
        self.train_ssim_list.append(self.train_ssim_recorder.show().data.cpu().numpy())
        self.train_ssim_recorder = AvgMeter()

    def validation_step(self, batch, batch_nb):
        image, pose, rays, t_vals = batch

        color, depth = self(rays, t_vals)

        loss = self.loss_function(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))

        # Plot the rgb, depth and the loss plot.
        if self.sanity_check_counter == 0:
            self.log("val_loss", loss, prog_bar=True)
            self.val_loss_recorder.update(loss.data)

            self.val_psnr.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
            psnr = self.val_psnr.compute().data.cpu()

            self.val_ssim.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
            ssim = self.val_ssim.compute().data.cpu()

            self.log("val_psnr", psnr, prog_bar=True)
            self.log("val_ssim", ssim, prog_bar=True)

            self.val_psnr_recorder.update(psnr)
            self.val_ssim_recorder.update(ssim)

            if (self.current_epoch + 1) % (
                self.max_epoch // self.n_visualization_frames
            ) == 0:
                fig, ax = plt.subplots(
                    nrows=1,
                    ncols=4,
                    figsize=(22, 4.8),
                )

                image = image[0].data.cpu().numpy()
                ax[0].imshow(np.clip(image, 0.0, 1.0))
                ax[0].set_title(f"Ground Truth: {self.current_epoch}")

                color = color[0].data.cpu().numpy()
                ax[1].imshow(np.clip(color, 0.0, 1.0))
                ax[1].set_title(f"Predicted: {self.current_epoch}")

                depth = depth[0].data.cpu().numpy()
                ax[2].imshow(np.clip(depth, 0.0, 1.0))
                ax[2].set_title(f"Depth Map: {self.current_epoch}")

                ax[3].plot(self.val_loss)
                ax[3].grid()
                ax[3].set_title(f"Loss Plot: {self.current_epoch}")

                fig.tight_layout()
                fig.savefig(
                    f"experiment/training/images/{str(self.current_epoch).zfill(5)}.png"
                )
                fig.clf()

    def on_validation_epoch_end(self):
        if self.sanity_check_counter == 0:
            psnr = self.val_psnr_recorder.show()
            sch = self.lr_schedulers()
            sch.step(psnr)

            lr_now_ = self.optimizers().param_groups[0]["lr"]
            if self.lr_now > lr_now_:
                self.lr_now = lr_now_
                print(
                    f"[{MODEL_NAME}] Learning Rate Changed: {lr_now_} - Epoch: {self.current_epoch}"
                )

            self.val_loss.append(self.val_loss_recorder.show().data.cpu().numpy())
            self.val_loss_recorder = AvgMeter()
            self.val_psnr_list.append(psnr.data.cpu().numpy())
            self.val_psnr_recorder = AvgMeter()
            self.val_ssim_list.append(self.val_ssim_recorder.show().data.cpu().numpy())
            self.val_ssim_recorder = AvgMeter()
        else:
            self.sanity_check_counter -= 1

    def test_step(self, batch, batch_nb):
        image, pose, rays, t_vals = batch

        color, _ = self(rays, t_vals)

        self.test_psnr.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
        psnr = self.test_psnr.compute().data.cpu()

        self.test_ssim.update(color.permute(0, 3, 1, 2), image.permute(0, 3, 1, 2))
        ssim = self.test_ssim.compute().data.cpu()

        self.log("test_psnr", psnr, prog_bar=True)
        self.log("test_ssim", ssim, prog_bar=True)

        self.test_psnr_recorder.update(psnr)
        self.test_ssim_recorder.update(ssim)

    def on_train_end(self):
        # Loss
        loss_img_file = f"experiment/training/{MODEL_NAME}_loss_plot.png"
        plt.figure(figsize=(6.4, 4.8))
        plt.plot(self.train_loss, color="r", label="train")
        plt.plot(self.val_loss, color="b", label="validation")
        plt.title("Loss Curves")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid()
        plt.savefig(loss_img_file)
        plt.clf()

        # Evaluation Metrics
        evaluation_metric_img_file = f"experiment/training/{MODEL_NAME}_psnr_plot.png"
        plt.figure(figsize=(6.4, 4.8))
        plt.plot(self.train_psnr_list, color="r", label="train")
        plt.plot(self.val_psnr_list, color="b", label="validation")
        plt.title("PSNR Curves")
        plt.xlabel("Epoch")
        plt.ylabel("PSNR")
        plt.legend()
        plt.grid()
        plt.savefig(evaluation_metric_img_file)
        plt.clf()

        evaluation_metric_img_file = f"experiment/training/{MODEL_NAME}_ssim_plot.png"
        plt.figure(figsize=(6.4, 4.8))
        plt.plot(self.train_ssim_list, color="r", label="train")
        plt.plot(self.val_ssim_list, color="b", label="validation")
        plt.title("SSIM Curves")
        plt.xlabel("Epoch")
        plt.ylabel("SSIM")
        plt.legend()
        plt.grid()
        plt.savefig(evaluation_metric_img_file)
        plt.clf()

    def train_dataloader(self):
        return data.DataLoader(
            dataset=ViewSynthesisDataset("train"),
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        return data.DataLoader(
            dataset=ViewSynthesisDataset("val"),
            batch_size=self.batch_size,
            shuffle=False,
        )

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.arch.parameters(),
            lr=self.lr,
        )

        lr_scheduler = {
            "scheduler": optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode="max",
                factor=0.25,
                patience=int(0.025 * self.max_epoch),
            ),
            "name": "lr_scheduler",
        }

        return [optimizer], [lr_scheduler]

In [None]:
WRAPPER = ModelWrapper

In [None]:
EXPERIMENT_DIR = "experiment/"
BEST_MODEL_PATH = os.path.join(EXPERIMENT_DIR, f"{MODEL_NAME}_best.ckpt")

## **Training**

In [None]:
os.makedirs("experiment", exist_ok=True)
os.makedirs("experiment/training", exist_ok=True)
os.makedirs("experiment/training/images", exist_ok=True)

In [None]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

In [None]:
model, trainer, checkpoint, early_stopping = None, None, None, None


def _train_loop():
    seed_everything(SEED, workers=True)

    model = MODEL()
    model = WRAPPER(model, LEARNING_RATE, BATCH_SIZE, EPOCH)

    checkpoint = ModelCheckpoint(
        monitor='val_psnr',
        dirpath=EXPERIMENT_DIR,
        mode='max',
        filename=f"{MODEL_NAME}_best",
    )
    print(MODEL_NAME)
    early_stopping = EarlyStopping(
        monitor="val_psnr",
        min_delta=0.00,
        patience=int(0.1 * EPOCH),
        verbose=False,
        mode="max",
    )

    if os.path.exists(BEST_MODEL_PATH):
        ckpt_path = BEST_MODEL_PATH
    else:
        ckpt_path = None

    trainer = Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=EPOCH,
        logger=False,
        callbacks=[checkpoint, early_stopping],
        log_every_n_steps=5,
    )
    trainer.fit(model, ckpt_path=ckpt_path)

_train_loop()

In [None]:
img = cv2.imread(f"experiment/training/{MODEL_NAME}_loss_plot.png")
cv2_imshow(img)

img = cv2.imread(f"experiment/training/{MODEL_NAME}_psnr_plot.png")
cv2_imshow(img)

img = cv2.imread(f"experiment/training/{MODEL_NAME}_ssim_plot.png")
cv2_imshow(img)

filenames = glob("experiment/training/images/*.png")
filenames = sorted(filenames)
image = list()
for filename in tqdm(filenames):
    image.append(imageio.imread(filename))
kargs = {
    "duration": 1.25,
    "loop": 0,
}
imageio.mimsave("experiment/training/result.gif", image, "GIF", **kargs)
with open("experiment/training/result.gif", "rb") as f:
    display(Image(data=f.read(), format="gif"))

## **Testing**

In [None]:
trainer = Trainer(accelerator='auto', logger=False)
model = MODEL()
model = WRAPPER(model, LEARNING_RATE, BATCH_SIZE, EPOCH)
trainer.test(
    model=model,
    ckpt_path=BEST_MODEL_PATH,
    dataloaders=data.DataLoader(
        dataset=ViewSynthesisDataset("test"),
        batch_size=1,
        shuffle=False,
    ),
)

## **Inference**

### **Utils**

In [None]:
def get_translation_t(t):
    """Get the translation matrix for movement in t."""
    matrix = np.array([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1],
    ])
    return torch.tensor(matrix, dtype=torch.float32)


def get_rotation_phi(phi):
    """Get the rotation matrix for movement in phi."""
    matrix = np.array([
        [1, 0, 0, 0],
        [0, math.cos(phi), -math.sin(phi), 0],
        [0, math.sin(phi), math.cos(phi), 0],
        [0, 0, 0, 1],
    ])
    return torch.tensor(matrix, dtype=torch.float32)


def get_rotation_theta(theta):
    """Get the rotation matrix for movement in theta."""
    matrix = np.array([
        [math.cos(theta), 0, -math.sin(theta), 0],
        [0, 1, 0, 0],
        [math.sin(theta), 0, math.cos(theta), 0],
        [0, 0, 0, 1],
    ])
    return torch.tensor(matrix, dtype=torch.float32)


def pose_spherical(theta, phi, t):
    """
    Get the camera to world matrix for the corresponding theta, phi
    and t.
    """
    c2w = get_translation_t(t)
    c2w = get_rotation_phi(phi / 180.0 * np.pi) @ c2w
    c2w = get_rotation_theta(theta / 180.0 * np.pi) @ c2w
    c2w = torch.tensor(
        np.array(
            [
                [-1, 0, 0, 0],
                [0, 0, 1, 0],
                [0, 1, 0, 0],
                [0, 0, 0, 1],
            ]
        ),
        dtype=torch.float32
    ) @ c2w
    return c2w

### **Render**

In [None]:
frames = list()
batch_ray = list()
batch_t = list()


PHI = -15.0
CAMERA_DIST = 3.5
DELTA_THETA = 3.0


model = MODEL()
model = WRAPPER.load_from_checkpoint(
    BEST_MODEL_PATH,
    arch=model,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    max_epoch=EPOCH,
)


# Iterate over different theta value and generate scenes.
for index, theta in tqdm(
    enumerate(np.linspace(0.0, 360.0, int(360./DELTA_THETA), endpoint=False))
    ):
    # Get the camera to world matrix.
    c2w = pose_spherical(theta, PHI, CAMERA_DIST)

    ray_orig, ray_dirs = ViewSynthesisDataset.get_rays(c2w)
    rays, t_vals = ViewSynthesisDataset.render_rays(ray_orig, ray_dirs)

    if index % BATCH_SIZE == 0 and index > 0:
        batched_ray = torch.stack(batch_ray, dim=0).to(
            "cpu" if not torch.cuda.is_available() else "cuda"
        )
        batch_ray = [rays]

        batched_t = torch.stack(batch_t, dim=0).to(
            "cpu" if not torch.cuda.is_available() else "cuda"
        )
        batch_t = [t_vals]

        color, _ = model(batched_ray, batched_t)

        color = color.data.cpu().numpy()
        temp = [
            cv2.resize(
                np.clip(255 * img, 0.0, 255.0).astype(np.uint8),
                (300, 300),
            ) for img in color
        ]

        frames = frames + temp
    else:
        batch_ray.append(rays)
        batch_t.append(t_vals)


video_path = "experiment/result.mp4"
imageio.mimwrite(video_path, frames, fps=30, quality=7, macro_block_size=None)
videoClip = VideoFileClip(video_path)
videoClip.write_gif("experiment/result.gif")
with open("experiment/result.gif", "rb") as f:
    display(Image(data=f.read(), format="gif"))