In [None]:
# imports

import os
import os.path
import cv2
import glob
import h5py
import tqdm
import argparse
import logging
from PIL import Image 
import skimage
import skimage.io

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
sns.set_theme()
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

import data, utils, models

In [None]:
# definition for loading model from a pretrained network file

def load_model(PATH, parallel=True, pretrained=True, old=True, load_opt=False):
    state_dict = torch.load(PATH, map_location="cpu")
    args = argparse.Namespace(**{**vars(state_dict["args"])})
    if old:
        vars(args)['blind_noise'] = False
    
    model = models.build_model(args)
    # model.load_state_dict(state_dict["model"][0])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    if load_opt:
        for o, state in zip([optimizer], state_dict["optimizer"]):
            o.load_state_dict(state)
    
    if pretrained:
        state_dict = torch.load(PATH)["model"][0]
        own_state = model.state_dict()
        # print(own_state)

        for name, param in state_dict.items():
            if parallel:
                name = name[7:]
            if name not in own_state:
                print("here", name)
                continue
            if isinstance(param, nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
        
    return model, optimizer, args

In [None]:
# load pretrained model

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

PATH = "pretrained/raw_video.pt"

model, optimizer, args = load_model(PATH)
model.to(device)
print(model)

In [None]:
# data loader

PATH = "datasets/RawVideo"

_, valid_loader, _ = data.build_dataset("RawVideo", PATH, batch_size=1, image_size=1080, stride=1920-1080, 
                                        n_frames=5, aug=0, scenes=[7,8,9,10,11], isos=[1600])

In [None]:
# Test

model.eval()

cpf = 1
mid = 2

valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}

model.eval()
for meter in valid_meters.values():
    meter.reset()

valid_bar = utils.ProgressBar(valid_loader)
running_valid_psnr = 0.0
plist = []
slist = []
for sample_id, (sample, noisy_inputs) in enumerate(valid_bar):
    with torch.no_grad():
        sample = sample.to(device)
        noisy_inputs = noisy_inputs.to(device)
        
        outputs, est_sigma = model(noisy_inputs)
        
        valid_psnr = utils.psnr(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=False, raw=True)
        valid_ssim = utils.ssim(sample[:, (mid*cpf):((mid+1)*cpf), :, :], outputs, normalized=False, raw=True)
        plist.append(valid_psnr)
        slist.append(valid_ssim)
        running_valid_psnr += valid_psnr
        valid_meters["valid_psnr"].update(valid_psnr.item())
        valid_meters["valid_ssim"].update(valid_ssim.item())
        
        valid_bar.log(dict(**valid_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)
        
running_valid_psnr /= (sample_id+1)
print("EVAL: "+valid_bar.print(dict(**valid_meters, lr=optimizer.param_groups[0]["lr"])))