In [None]:
# imports

import os
import os.path
import cv2
import glob
import h5py
import tqdm
import argparse
import logging
from PIL import Image 

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
sns.set_theme()
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

import data, utils, models

In [None]:
# definition for loading model from a pretrained network fil

def load_model(PATH, parallel=False, pretrained=True, old=True, load_opt=True):
    state_dict = torch.load(PATH, map_location="cpu")
    args = argparse.Namespace(**{**vars(state_dict["args"])})
    if old:
        vars(args)['blind_noise'] = False
    
    model = models.build_model(args)
    # model.load_state_dict(state_dict["model"][0])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    if load_opt:
        for o, state in zip([optimizer], state_dict["optimizer"]):
            o.load_state_dict(state)
    
    if pretrained:
        state_dict = torch.load(PATH)["model"][0]
        own_state = model.state_dict()
        # print(own_state)

        for name, param in state_dict.items():
            if parallel:
                name = name[7:]
            if name not in own_state:
                print("here", name)
                continue
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
        
    return model, optimizer, args

In [None]:
# necessary variable definitions

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

dataset = "GoPro"

# load the desired video to denoise
# video = "snowboard"
# video = "hypersmooth"
video = "rafting"
# video = "motorbike"

patch_size = 128
stride = 64
is_image = False
n_frames = 5
cpf = 3
mid = n_frames // 2

aug = 0

dist = 'G'
mode = 'S'
noise_std = 30
min_noise = 0
max_noise = 100

batch_size = 8
lr = 1e-4
epochs = 100

In [None]:
# load the desired pretrained model, for Single Video Denoisining load the corresponding model

# # UDVD
# PATH = "pretrained/blind_video_net.pt"

# # SingleVideo - Set8/GoPro - snowboard
# PATH = "pretrained/single_video_Set8_snowboard_30.pt"

# # SingleVideo - Set8/GoPro - hypersmooth
# PATH = "pretrained/single_video_Set8_hypersmoth_30.pt"

# SingleVideo - Set8/GoPro - rafting
PATH = "pretrained/single_video_Set8_rafting_30.pt"

# # SingleVideo - Set8/GoPro - motorbike
# PATH = "pretrained/single_video_Set8_motorbike_30.pt"

model, optimizer, args = load_model(PATH, parallel=parallel, pretrained=pretrained, old=old, load_opt=load_opt)
model.to(device)
print(model)

In [None]:
# data loader

PATH = os.path.join("datasets/Set8", dataset)

train_loader, test_loader = data.build_dataset("SingleVideo", PATH, batch_size=batch_size, dataset=dataset, video=video, image_size=patch_size, stride=stride, n_frames=n_frames, 
                                               aug=aug, dist=dist, mode=mode, noise_std=noise_std, min_noise=min_noise, max_noise=max_noise,
                                               sample=True, heldout=False)

In [None]:
# Test

valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}
mean_meters = {name: utils.AverageMeter() for name in (["mean_psnr", "mean_ssim"])}

model.eval()
for meter in valid_meters.values():
    meter.reset()
for meter in mean_meters.values():
    meter.reset()

valid_bar = utils.ProgressBar(test_loader)
running_valid_psnr = 0.0
plist = []
slist = []
for sample_id, (sample, noisy_inputs) in enumerate(valid_bar):
    with torch.no_grad():
        sample = sample.to(device)
        noisy_inputs = noisy_inputs.to(device)
        
        out, est_sigma = model(noisy_inputs)
        
        noisy_frame = noisy_inputs[:, (mid*cpf):((mid+1)*cpf), :, :]
        outputs, mean_image = utils.post_process(out, noisy_frame, model="blind-video-net", sigma=noise_std/255, device=device)
        
        valid_psnr = utils.psnr(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=True, raw=False)
        valid_ssim = utils.ssim(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=True, raw=False)
        plist.append(valid_psnr)
        slist.append(valid_ssim)
        running_valid_psnr += valid_psnr
        valid_meters["valid_psnr"].update(valid_psnr.item())
        valid_meters["valid_ssim"].update(valid_ssim.item())

        mean_psnr = utils.psnr(sample[:, (mid*cpf):((mid+1)*cpf), :, :], mean_image, normalized=True, raw=False)
        mean_ssim = utils.ssim(sample[:, (mid*cpf):((mid+1)*cpf), :, :], mean_image, normalized=True, raw=False)
        mean_meters["mean_psnr"].update(mean_psnr.item())
        mean_meters["mean_ssim"].update(mean_ssim.item())
        
        valid_bar.log(dict(**valid_meters, **mean_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)
        
running_valid_psnr /= (sample_id+1)
print("EVAL: "+valid_bar.print(dict(**valid_meters, **mean_meters, lr=optimizer.param_groups[0]["lr"])))