In [5]:
from torch.utils.data import Dataset
import torch
import numpy as np
from utils import get_ray_directions, get_rays
import os
import json
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Tuple
import lightning as L
from rendering import rendering

from dataset import LegoDataset, LegoDataModule
from model import Nerf

In [6]:
lego_dataset = LegoDataset(
    root_dir="dataset/lego/", split="test", img_shape=(200, 200)
)

In [10]:
train_dataloader = DataLoader(
    torch.cat(
        (
            lego_dataset.all_rays_origin.type(torch.float),
            lego_dataset.all_rays_direction.type(torch.float),
            lego_dataset.all_rgbs.type(torch.float),
        ),
        dim=1,
    ),
    batch_size=1024,
    num_workers=3,
    shuffle=True,
)

image_1_4 = int(lego_dataset.img_shape[0] * (1 / 4))
image_3_4 = int(lego_dataset.img_shape[0] * (3 / 4))
warmup_dataloader = DataLoader(
    torch.cat(
        (
            lego_dataset.all_rays_origin.reshape(
                len(lego_dataset),
                lego_dataset.img_shape[0],
                lego_dataset.img_shape[1],
                3,
            )[:, image_1_4:image_3_4, image_1_4:image_3_4, :]
            .reshape(-1, 3)
            .type(torch.float),
            lego_dataset.all_rays_direction.reshape(
                len(lego_dataset),
                lego_dataset.img_shape[0],
                lego_dataset.img_shape[1],
                3,
            )[:, image_1_4:image_3_4, image_1_4:image_3_4, :]
            .reshape(-1, 3)
            .type(torch.float),
            lego_dataset.all_rgbs.reshape(
                len(lego_dataset),
                lego_dataset.img_shape[0],
                lego_dataset.img_shape[1],
                3,
            )[:, image_1_4:image_3_4, image_1_4:image_3_4, :]
            .reshape(-1, 3)
            .type(torch.float),
        ),
        dim=1,
    ),
    batch_size=1024,
    num_workers=3,
    shuffle=True,
)

In [11]:
def training(
    model, optimizer, scheduler, dataloader, tn, tf, nb_bins, nb_epochs, device="cpu"
):

    training_loss = []

    progress_bar = tqdm(
        enumerate(dataloader),
        total=len(dataloader),
    )

    for epoch in range(nb_epochs):
        progress_bar.set_description(f"Training Epoch: {epoch}")
        for idx, batch in progress_bar:
            origin = batch[:, :3].to(device)
            direction = batch[:, 3:6].to(device)

            target = batch[:, 6:].to(device)

            prediction = rendering(model, origin, direction, tn, tf, nb_bins, device)

            loss = ((prediction - target) ** 2).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            progress_bar.set_postfix({"loss": loss.item()})
            training_loss.append(loss.item())

        scheduler.step()

        torch.save(model.cpu(), "models/model_nerf")
        model.to(device)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# hyperparams
tn = 2.0
tf = 6.0
nb_epochs = 16
learning_rate = 1e-3
gamma = 0.5
nb_bins = 100

model = Nerf().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[2, 4, 8], gamma=gamma
)


# warmup on 1 epoch
training_loss = training(
    model, optimizer, scheduler, warmup_dataloader, tn, tf, nb_bins, 1, device
)
plt.plot(training_loss)
plt.show()

# training_loss = training(model, optimizer, scheduler, train_dataloader, tn, tf, nb_bins, nb_epochs, device)
# plt.plot(training_loss)
# plt.show()

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
nerf_model = torch.load('models/model_nerf').to(device)
nerf_model.eval()

# hyperparams
tn = 2.0
tf = 6.0
nb_epochs = 16
learning_rate = 1e-3
gamma = 0.5
nb_bins = 100

In [None]:
from utils import test
for idx in range(55, 60):
    img, mse, psnr = test(
    nerf_model,
    lego_dataset[idx]['rays_origin'].reshape(-1, 3).to(device).type(torch.float),
    lego_dataset[idx]['rays_direction'].reshape(-1, 3).to(device).type(torch.float),
    tn,
    tf,
    image_index=idx,
    nb_bins=100,
    chunk_size=20,
    height=lego_dataset.img_shape[0],
    width=lego_dataset.img_shape[1],
    target=lego_dataset[idx]['rgbs'].numpy(),
    outputs_dir='nerf_testbruh',
    title=False
)

### lit ckpt

In [22]:
import lightning as L
from model import NeRFLightning

lit_nerf = NeRFLightning.load_from_checkpoint("logs/LegoNeRF/ei19d3bj/checkpoints/epoch=0-step=3907.ckpt")
lit_nerf.eval()
lit_nerf.state_dict()
torch.save(obj=lit_nerf.state_dict(), f='models/lit_nerf_state_dict.pt')

### litmodel as nn.module

In [16]:
import torch
from model import Nerf
checkpoint = torch.load("logs/LegoNeRF/ei19d3bj/checkpoints/epoch=0-step=3907.ckpt")
#since keys dont align properly here
mapped_state_dict = {}
for k, v in checkpoint['state_dict'].items():
    if k.startswith("nerf."):  # If checkpoint keys start with "nerf.", remove it
        mapped_state_dict[k[len("nerf."):]] = v
    else:
        mapped_state_dict[k] = v
nerf = Nerf()
nerf.load_state_dict(mapped_state_dict)
torch.save(obj=nerf.state_dict(), f="models/saved_nerf.pt")

{'block1.0.weight': tensor([[-0.1030, -0.0449, -0.1626,  ...,  0.0295,  0.0263,  0.0383],
         [ 0.1102,  0.0949, -0.1201,  ...,  0.0026,  0.0050,  0.0234],
         [-0.1619,  0.0048, -0.0798,  ...,  0.0212, -0.0153,  0.0504],
         ...,
         [-0.1228,  0.2679, -0.0969,  ..., -0.0236,  0.0198, -0.0094],
         [-0.1460,  0.0541,  0.0025,  ...,  0.0031,  0.0038,  0.0026],
         [-0.0293, -0.1723,  0.0568,  ..., -0.0084, -0.0469, -0.0482]],
        device='cuda:0'),
 'block1.0.bias': tensor([ 5.5093e-02,  5.0232e-02, -3.8056e-02, -2.7814e-03, -1.1225e-01,
          8.3858e-02, -9.9738e-02, -7.3286e-02, -5.3965e-02, -5.5836e-02,
          1.0592e-01, -7.5771e-02, -6.5942e-02, -1.7457e-02, -7.1793e-02,
         -5.8036e-02,  8.3389e-02, -1.1971e-01,  4.4042e-03,  7.6238e-02,
         -1.3191e-01, -1.1479e-01, -3.7590e-02, -1.1718e-01, -1.5795e-01,
         -6.0562e-02, -3.3646e-02, -9.1756e-02,  8.6587e-02,  2.5170e-02,
          8.1365e-02, -1.3753e-01, -2.4666e-02, -5.03