In [2]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import os
from tqdm import tqdm
import cv2
import numpy as np
import dataloader
import imageio
import metrics

In [3]:
torch.manual_seed(42)
np.random.seed(42)

In [4]:
# loads frame sets that are hard to predict
# used to quickly validate model training and progress

In [5]:
# video frames HARD!
# 4 examples from the RedBullVideo

In [7]:
HARD_TEST_INSTANCES = {
    'input_videos/redbull480.mp4': [116, 763, 1089, 1253]
}

In [15]:
def get_frame_by_caption(cap, index):
    cap.set(cv2.CAP_PROP_POS_FRAMES, index)
    _, frame = cap.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    return frame
    
    

def generate_hard_input(quadratic=False):
    
    
    if quadratic:
        l = [-2, -1, 1, 2]
    else:
        l = [-1, 1]
    
    
    for filename, frame_indices in HARD_TEST_INSTANCES.items():
        
#         video, _, _ = torchvision.io.read_video(filename)

        cap = cv2.VideoCapture(filename)
    
        
        for index in frame_indices:
            frame_indices = [index+i for i in l]
            X = [get_frame_by_caption(cap, i) for i in frame_indices]
            X = torch.from_numpy(np.array(X))
            X = X.permute(0,3,1,2).float()
            print('shape', X.shape)
            y = get_frame_by_caption(cap, index)
            y = torch.from_numpy(y).permute(2,0,1).float()
#             X = video[frame_indices].permute(0,3,1,2).float()
            X = X.unsqueeze(dim=0)
            print('shape2', X.shape, y.shape)
            X = X.unbind(dim=0)
            
            yield X, y
            
        
            
            

            
        

In [16]:
gen = generate_hard_input()

next(gen)

shape torch.Size([2, 3, 480, 720])
shape2 torch.Size([1, 2, 3, 480, 720]) torch.Size([3, 480, 720])


((tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],
  
  
          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
         

In [6]:
def create_grid(y_hat, y, **args):
    '''
    Creates image grid by stacking predictions y_hat and 
    ground truth y.
    
    y_hat: predicted inputs, shape B x 3 x H x W
    y: true values, same shape as y_hat
    '''
    
    assert y_hat.shape == y.shape
    
    inp_tensor = torch.cat([y_hat, y])
    grid = torchvision.utils.make_grid(inp_tensor, **args)
    
    return grid
    
    

In [46]:
def evalute_hard_images(model):
    gen = generate_hard_input(quadratic=False)

    y_hats = []
    ys = []
    for (x1, x2), y in gen:
        x1 = x1.cuda() / 255.
        x2 = x2.cuda() / 255.
        y = y / 255.
        
        y_hat = G(x1, x2).detach().cpu().squeeze(dim=0)
        y_hats.append(y_hat)
        ys.append(y)


    y_hats = torch.stack(y_hats).clamp(0,1)
    ys = torch.stack(ys)
    
    # create image grid
    grid = create_grid(y_hats, ys, padding=20, nrow=4)
    
    return grid
    
    

In [50]:
%%time
grid = evalute_hard_images(G)

plt.imshow(grid.permute(1,2,0))

In [10]:
for name in tqdm(names):
    # init summarywriter
    writer = SummaryWriter(f'runs/{name}')

    # load model
    G = torch.load(f'models/generator_{name}')
    G = G.cuda()

    # eval hard input
    gen = generate_hard_input(quadratic=False)

    y_hats = []
    ys = []
    for (x1, x2), y in gen:
        x1 = x1.cuda() / 255.
        x2 = x2.cuda() / 255.
        y = y / 255.
        
        y_hat = G(x1, x2).detach().cpu().squeeze(dim=0)
        y_hats.append(y_hat)
        ys.append(y)


    y_hats = torch.stack(y_hats).clamp(0,1)
    ys = torch.stack(ys)
    

    # create image grid
    grid = create_grid(y_hats, ys, padding=20, nrow=4)

    writer.add_image('hard_examples', img_tensor=grid)

100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [01:17<00:00,  9.70s/it]


In [51]:
# eval

In [8]:
ds = dataloader.adobe240_dataset()

n_train = int(len(ds) * 0.8)
n_valid = len(ds) - n_train


_, valid = torch.utils.data.random_split(ds, [n_train, n_valid])
dl = torch.utils.data.DataLoader(valid, batch_size=2)

In [9]:
names = os.listdir('runs')
names = names[1:]

filepaths = [f'models/generator_{name}' for name in names]
filepaths.append('models/benchmark_generator_sepconv')

In [11]:
results = {}
for model_path in reversed(filepaths):
    # init summarywriter
    
    # load model
    G = torch.load(model_path)
    G = G.cuda()
    G = G.eval()
    
    psnrs = []
    
    with torch.no_grad():
        for (x1, x2), y in tqdm(dl, total=len(dl), desc=model_path):
            x1 = x1.permute(0,3,1,2).cuda() / 255
            x2 = x2.permute(0,3,1,2).cuda() / 255
            y = y.permute(0,3,1,2).cuda()

            y_hat = G(x1, x2).mul(255).clamp(0,255).int()

            psnr = metrics.psnr(y_hat, y)
            psnrs.extend(psnr)
            
    results[model_path] = psnrs
        
    
        
    

models/benchmark_generator_sepconv: 100%|██████████████████████████████████████████| 1198/1198 [10:29<00:00,  1.90it/s]
models/generator_1586764929_0.0001_1e-05_False: 100%|██████████████████████████████| 1198/1198 [10:26<00:00,  1.91it/s]
models/generator_1586716841_0.0001_1e-05_True: 100%|███████████████████████████████| 1198/1198 [10:27<00:00,  1.91it/s]
models/generator_1586697759_0.0001_0_False: 100%|██████████████████████████████████| 1198/1198 [10:27<00:00,  1.91it/s]
models/generator_1586688260_0.0001_0_True: 100%|███████████████████████████████████| 1198/1198 [10:26<00:00,  1.91it/s]
models/generator_1586675975_1e-05_1e-05_False: 100%|███████████████████████████████| 1198/1198 [10:29<00:00,  1.90it/s]
models/generator_1586606508_1e-05_1e-05_True: 100%|████████████████████████████████| 1198/1198 [10:27<00:00,  1.91it/s]
models/generator_1586588130_1e-05_0_False: 100%|███████████████████████████████████| 1198/1198 [10:25<00:00,  1.91it/s]
models/generator_1586529004_1e-05_0_True