## Install PyAV
torchvision requires pyav as video backend

In [None]:
!pip install av

## Imports

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_video, write_video
import torchvision.transforms.functional as F_t
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_url

In [None]:
from IPython.display import HTML
from base64 import b64encode

## Network

In [None]:
class FourierLayer(nn.Module):
    def __init__(self, in_features, out_features, scale):
        super().__init__()
        B = torch.randn(in_features, out_features)*scale
        self.register_buffer("B", B)
    
    def forward(self, x):
        x_proj = torch.matmul(2*math.pi*x, self.B)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        return out

In [None]:
class SignalRegressor(nn.Module):
    def __init__(self, in_features, fourier_features,
                 hidden_features, hidden_layers, out_features, scale):
        super().__init__()

        self.net = []
        if fourier_features is not None:
            self.net.append(FourierLayer(in_features, fourier_features, scale))
            self.net.append(nn.Linear(2*fourier_features, hidden_features))
            self.net.append(nn.ReLU())
        else:
            self.net.append(nn.Linear(in_features, hidden_features))
            self.net.append(nn.ReLU())
        
        for i in range(hidden_layers-1):
            self.net.append(nn.Linear(hidden_features, hidden_features))
            self.net.append(nn.ReLU())
        
        self.net.append(nn.Linear(hidden_features, out_features))
        self.net.append(nn.Sigmoid())
        self.net = nn.Sequential(*self.net)
    
    def forward(self, x):
        out = self.net(x)
        return out

## Dataset

In [None]:
class VideoDataset(Dataset):
    def __init__(self, video_path):
        super().__init__()
        self.vframes, self.aframes, self.info = read_video(video_path, pts_unit="sec")
        self.nframes = self.vframes.shape[0]
    
    def __getitem__(self, idx):
        img = self.vframes[idx]
        img = F_t.convert_image_dtype(img, dtype=torch.float)

        height_axis = torch.linspace(0, 1, steps=img.shape[0])
        width_axis = torch.linspace(0, 1, steps=img.shape[1])
        time_axis = torch.tensor(idx/(self.nframes-1))
        tt, hh, ww = torch.meshgrid(time_axis, height_axis, width_axis)
        grid = torch.stack([tt, hh, ww], dim=-1).squeeze()

        return grid, img
    
    def __len__(self):
        return self.nframes

## Display Video

In [None]:
web_url = "https://sample-videos.com/video123/mp4/240/big_buck_bunny_240p_1mb.mp4"
download_url(web_url, ".", "bunny.mp4")

In [None]:
video_path = "bunny.mp4"

mp4 = open(video_path,'rb').read()
video_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""<video controls> <source src="%s" type="video/mp4"> </video>""" % video_url)

## Dataloader

In [None]:
video_data = VideoDataset(video_path)
video_loader = DataLoader(video_data, batch_size=8)

## Train and Evaluate

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_epochs = 1000
summary_interval = 100

video_regressor = SignalRegressor(in_features=3, fourier_features=256,
                                  hidden_features=256, hidden_layers=4, out_features=3, scale=5).to(device)

optim = torch.optim.Adam(lr=1e-3, params=video_regressor.parameters())

for epoch in range(1, total_epochs+1):
    video_regressor.train()
    for grids, imgs in video_loader:
        grids, imgs = grids[::2, ::2, ::2], imgs[::2, ::2, ::2]  # use every other frame and every other pixel for training
        grids, imgs = grids.to(device), imgs.to(device)
        coords, rgbs = grids.reshape(-1, 3), imgs.reshape(-1, 3)
        optim.zero_grad()
        output = video_regressor(coords)
        train_loss = F.mse_loss(output, rgbs)
        train_loss.backward()
        optim.step()

    if not epoch % summary_interval:
        video_regressor.eval()
        test_loss = 0
        test_count = 0
        with torch.no_grad():
            for grids, imgs in video_loader:
                grids, imgs = grids.to(device), imgs.to(device)
                coords, rgbs = grids.reshape(-1, 3), imgs.reshape(-1, 3)  # use all the pixels for evaluation
                prediction = video_regressor(coords)
                batch_loss = F.mse_loss(prediction, rgbs)
                batch_size = coords.shape[0]
                test_loss = (test_loss*test_count + batch_loss*batch_size) / (test_count + batch_size)
                test_count += batch_size
            test_psnr = -10*torch.log10(test_loss)
            print(f"Epoch: {epoch}, Test PSNR: {test_psnr.item():.6f}")

## SpatioTempral Super-Resolution Result
increase video frame rate by 2x \
increase frame resolution by 2x

In [None]:
super_video = []
video_regressor.eval()
with torch.no_grad():
    for grids, imgs in video_loader:
        grids, imgs = grids.to(device), imgs.to(device)
        coords, rgbs = grids.reshape(-1, 3), imgs.reshape(-1, 3)
        prediction = video_regressor(coords)
        super_video.append(prediction)
    super_video = torch.cat(super_video, dim=0)
    super_video = super_video.reshape_as(video_data.vframes)
    super_video = F_t.convert_image_dtype(super_video, dtype=torch.uint8)

In [None]:
super_path = "bunny_super.mp4"
write_video(super_path, super_video.cpu(), fps=video_data.info["video_fps"])

# display the super-resolved video
mp4 = open(super_path,'rb').read()
video_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""<video controls> <source src="%s" type="video/mp4"> </video>""" % video_url)