In [1]:
# imports

import os
import os.path
import cv2
import glob
import h5py
import tqdm
import argparse
import logging
from PIL import Image 
import skimage
import skimage.io

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
sns.set_theme()
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

import sys
sys.path.append('../')
import data, utils, models

In [2]:
# load pretrained model

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

PATH = "../pretrained/raw_video.pt"

model, optimizer, args = utils.load_model(PATH, parallel=True)
model.to(device)
print(model)

BlindVideoNet(
  (rotate): rotate()
  (denoiser_1): Blind_UNet(
    (enc1): ENC_Conv(
      (conv1): Conv(
        (shift_down): ZeroPad2d(padding=(0, 0, 1, 0), value=0.0)
        (crop): crop()
        (replicate): ReplicationPad2d((1, 1, 1, 1))
        (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (relu): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (conv2): Conv(
        (shift_down): ZeroPad2d(padding=(0, 0, 1, 0), value=0.0)
        (crop): crop()
        (replicate): ReplicationPad2d((1, 1, 1, 1))
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (relu): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (conv3): Conv(
        (shift_down): ZeroPad2d(padding=(0, 0, 1, 0), value=0.0)
        (crop): crop()
        (replicate): ReplicationPad2d((1, 1, 1, 1))
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (relu): LeakyReLU(negative_slope=0.1, inplace=True)
  

In [3]:
# data loader - change isoslist as per requirement, include all or just one value as below

PATH = "../datasets/RawVideo"

_, valid_loader, _ = data.build_dataset("RawVideo", PATH, batch_size=1, image_size=1080, stride=1920-1080, 
                                        n_frames=5, aug=0, scenes=[7,8,9,10,11], isos=[1600])

In [4]:
# Test

model.eval()

cpf = 1
mid = 2

valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}

model.eval()
for meter in valid_meters.values():
    meter.reset()

valid_bar = utils.ProgressBar(valid_loader)
running_valid_psnr = 0.0
plist = []
slist = []
for sample_id, (sample, noisy_inputs) in enumerate(valid_bar):
    with torch.no_grad():
        sample = sample.to(device)
        noisy_inputs = noisy_inputs.to(device)
        
        outputs, est_sigma = model(noisy_inputs)
        
        valid_psnr = utils.psnr(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=False, raw=True)
        valid_ssim = utils.ssim(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=False, raw=True)
        plist.append(valid_psnr)
        slist.append(valid_ssim)
        running_valid_psnr += valid_psnr
        valid_meters["valid_psnr"].update(valid_psnr.item())
        valid_meters["valid_ssim"].update(valid_ssim.item())
        
        valid_bar.log(dict(**valid_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)
        
running_valid_psnr /= (sample_id+1)
print("EVAL: "+valid_bar.print(dict(**valid_meters, lr=optimizer.param_groups[0]["lr"])))

                                                                                                                 

EVAL: valid_psnr 48.036 | valid_ssim 0.998 | lr 1.0e-04


