In [None]:
import torch
import numpy as np
import torch.nn as nn

W_0 = 30

class FCBlock(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features):
        super().__init__()

        self.net = []
        self.net.append(nn.Sequential(nn.Linear(in_features, hidden_features), Sine()))
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(nn.Linear(hidden_features, hidden_features), Sine()))
        self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))

        self.net = nn.Sequential(*self.net)
        self.net.apply(sine_init)
        self.net[0].apply(first_layer_sine_init)
    
    def forward(self, inputs):
        return self.net(inputs)

def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / W_0, np.sqrt(6 / num_input) / W_0)

def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        return torch.sin(W_0 * input)

class Siren(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, num_hidden_layers):
        super().__init__()
        self.net = FCBlock(in_features, out_features, num_hidden_layers, hidden_features)

    def forward(self, inputs):
        return self.net(inputs)

device = torch.device("cuda:0")
model = Siren(in_features=4, out_features=3, hidden_features=1024, num_hidden_layers=3)

optim = torch.optim.Adam(lr=1e-4, params=model.parameters())

model.to(device)
model.load_state_dict(torch.load("outputs/ship-3-1024-30.pth"))

In [None]:
import math
import zmq
import cv2

def get_mgrid(sidelen, dim):
    if dim == 4:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2], :sidelen[3]], axis=-1)[None, ...].astype(np.float32)
        for i in range(dim):
            pixel_coords[..., i] = pixel_coords[..., i] / max(sidelen[i] - 1, 1)
        pixel_coords = torch.from_numpy((pixel_coords - 0.5) * 2).view(-1, dim)
        pixel_coords = pixel_coords + torch.Tensor([shift_factor, shift_factor, 0, 0])
        pixel_coords = pixel_coords * torch.Tensor([scale_factor, scale_factor, 1, 1])
        return pixel_coords.view((sidelen[0] * sidelen[1], sidelen[2], sidelen[3], dim))
    raise NotImplementedError(f'Not implemented for dim={dim}')

def to_uint8(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

resolution = 120
shift_factor = 0
scale_factor = 1
i = 0

u, v = 64, 64
h, w = 128, 128

mgrid = get_mgrid((u, v, h, w), dim=4)

context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind('tcp://*:5555')

while True:
    u = (i-resolution/2) / (resolution/2)
    x = math.sin(math.pi * u)
    z = math.cos(math.pi * u)
    novel_coords = torch.Tensor([x, z]) * scale_factor
    
    im_coord = mgrid[0].clone()
    im_coord[:,:,:2] = novel_coords
    
    with torch.no_grad():
        im_coord = im_coord.to(device)
        model_out = model(im_coord)
        model_out = to_uint8(model_out)
    
    socket.recv()
    socket.send(cv2.flip(model_out, 0).tobytes())
    
    i = (i + 1) % resolution