In [1]:
import torch
from pssr.crappifiers import AdditiveGaussian, Poisson, Crappifier
from pssr.data import ImageDataset, PairedImageDataset
from pssr.models import ResUNet
from pssr.predict import predict_collage, predict_images, test_metrics

import numpy as np
from PIL import Image

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
n_images = 50
lr_scale = 4

In [4]:
crappifier = AdditiveGaussian(11, -2, 1)

In [5]:
val_dataset = ImageDataset("testdata/EM/hr_crop", hr_res=512, crappifier=crappifier, rotation=False)
test_dataset = PairedImageDataset("testdata/EM/pairs_align/hr", "testdata/EM/pairs_align/lr")

In [6]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

In [7]:
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    num_workers = 4,
    pin_memory = True,
)

In [8]:
model.load_state_dict(torch.load("trained/ssim_0.8_pairs_0.13.pth"))

<All keys matched successfully>

In [9]:
predict_collage(model, val_dataset, batch_size, device, prefix="val", dataloader_kwargs=kwargs)
predict_collage(model, test_dataset, batch_size, device, prefix="test", dataloader_kwargs=kwargs)

In [10]:
predict_images(model, test_dataset, device, out_dir="preds/outs", norm=True)

100%|██████████| 42/42 [00:00<00:00, 45.26it/s]


In [12]:
test_metrics(model, test_dataset, 16, device, dataloader_kwargs=kwargs)

100%|██████████| 3/3 [00:01<00:00,  2.34it/s]


{'mse': 0.0018880930812364177,
 'pixel': 11.069975609083793,
 'psnr': 27.256001846824386,
 'ssim': 0.621225908442369}