In [1]:
import os
import numpy as np
from itertools import chain

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize

from model import Siren, Homography
from util import get_mgrid, apply_homography, jacobian, VideoFitting

In [2]:
def train(path, total_steps, lambda_interf=0.001, lambda_excl=0.002, verbose=True, steps_til_summary=100):
    transform = Compose([
        ToTensor(),
        Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
    ])
    v = VideoFitting(path, transform)
    videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)

    g = Homography(hidden_features=256, hidden_layers=2)
    g.cuda()
    f1 = Siren(in_features=2, out_features=3, hidden_features=256, 
               hidden_layers=4, outermost_linear=True)
    f1.cuda()
    f2 = Siren(in_features=3, out_features=3, hidden_features=128,
               hidden_layers=4, outermost_linear=True)
    f2.cuda()
    
    params = chain(g.parameters(), f1.parameters(), f2.parameters())
    optim = torch.optim.Adam(lr=1e-4, params=params)

    model_input, ground_truth = next(iter(videoloader))
    model_input, ground_truth = model_input[0].cuda(), ground_truth[0].cuda()

    batch_size = (v.H * v.W) // 4
    for step in range(total_steps):
        start = (step * batch_size) % len(model_input)
        end = min(start + batch_size, len(model_input))

        xy = model_input[start:end, :-1].requires_grad_()
        t = model_input[start:end, [-1]].requires_grad_()
        h = g(t)
        xy_ = apply_homography(xy, h)
        o_scene = f1(xy_)
        o_moire = f2(torch.cat((xy, t), -1))
        o = o_scene + o_moire
        loss_recon = ((o - ground_truth[start:end]) ** 2).mean()
        loss_interf = o_moire.abs().mean()

        g_scene = jacobian(o_scene, xy_)
        g_moire = jacobian(o_moire, xy)
        n_scene = (g_moire.norm(dim=0, keepdim=True) / g_scene.norm(dim=0, keepdim=True)).sqrt()
        n_moire = (g_scene.norm(dim=0, keepdim=True) / g_moire.norm(dim=0, keepdim=True)).sqrt()
        loss_excl = (torch.tanh(n_scene * g_scene) * torch.tanh(n_moire * g_moire)).pow(2).mean()

        loss = loss_recon + lambda_interf * loss_interf + lambda_excl * loss_excl

        if verbose and not step % steps_til_summary:
            print("Step [%04d/%04d]: recon=%0.4f, interf=%0.4f, excl=%0.4f" % (step, total_steps, loss_recon, loss_interf, loss_excl))

        optim.zero_grad()
        loss.backward()
        optim.step()

    return g, f1, f2, v.video

In [3]:
g, f1, f2, orig = train('./data/moire', 3000)

Step [0000/3000]: recon=0.4119, interf=0.0673, excl=0.0026
Step [0100/3000]: recon=0.0291, interf=0.1817, excl=0.1533
Step [0200/3000]: recon=0.0213, interf=0.1586, excl=0.1740
Step [0300/3000]: recon=0.0150, interf=0.1306, excl=0.1748
Step [0400/3000]: recon=0.0121, interf=0.1096, excl=0.1853
Step [0500/3000]: recon=0.0082, interf=0.0942, excl=0.1972
Step [0600/3000]: recon=0.0082, interf=0.0831, excl=0.2062
Step [0700/3000]: recon=0.0065, interf=0.0750, excl=0.2126
Step [0800/3000]: recon=0.0062, interf=0.0681, excl=0.2167
Step [0900/3000]: recon=0.0046, interf=0.0643, excl=0.2189
Step [1000/3000]: recon=0.0052, interf=0.0606, excl=0.2141
Step [1100/3000]: recon=0.0033, interf=0.0577, excl=0.2100
Step [1200/3000]: recon=0.0030, interf=0.0556, excl=0.2102
Step [1300/3000]: recon=0.0040, interf=0.0541, excl=0.2062
Step [1400/3000]: recon=0.0033, interf=0.0528, excl=0.2100
Step [1500/3000]: recon=0.0024, interf=0.0512, excl=0.1931
Step [1600/3000]: recon=0.0022, interf=0.0500, excl=0.19

In [4]:
with torch.no_grad():
    N, _, H, W = orig.size()
    xyt = get_mgrid([H, W, N]).cuda()
    h = g(xyt[:, [-1]])
    o_scene = f1(apply_homography(xyt[:, :-1], h))
    o_moire = f2(xyt)
    o_scene = o_scene.view(H, W, N, 3).permute(2, 0, 1, 3).cpu().detach().numpy()
    o_moire = o_moire.view(H, W, N, 3).permute(2, 0, 1, 3).cpu().detach().numpy()
    o_scene = (np.clip(o_scene * 0.5 + 0.5, 0, 1) * 255).astype(np.uint8)
    o_moire = (np.clip(o_moire * 0.5 + 0.5, 0, 1) * 255).astype(np.uint8)
    o_scene = [o_scene[i] for i in range(len(o_scene))]
    o_moire = [o_moire[i] for i in range(len(o_moire))]
    orig = orig.permute(0, 2, 3, 1).detach().numpy()
    orig = ((orig * 0.5 + 0.5) * 255).astype(np.uint8)
    orig = [orig[i] for i in range(len(orig))]

In [5]:
# Save out video
# ! pip install --user imageio imageio-ffmpeg
import imageio
fn_orig = os.path.join('./data/moire_orig.mp4')
fn_scene = os.path.join('./data/moire_scene.mp4')
fn_moire = os.path.join('./data/moire_interf.mp4')
imageio.mimwrite(fn_orig, orig, fps=1)
imageio.mimwrite(fn_scene, o_scene, fps=1)
imageio.mimwrite(fn_moire, o_moire, fps=1)

# Display video inline
from IPython.display import HTML
from base64 import b64encode
data_url_orig = "data:video/mp4;base64," + b64encode(open(fn_orig, 'rb').read()).decode()
data_url_scene = "data:video/mp4;base64," + b64encode(open(fn_scene, 'rb').read()).decode()
data_url_moire = "data:video/mp4;base64," + b64encode(open(fn_moire, 'rb').read()).decode()
HTML(f'''
<video width=512 controls autoplay loop>
      <source src="{data_url_orig}" type="video/mp4">
</video>
<video width=512 controls autoplay loop>
      <source src="{data_url_scene}" type="video/mp4">
</video>
<video width=512 controls autoplay loop>
      <source src="{data_url_moire}" type="video/mp4">
</video>
''')