In [1]:
from torch.utils.data import Dataset
import torch
import numpy as np
from utils import get_ray_directions, get_rays
import os
import json
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Tuple
import lightning as L
from rendering import rendering
from utils import test

from dataset import LegoDataset, LegoDataModule
from model import Nerf

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
lego_dataset = LegoDataset(root_dir="dataset/lego/", split="test", img_shape=(400, 400))

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# hyperparams
tn = 2.0
tf = 6.0
nb_epochs = 16
learning_rate = 1e-3
gamma = 0.5
nb_bins = 100

model = Nerf().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[2, 4, 8], gamma=gamma
)


# warmup on 1 epoch
training_loss = training(
    model, optimizer, scheduler, warmup_dataloader, tn, tf, nb_bins, 1, device
)
plt.plot(training_loss)
plt.show()

# training_loss = training(model, optimizer, scheduler, train_dataloader, tn, tf, nb_bins, nb_epochs, device)
# plt.plot(training_loss)
# plt.show()

In [3]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# nerf_model = torch.load("models/model_nerf").to(device)
# nerf_model.eval()

# hyperparams
tn = 2.0
tf = 6.0
nb_epochs = 16
learning_rate = 5e-4
gamma = 0.5
nb_bins = 192

In [None]:
import lightning as L
from model import NeRFLightning

lit_nerf = NeRFLightning.load_from_checkpoint("models/16_epoch_192_bins_400_nerf.ckpt")
lit_nerf.eval()

In [None]:
from utils import test

for idx in range(0, 200):
    img, mse, psnr = test(
        lit_nerf,
        lego_dataset[idx]["rays_origin"].reshape(-1, 3).to(device).type(torch.float),
        lego_dataset[idx]["rays_direction"].reshape(-1, 3).to(device).type(torch.float),
        tn,
        tf,
        image_index=idx,
        nb_bins=192,
        chunk_size=10,
        height=lego_dataset.img_shape[0],
        width=lego_dataset.img_shape[1],
        target=lego_dataset[idx]["rgbs"].numpy(),
        outputs_dir="lkjasdaf",
        metrics=False,
    )

### lit ckpt

In [None]:
# lit_nerf.state_dict()
# torch.save(obj=lit_nerf.state_dict(), f="models/lit_nerf_state_dict.pt")

### litmodel as nn.module

In [4]:
import torch
from model import Nerf

checkpoint = torch.load("models/epoch=16-step=83670.ckpt")
# since keys dont align properly here
mapped_state_dict = {}
for k, v in checkpoint["state_dict"].items():
    if k.startswith("nerf."):  # If checkpoint keys start with "nerf.", remove it
        mapped_state_dict[k[len("nerf.") :]] = v
    else:
        mapped_state_dict[k] = v
nerf = Nerf().to(device)
nerf.load_state_dict(mapped_state_dict)
torch.save(obj=nerf.state_dict(), f="models/16_epoch_192_bins_400_nerf.pt")

In [None]:
idx = 2
img, mse, psnr = test(
    lit_nerf,
    lego_dataset[idx]["rays_origin"].reshape(-1, 3).to(device).type(torch.float),
    lego_dataset[idx]["rays_direction"].reshape(-1, 3).to(device).type(torch.float),
    tn,
    tf,
    image_index=idx,
    nb_bins=100,
    chunk_size=20,
    height=lego_dataset.img_shape[0],
    width=lego_dataset.img_shape[1],
    target=lego_dataset[idx]["rgbs"].numpy(),
    outputs_dir="tbuadslfjk",
    metrics=False,
)

In [7]:
import json


def get_avg_metrics(outputs_json_dir: str) -> Tuple[float, float]:
    with open("outputs/nerf_testing.json", "r") as f:
        data = json.load(f)

    avg_psnr = sum(data["psnr"]) / len(data["psnr"])
    avg_mse = sum(data["mse"]) / len(data["mse"])
    return avg_psnr, avg_mse

(29.200356294523996, 0.0012457877128773586)