In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt

shift_factor = 0
scale_factor = 1

def get_mgrid(sidelen, dim):
    if dim == 4:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2], :sidelen[3]], axis=-1)[None, ...].astype(np.float32)
        for i in range(dim):
            pixel_coords[..., i] = pixel_coords[..., i] / max(sidelen[i] - 1, 1)
        pixel_coords = torch.from_numpy((pixel_coords - 0.5) * 2).view(-1, dim)
        pixel_coords = pixel_coords + torch.Tensor([shift_factor, shift_factor, 0, 0])
        pixel_coords = pixel_coords * torch.Tensor([scale_factor, scale_factor, 1, 1])
        return pixel_coords.view((sidelen[0] * sidelen[1], sidelen[2], sidelen[3], dim))
    raise NotImplementedError(f'Not implemented for dim={dim}')

class LightFieldData:
    def __init__(self, batch_size):
        self.device = torch.device("cuda:0")
        self.batch_size = batch_size
        
    def load_data_capture(self, data_folder):
        u, v = 64, 64
        h, w = 256, 256
        c = 3
        data = torch.zeros((u * v, h, w, c))
        for ui in range(u):
            for vi in range(v):
                im = cv2.imread(f'{data_folder}/{vi:03d}-{ui:03d}.png')
                im = cv2.resize(im, (h, w))
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                im = torch.from_numpy(((im * 2.0 / 255.0) - 1.0).astype(np.float32))
                data[ui*u+vi] = im
        
        self.mgrid = get_mgrid((u, v, h, w), dim=4)
        self.data = data
        self.indices = np.arange(len(self.data))
        self.cursor = len(self.indices)
        self.height = h
        self.width = w
        
        print(f"Data:{self.data.shape}, Type:{self.data.dtype}, Range:{(self.data.min(), self.data.max())}")
        print(f"Mesh Grid:{self.mgrid.shape}, Type:{self.mgrid.dtype}, Range:{(self.mgrid.min(), self.mgrid.max())}")
        
    def shuffle(self):
        self.indices = np.random.permutation(self.indices)
        self.cursor = 0
    
    def next_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size
        if self.cursor + batch_size > len(self.indices):
            self.shuffle()
        batch_indices = self.indices[self.cursor:self.cursor + batch_size]
        self.cursor += batch_size
        batch_data = self.data[batch_indices]
        batch_coord = self.mgrid[batch_indices]
        return batch_data.to(self.device), batch_coord.to(self.device)
    
    def pixel_batch(self, batch_size):
        data_shape = self.data.view(-1, 3).shape
        pixel_ix = np.random.choice(data_shape[0], batch_size)
        batch_data = self.data.view(-1, 3)[pixel_ix]
        batch_coord = self.mgrid.view(-1, 4)[pixel_ix]
        return batch_data.to(self.device), batch_coord.to(self.device)

    def tensor_to_plt(self, in_tensor):
        plt.imshow(in_tensor[...,0:3].detach().cpu().clamp(-1,1) * 0.5 + 0.5)
        plt.show()

In [None]:
lf_data = LightFieldData(batch_size=8)
# lf_data.load_data_sintel('datasets/', u=3, v=3, channels=3, height=384, width=384, rs=450, bs=1)
lf_data.load_data_capture('datasets/Ship-360')

In [None]:
all_data, all_coord = lf_data.data, lf_data.mgrid
for i in range(len(all_data[:2])):
    lf_data.tensor_to_plt(all_data[i])

In [None]:
import torch.nn as nn

W_0 = 30

# class FullyConnected(nn.Module):
#     def __init__(self, in_features, out_features, num_hidden_layers, hidden_features):
#         super().__init__()

#         self.net = []
#         self.net.append(nn.Sequential(nn.Linear(in_features, hidden_features), nn.ReLU()))
#         for i in range(num_hidden_layers):
#             self.net.append(nn.Sequential(nn.Linear(hidden_features, hidden_features), nn.ReLU()))
#         self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))

#         self.net = nn.Sequential(*self.net)
    
#     def forward(self, inputs):
#         return self.net(inputs)

# class ResidualBlock(nn.Module):
#     def __init__(self, module):
#         super().__init__()
#         self.module = module

#     def forward(self, inputs):
#         return self.module(inputs) + inputs

class FCBlock(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features):
        super().__init__()

        self.net = []
        self.net.append(nn.Sequential(nn.Linear(in_features, hidden_features), Sine()))
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(nn.Linear(hidden_features, hidden_features), Sine()))
            # self.net.append(ResidualBlock(nn.Sequential(nn.Linear(hidden_features, hidden_features), Sine())))
        self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))

        self.net = nn.Sequential(*self.net)
        self.net.apply(sine_init)
        self.net[0].apply(first_layer_sine_init)
    
    def forward(self, inputs):
        return self.net(inputs)

def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / W_0, np.sqrt(6 / num_input) / W_0)

def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        return torch.sin(W_0 * input)
        # return torch.sinc(W_0 * input)

class Siren(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, num_hidden_layers):
        super().__init__()
        self.net = FCBlock(in_features, out_features, num_hidden_layers, hidden_features)

    def forward(self, inputs):
        return self.net(inputs)
    
class MultiSiren(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, num_hidden_layers, split):
        super().__init__()
        self.uv_nets = []
        self.split = split
        for _ in range(self.split):
            v_nets = []
            for _ in range(self.split):
                v_nets.append(FCBlock(in_features, out_features, num_hidden_layers, hidden_features))
            self.uv_nets.append(nn.ModuleList(v_nets))
        self.uv_nets = nn.ModuleList(self.uv_nets)
            
    def uv_to_index(self, u, v):
        ui = 0
        while u > (-1 + 2/self.split):
            ui += 1
            u -= 2/self.split
        vi = 0
        while v > (-1 + 2/self.split):
            vi += 1
            v -= 2/self.split
        return (ui, vi)

    def forward(self, inputs):
        u = inputs.flatten()[0].item()
        v = inputs.flatten()[1].item()
        uv_ix = self.uv_to_index(u, v)
        return self.uv_nets[uv_ix[0]][uv_ix[1]](inputs)

device = torch.device("cuda:0")
layers = 3
features = 1024
# model = Siren(in_features=4, out_features=3, hidden_features=features, num_hidden_layers=layers)
model = MultiSiren(in_features=4, out_features=3, hidden_features=features, num_hidden_layers=layers, split=2)
# model = FullyConnected(in_features=4, out_features=3, hidden_features=1024, num_hidden_layers=3)

optim = torch.optim.Adam(lr=1e-4, params=model.parameters())
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)

model.to(device)

In [None]:
from datetime import datetime
import torch.nn.functional as F

dummy_input = torch.randn(64*64, 4, device="cuda")
model_str = f"ship360-multi-{layers}-{features}-{W_0}"
print(model_str)

step = 0
while True:
    # batch_data, batch_coord = lf_data.pixel_batch(batch_size=100_000)
    batch_data, batch_coord = lf_data.next_batch(batch_size=1)
    model_output = model(batch_coord)
    loss = (model_output - batch_data).abs().mean()
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    step += 1
    if step == 1 or step % 50 == 0:
        print(f"{datetime.now()} step:{step:04d}, loss:{loss.item():0.8f}")
        torch.save(model.state_dict(), f"outputs/{model_str}.pth")
        # torch.onnx.export(model,
        #                   dummy_input,
        #                   f"outputs/{model_str}.onnx",
        #                   export_params=True,
        #                   opset_version=9,
        #                   do_constant_folding=True,
        #                   input_names = ['x'],
        #                   output_names = ['y']
        #                  )
    # if step % 150 == 0:
    #     scheduler.step()

In [None]:
with torch.no_grad():
    batch_data, batch_coord = lf_data.next_batch(batch_size=1)
    model_output = model(batch_coord)
for i in range(len(model_output)):
    lf_data.tensor_to_plt(batch_data[i])
    lf_data.tensor_to_plt(model_output[i])