In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('mps')

In [2]:
class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=15, output_ch=3):  # output_ch is 3 for RGB
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.output_ch = output_ch
        
        self.layers = [nn.Linear(input_ch, W)]
        self.layers += [nn.Linear(W, W) for _ in range(D-1)]
        self.layers = nn.ModuleList(self.layers)
        self.output_layer = nn.Linear(W, output_ch)
    
    def forward(self, x):
        h = x
        for layer in self.layers:
            h = torch.relu(layer(h))
        return self.output_layer(h)

In [3]:
class LLFFDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'images_8')
        self.images = sorted([os.path.join(self.image_dir, f) for f in os.listdir(self.image_dir) if f.endswith(('.JPG', '.png'))])
        self.poses_bounds = np.load(os.path.join(root_dir, 'poses_bounds.npy'))
        print(f'Loaded {len(self.images)} images.')
        print(f'Camera parameters shape: {self.poses_bounds.shape}')
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = np.array(Image.open(img_path)) / 255.0
        cam_params = self.poses_bounds[idx, :-2]  # Assuming last two values in poses_bounds are bounds
        # Flatten the image and duplicate cam_params to match
        img = img.reshape(-1, 3)
        cam_params = np.tile(cam_params, (img.shape[0], 1))
        return torch.tensor(img, dtype=torch.float32), torch.tensor(cam_params, dtype=torch.float32)


In [4]:
def train_nerf(nerf, dataset, epochs=100, batch_size=1, lr=5e-4, save_path='nerf_model.pth'):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(nerf.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    nerf.to(device)
    for epoch in range(epochs):
        epoch_loss = 0
        for images, cam_params in tqdm(dataloader):
            images, cam_params = images.to(device), cam_params.to(device)
            optimizer.zero_grad()
            predictions = nerf(cam_params)
            loss = criterion(predictions, images)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader)}')
    
    # Save the model
    torch.save(nerf.state_dict(), save_path)
    print(f'Model saved to {save_path}')

In [39]:
EPOCHS = 2
BATCH_SIZE = 1  # This should be 1 due to the high resolution of the images
LR = 1e-3
CHUNK_SIZE = 1024  # Adjust this value as needed

# Paths
DATASET_DIR = '/Users/ewojcik/Code/AIxperiments/3D/LLFF_Data/real_iconic/airplants'  # Adjust to your dataset directory

# Initialize Dataset and Model
dataset = LLFFDataset(DATASET_DIR)
nerf_model = NeRF()

# Train the Model
train_nerf(nerf_model, dataset, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR)

Loaded 30 images.
Camera parameters shape: (30, 17)


  0%|          | 0/30 [00:00<?, ?it/s]

100%|██████████| 30/30 [00:53<00:00,  1.78s/it]


Epoch 1/10, Loss: 0.09768241910884777


100%|██████████| 30/30 [00:48<00:00,  1.60s/it]


Epoch 2/10, Loss: 0.025896312296390535


100%|██████████| 30/30 [00:47<00:00,  1.59s/it]


Epoch 3/10, Loss: 0.021713684995969137


100%|██████████| 30/30 [00:45<00:00,  1.51s/it]


Epoch 4/10, Loss: 0.022999316019316516


100%|██████████| 30/30 [00:48<00:00,  1.63s/it]


Epoch 5/10, Loss: 0.021683336483935514


100%|██████████| 30/30 [00:54<00:00,  1.82s/it]


Epoch 6/10, Loss: 0.021402562720080218


100%|██████████| 30/30 [00:46<00:00,  1.56s/it]


Epoch 7/10, Loss: 0.022627480265994867


100%|██████████| 30/30 [00:44<00:00,  1.47s/it]


Epoch 8/10, Loss: 0.021569070344169935


100%|██████████| 30/30 [00:45<00:00,  1.51s/it]


Epoch 9/10, Loss: 0.021984835093220075


100%|██████████| 30/30 [00:47<00:00,  1.59s/it]

Epoch 10/10, Loss: 0.021986890646318594
Model saved to nerf_model.pth





In [7]:
def load_and_visualize_model(model_path, image_path, cam_params, chunk_size=1024, output_file='output.png'):
    # Load the model
    nerf_model = NeRF()
    nerf_model.load_state_dict(torch.load(model_path))
    nerf_model.to(device)
    nerf_model.eval()
    
    # Load the image
    img = np.array(Image.open(image_path)) / 255.0
    img_flatten = img.reshape(-1, 3)
    
    # Duplicate the camera parameters
    cam_params = np.tile(cam_params, (img_flatten.shape[0], 1))
    
    # Convert to tensors
    cam_params_tensor = torch.tensor(cam_params, dtype=torch.float32).to(device)
    
    # Initialize an empty array for the output image
    output_flatten = np.zeros_like(img_flatten)
    
    # Process the image in chunks
    for i in range(0, img_flatten.shape[0], chunk_size):
        img_chunk = img_flatten[i:i+chunk_size]
        cam_params_chunk = cam_params_tensor[i:i+chunk_size]
        
        with torch.no_grad():
            predictions = nerf_model(cam_params_chunk).cpu().numpy()
        
        output_flatten[i:i+chunk_size] = predictions
    
    # Reshape and save the output image
    output_image = (output_flatten * 255).astype(np.uint8).reshape(img.shape)
    output_image = Image.fromarray(output_image)
    output_image.save(output_file)
    print(f'Output image saved as {output_file}')
    
    # Display the output image
    plt.imshow(output_image)
    plt.show()


In [8]:
MODEL_PATH = 'nerf_model.pth'
IMAGE_PATH = '/Users/ewojcik/Code/AIxperiments/3D/LLFF_Data/real_iconic/airplants/images/IMG_2066.JPG'  # Adjust to one of your images
CAM_PARAMS = np.load('/Users/ewojcik/Code/AIxperiments/3D/LLFF_Data/real_iconic/airplants/poses_bounds.npy')

load_and_visualize_model(MODEL_PATH, IMAGE_PATH, CAM_PARAMS)

: 