In [None]:
!pip install torch torchvision tensorboard
!apt-get install colmap

!git clone https://github.com/bmild/nerf.git
%cd nerf

In [None]:

%%writefile nerf_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, output_ch=4, skips=[4]):
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.output_ch = output_ch
        self.skips = skips

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D - 1)])
        self.output_linear = nn.Linear(W, output_ch)

    def forward(self, x):
        h = x
        for i, l in enumerate(self.pts_linears):
            h = F.relu(l(h))
            if i in self.skips:
                h = torch.cat([x, h], -1)
        outputs = self.output_linear(h)
        return outputs


Writing nerf_model.py


In [None]:
import cv2
import os

video_path = '/content/video.mp4'
output_folder = 'extracted_frames/'

os.makedirs(output_folder, exist_ok=True)

cap = cv2.VideoCapture(video_path)
frame_count = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    cv2.imwrite(f"{output_folder}/frame_{frame_count:04d}.png", frame)
    frame_count += 1

cap.release()
print(f"Extracted {frame_count} frames.")



Extracted 220 frames.


In [None]:

!colmap feature_extractor --database_path /content/database.db --image_path /content/extracted_frames/


!colmap exhaustive_matcher --database_path /content/database.db


!mkdir /content/sparse
!colmap mapper --database_path /content/database.db --image_path /content/extracted_frames/ --output_path /content/sparse


!mkdir /content/dense
!colmap image_undistorter --image_path /content/extracted_frames/ --input_path /content/sparse/0 --output_path /content/dense --output_type COLMAP
!colmap patch_match_stereo --workspace_path /content/dense --workspace_format COLMAP --PatchMatchStereo.geom_consistency true
!colmap stereo_fusion --workspace_path /content/dense --workspace_format COLMAP --input_type geometric --output_path /content/dense/fused.ply


In [None]:
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from nerf_model import NeRF

num_epochs = 100
batch_size = 1024
learning_rate = 1e-4
log_dir = '/content/logs/'

model = NeRF()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()

writer = SummaryWriter(log_dir)

def data_loader():
    while True:
        yield torch.randn(batch_size, 3), torch.randn(batch_size, 4)

for epoch in range(num_epochs):
    for step, (rays, targets) in enumerate(data_loader()):
        optimizer.zero_grad()
        outputs = model(rays)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        writer.add_scalar('Loss/train', loss.item(), epoch * 100 + step)

    print(f"Epoch {epoch} Loss: {loss.item()}")

torch.save(model.state_dict(), '/content/nerf_model.pth')
writer.close()


In [None]:
import torch
import cv2
import numpy as np
from nerf_model import NeRF

model = NeRF()
model.load_state_dict(torch.load('/content/nerf_model.pth'))
model.eval()

new_views = [torch.randn(3) for _ in range(60)]

rendered_images = []
for view in new_views:
    with torch.no_grad():
        img = model(view.unsqueeze(0)).squeeze().numpy()
        img = (img - img.min()) / (img.max() - img.min()) * 255
        img = img.astype(np.uint8)
        rendered_images.append(img)

height, width = rendered_images[0].shape[:2]
out = cv2.VideoWriter('/content/output_video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
for img in rendered_images:
    out.write(img)
out.release()


In [None]:
import os
import cv2
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from nerf_model import NeRF

def load_test_images(path):
    images = []
    for file in sorted(os.listdir(path)):
        img = cv2.imread(os.path.join(path, file))
        images.append(img)
    return images

def load_rendered_images():
    return [torch.randn(256, 256, 3).numpy() for _ in range(60)]

test_images = load_test_images('/content/test_images/')
rendered_images = load_rendered_images()

psnr_values = []
ssim_values = []
for rendered, test in zip(rendered_images, test_images):
    psnr_value = psnr(rendered, test)
    ssim_value = ssim(rendered, test, multichannel=True)
    psnr_values.append(psnr_value)
    ssim_values.append(ssim_value)
    print(f"PSNR: {psnr_value}, SSIM: {ssim_value}")

print(f"Average PSNR: {sum(psnr_values) / len(psnr_values)}, Average SSIM: {sum(ssim_values) / len(ssim_values)}")
