In [None]:
# # coord animation
# import os
# import math
# import glob
# import cairo
# import imageio

# files = glob.glob(f'graphics/coord/*.png')
# for f in files:
#     os.remove(f)

# resolution = 60
# for i in range(resolution):
#     u = (i-resolution/2) / (resolution/2)
#     x = (math.cos(math.pi * u) + 1) / 2.0
#     z = (math.sin(math.pi * u) + 1) / 2.0
#     coord_scale = 0.9
#     x = x * coord_scale + 0.5 * (1-coord_scale)
#     z = z * coord_scale + 0.5 * (1-coord_scale)
    
#     pixel_scale = 128
#     ims = cairo.ImageSurface(cairo.FORMAT_ARGB32, pixel_scale, pixel_scale)
#     cr = cairo.Context(ims)
#     cr.scale(pixel_scale, pixel_scale)

#     gray_c = 0.3

#     cr.set_source_rgb(gray_c, gray_c, gray_c)
#     cr.rectangle(0.0, 0.0, 1.0, 1.0)
#     cr.fill()

    
#     size = 0.1
#     cr.set_source_rgb(1.0, 0.0, 0.0)
#     cr.rectangle(x - size/2, z - size/2, size, size)
#     cr.fill()

#     ims.write_to_png(f"graphics/coord/{i:02d}.png")

# anim_file = f'graphics/coord.mp4'

# frames = []
# filenames = glob.glob(f'graphics/coord/*.png')
# filenames = sorted(filenames)
# for i, filename in enumerate(filenames):
#     frames.append(imageio.imread(filename))

# imageio.mimsave(anim_file, frames, fps=30)

In [None]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt

shift_factor = 0
scale_factor = 1

def get_mgrid(sidelen, dim):
    if dim == 4:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2], :sidelen[3]], axis=-1)[None, ...].astype(np.float32)
        for i in range(dim):
            pixel_coords[..., i] = pixel_coords[..., i] / max(sidelen[i] - 1, 1)
        pixel_coords = torch.from_numpy((pixel_coords - 0.5) * 2).view(-1, dim)
        pixel_coords = pixel_coords + torch.Tensor([shift_factor, shift_factor, 0, 0])
        pixel_coords = pixel_coords * torch.Tensor([scale_factor, scale_factor, 1, 1])
        return pixel_coords.view((sidelen[0] * sidelen[1], sidelen[2], sidelen[3], dim))
    raise NotImplementedError(f'Not implemented for dim={dim}')

class LightFieldData:
    def __init__(self, batch_size):
        self.device = torch.device("cuda:0")
        self.batch_size = batch_size
    
    def load_data_capture(self, data_folder):
        u, v = 64, 64
        h, w = 256, 256
        c = 3
        data = torch.zeros((u * v, h, w, c))
        for ui in range(u):
            for vi in range(v):
                im = cv2.imread(f'{data_folder}/{vi:03d}-{ui:03d}.png')
                im = cv2.resize(im, (h, w))
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                im = torch.from_numpy(((im * 2.0 / 255.0) - 1.0).astype(np.float32))
                data[ui*u+vi] = im
        
        self.mgrid = get_mgrid((u, v, h, w), dim=4)
        self.data = data
        self.indices = np.arange(len(self.data))
        self.cursor = len(self.indices)
        self.height = h
        self.width = w
        
        print(f"Data:{self.data.shape}, Type:{self.data.dtype}, Range:{(self.data.min(), self.data.max())}")
        print(f"Mesh Grid:{self.mgrid.shape}, Type:{self.mgrid.dtype}, Range:{(self.mgrid.min(), self.mgrid.max())}")
        
    def shuffle(self):
        self.indices = np.random.permutation(self.indices)
        self.cursor = 0
    
    def next_batch(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size
        if self.cursor + batch_size > len(self.indices):
            self.shuffle()
        batch_indices = self.indices[self.cursor:self.cursor + batch_size]
        self.cursor += batch_size
        batch_data = self.data[batch_indices]
        batch_coord = self.mgrid[batch_indices]
        return batch_data.to(self.device), batch_coord.to(self.device)
    
    def pixel_batch(self, batch_size):
        data_shape = self.data.view(-1, 3).shape
        pixel_ix = np.random.choice(data_shape[0], batch_size)
        batch_data = self.data.view(-1, 3)[pixel_ix]
        batch_coord = self.mgrid.view(-1, 4)[pixel_ix]
        return batch_data.to(self.device), batch_coord.to(self.device)

    def tensor_to_plt(self, in_tensor):
        plt.imshow(in_tensor[...,0:3].detach().cpu().clamp(-1,1) * 0.5 + 0.5)
        plt.show()

In [None]:
lf_data = LightFieldData(batch_size=8)
lf_data.load_data_capture('datasets/Ship-360')

In [None]:
import torch.nn as nn

W_0 = 30

class FCBlock(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features):
        super().__init__()

        self.net = []
        self.net.append(nn.Sequential(nn.Linear(in_features, hidden_features), Sine()))
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(nn.Linear(hidden_features, hidden_features), Sine()))
        self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))

        self.net = nn.Sequential(*self.net)
        self.net.apply(sine_init)
        self.net[0].apply(first_layer_sine_init)
    
    def forward(self, inputs):
        return self.net(inputs)

def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / W_0, np.sqrt(6 / num_input) / W_0)

def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        return torch.sin(W_0 * input)

class Siren(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, num_hidden_layers):
        super().__init__()
        self.net = FCBlock(in_features, out_features, num_hidden_layers, hidden_features)

    def forward(self, inputs):
        return self.net(inputs)

device = torch.device("cuda:0")
layers = 3
features = 1024
model = Siren(in_features=4, out_features=3, hidden_features=features, num_hidden_layers=layers)

model_str = f"ship360-{layers}-{features}-{W_0}"
model.load_state_dict(torch.load(f"outputs/{model_str}.pth"))

optim = torch.optim.Adam(lr=1e-4, params=model.parameters())

model.to(device)

In [None]:
import os
import math
import glob
import imageio

def to_uint8(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

vid_folder = f'graphics/{model_str}'
if not os.path.exists(vid_folder):
    os.makedirs(vid_folder)
    
files = glob.glob(f'{vid_folder}/*.png')
for f in files:
    os.remove(f)

resolution = 120
for i in range(resolution):
    u = (i-resolution/2) / (resolution/2)
    x = math.sin(math.pi * u)
    z = math.cos(math.pi * u)
    novel_coords = torch.Tensor([x, z]) * scale_factor
    
    im_coord = lf_data.mgrid[0].clone()
    im_coord[:,:,:2] = novel_coords
    
    with torch.no_grad():
        im_coord = im_coord.to(device)
        model_out = model(im_coord)
        model_out = to_uint8(model_out)
        cv2.imwrite(f"{vid_folder}/{i:03d}.png", cv2.cvtColor(model_out, cv2.COLOR_RGB2BGR))

anim_file = f'{vid_folder}.mp4'

frames = []
filenames = glob.glob(f'{vid_folder}/*.png')
filenames = sorted(filenames)
for i, filename in enumerate(filenames):
    frames.append(imageio.imread(filename))

imageio.mimsave(anim_file, frames, fps=30)

In [None]:
rows = 4
cols = 4

batch_data, _ = lf_data.next_batch(rows*cols)
for i in range(len(batch_data[:2])):
    lf_data.tensor_to_plt(batch_data[i])

all_canvas = None
for i in range(rows):
    row_canvas = None
    for j in range(cols):
        gt_frame = to_uint8(batch_data[i*cols+j])
        if j < cols-1:
            gt_frame = np.hstack((gt_frame, np.zeros((gt_frame.shape[0], 16, 3), dtype=np.uint8)))
        if row_canvas is None:
            row_canvas = gt_frame
        else:
            row_canvas = np.hstack((row_canvas, gt_frame))
    if i < rows-1:
        row_canvas = np.vstack((row_canvas, np.zeros((16, row_canvas.shape[1], 3), dtype=np.uint8)))
    if all_canvas is None:
        all_canvas = row_canvas
    else:
        all_canvas = np.vstack((all_canvas, row_canvas))

imageio.imwrite(f'graphics/views.png', all_canvas)