In [None]:
import os
import numpy as np
from itertools import chain

import torch
from torch.utils.data import DataLoader

from model import Siren
from util import get_mgrid, jacobian, VideoFitting

In [None]:
def train(path, total_steps, lambda_interf=0.01, lambda_flow=0.02, verbose=True, steps_til_summary=100):
    g = Siren(in_features=3, out_features=3, hidden_features=256,
              hidden_layers=5, outermost_linear=True)
    g.cuda()
    f1 = Siren(in_features=3, out_features=3, hidden_features=256, 
               hidden_layers=5, outermost_linear=True)
    f1.cuda()
    f2 = Siren(in_features=3, out_features=1, hidden_features=256,
               hidden_layers=5, outermost_linear=True)
    f2.cuda()

    optim = torch.optim.Adam(lr=1e-4, params=chain(g.parameters(), f1.parameters(), f2.parameters()))

    v = VideoFitting(path)
    videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)
    model_input, ground_truth = next(iter(videoloader))
    model_input, ground_truth = model_input[0].cuda(), ground_truth[0].cuda()

    batch_size = (v.H * v.W) // 4
    for step in range(total_steps):
        start = (step * batch_size) % len(model_input)
        end = min(start + batch_size, len(model_input))

        xyt = model_input[start:end].requires_grad_()
        h = g(xyt)
        xy_, w = xyt[:, :-1] + h[:, :-1], h[:, [-1]]
        o_scene = torch.sigmoid(f1(torch.cat((xy_, w), -1)))
        o_rain = torch.sigmoid(f2(xyt))
        o = (1 - o_rain) * o_scene + o_rain
        loss_recon = (o - ground_truth[start:end]).abs().mean()
        loss_interf = o_rain.abs().mean()
        loss_flow = jacobian(h, xyt).abs().mean()
        loss = loss_recon + lambda_interf * loss_interf + lambda_flow * loss_flow

        if not step % steps_til_summary:
            print("Step [%04d/%04d]: recon=%0.8f, interf=%0.4f, flow=%0.4f" % (step, total_steps, loss_recon, loss_interf, loss_flow))

        optim.zero_grad()
        loss.backward()
        optim.step()
    
    return g, f1, f2, v.video

In [None]:
g, f1, f2, orig = train('./data/rain', 5000)

In [None]:
with torch.no_grad():
    N, _, H, W = orig.size()
    xyt = get_mgrid([H, W, N]).cuda()
    h = g(xyt)
    xy_, w = xyt[:, :-1] + h[:, :-1], h[:, [-1]]
    o_scene = torch.sigmoid(f1(torch.cat((xy_, w), -1)))
    o_rain = torch.sigmoid(f2(xyt))
    o_scene = o_scene.view(H, W, N, 3).permute(2, 0, 1, 3).cpu().detach().numpy()
    o_rain = o_rain.view(H, W, N).permute(2, 0, 1).cpu().detach().numpy()
    o_scene = (o_scene * 255).astype(np.uint8)
    o_rain = (o_rain * 255).astype(np.uint8)
    o_scene = [o_scene[i] for i in range(len(o_scene))]
    o_rain = [o_rain[i] for i in range(len(o_rain))]
    orig = orig.permute(0, 2, 3, 1).detach().numpy()
    orig = (orig * 255).astype(np.uint8)
    orig = [orig[i] for i in range(len(orig))]

In [None]:
# Save out video
# ! pip install --user imageio imageio-ffmpeg
import imageio
fn_orig = os.path.join('./data/rain_orig.mp4')
fn_scene = os.path.join('./data/rain_scene.mp4')
fn_rain = os.path.join('./data/rain_interf.mp4')
imageio.mimwrite(fn_orig, orig, fps=1)
imageio.mimwrite(fn_scene, o_scene, fps=1)
imageio.mimwrite(fn_rain, o_rain, fps=1)

# Display video inline
from IPython.display import HTML
from base64 import b64encode
data_url_orig = "data:video/mp4;base64," + b64encode(open(fn_orig, 'rb').read()).decode()
data_url_scene = "data:video/mp4;base64," + b64encode(open(fn_scene, 'rb').read()).decode()
data_url_rain = "data:video/mp4;base64," + b64encode(open(fn_rain, 'rb').read()).decode()
HTML(f'''
<video width=512 controls autoplay loop>
      <source src="{data_url_orig}" type="video/mp4">
</video>
<video width=512 controls autoplay loop>
      <source src="{data_url_scene}" type="video/mp4">
</video>
<video width=512 controls autoplay loop>
      <source src="{data_url_rain}" type="video/mp4">
</video>
''')