In [None]:
import math
import zmq
import cv2
import time
import torch
import numpy as np
import scipy.ndimage
import torch.nn as nn

W_0 = 30

def get_mgrid(sidelen, dim):
    if dim == 4:
        pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2], :sidelen[3]], axis=-1)[None, ...].astype(np.float32)
        for i in range(dim):
            pixel_coords[..., i] = pixel_coords[..., i] / max(sidelen[i] - 1, 1)
        pixel_coords = torch.from_numpy((pixel_coords - 0.5) * 2).view(-1, dim)
        pixel_coords = pixel_coords + torch.Tensor([shift_factor, shift_factor, 0, 0])
        pixel_coords = pixel_coords * torch.Tensor([scale_factor, scale_factor, 1, 1])
        return pixel_coords.view((sidelen[0] * sidelen[1], sidelen[2], sidelen[3], dim))
    raise NotImplementedError(f'Not implemented for dim={dim}')
    
def apply_filter(diff_out, window_size):
    diff_ixs = np.where(diff_out == True)
    for ix in range(len(diff_ixs[0])):
        i = diff_ixs[0][ix]
        j = diff_ixs[1][ix]
        diff_out[i-(window_size//2):i+window_size//2, j-(window_size//2):j+window_size//2] = True
    return diff_out

def to_uint8(im):
    return ((im.detach().cpu().clamp(-1,1).numpy() * 0.5 + 0.5) * 255).astype(np.uint8)

class FCBlock(nn.Module):
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features):
        super().__init__()

        self.net = []
        self.net.append(nn.Sequential(nn.Linear(in_features, hidden_features), Sine()))
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(nn.Linear(hidden_features, hidden_features), Sine()))
        self.net.append(nn.Sequential(nn.Linear(hidden_features, out_features)))

        self.net = nn.Sequential(*self.net)
        self.net.apply(sine_init)
        self.net[0].apply(first_layer_sine_init)
    
    def forward(self, inputs):
        return self.net(inputs)

def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-np.sqrt(6 / num_input) / W_0, np.sqrt(6 / num_input) / W_0)

def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            m.weight.uniform_(-1 / num_input, 1 / num_input)

class Sine(nn.Module):
    def __init(self):
        super().__init__()

    def forward(self, input):
        return torch.sin(W_0 * input)

class Siren(nn.Module):
    def __init__(self, in_features, out_features, hidden_features, num_hidden_layers):
        super().__init__()
        self.net = FCBlock(in_features, out_features, num_hidden_layers, hidden_features)

    def forward(self, inputs):
        return self.net(inputs)

device = torch.device("cuda:0")
model = Siren(in_features=4, out_features=3, hidden_features=1024, num_hidden_layers=3)
model.load_state_dict(torch.load("outputs/ship-3-1024-30.pth"))
model.to(device)
print("Loaded model")

resolution = 120
shift_factor = 0
scale_factor = 1
i = 0

u, v = 64, 64
h, w = 128, 128
filter_size = 7 if h == 128 else 9
fixed_point_skip = 4

mgrid = get_mgrid((u, v, h, w), dim=4)
print("Computed input domain")

context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind('tcp://*:5555')

print("Server started")
eps = 1e-5

prev_out = None
diff_out = None

im_coord = mgrid[0].clone().to(device)

while True:
    start = time.time()
    
    message_rx = socket.recv()
    angular_coord = np.fromstring(message_rx, dtype=float, sep=',')
    
    u = angular_coord[0]
    v = angular_coord[1]
    im_coord[:,:,0] = u * scale_factor
    im_coord[:,:,1] = v * scale_factor
    with torch.no_grad():
        if diff_out is not None:
            diff_out[::fixed_point_skip,::fixed_point_skip] = True
            pred_coord = im_coord[diff_out]
            pred_out = model(pred_coord)
            pred_out = to_uint8(pred_out)
            model_out[diff_out] = pred_out
        else:
            model_out = model(im_coord)
            model_out = to_uint8(model_out)
        if prev_out is not None:
            diff_out = prev_out != model_out.sum(axis=2)
            diff_out = apply_filter(diff_out, filter_size)
        prev_out = model_out.sum(axis=2)
    
    socket.send(model_out.tobytes())
    
    end = time.time() + eps
    i = (i + 1) % resolution
    if i % resolution == 0:
        print(f"Received request: {message_rx}, u:{u}, v:{v}, samples:{diff_out.sum()}, FPS:{int(1 / (end - start))}")