In [1]:
import torch
from pssr.crappifiers import Crappifier
from pssr.data import ImageDataset, PairedImageDataset
from pssr.models import ResUNet
from pssr.predict import predict_collage, predict_images, test_metrics

import numpy as np
from PIL import Image

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
n_images = 50
lr_scale = 4

In [4]:
class ScalePoisson(Crappifier):
    def __init__(self, downres : int, intensity : float = 1, gain : float = 0):
        self.downres = downres
        self.intensity = intensity
        self.gain = gain

        self.resample = Image.Resampling.BICUBIC
        
    def crappify(self, image : np.ndarray):
        image = np.clip(self._interpolate(image, np.random.poisson(image/255*image.max())/image.max()*255) + self.gain, 0, 255)
        return np.asarray(Image.fromarray(image).resize([self.downres]*2, resample=self.resample).resize(image.shape, resample=self.resample))
    
    def _interpolate(self, x, y):
        return x * (1 - self.intensity) + y * self.intensity
    
crappifier = ScalePoisson(125, 1, 0)

In [5]:
val_dataset = ImageDataset("testdata/EM_hr_crop", hr_res=512, crappifier=crappifier, val_split=.1, rotation=False)
test_dataset = PairedImageDataset("testdata/EM_pairs_crop/hr", "testdata/EM_pairs_crop/lr", rotation=False)

In [6]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

In [7]:
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    num_workers = 4,
    pin_memory = True,
)

In [8]:
model.load_state_dict(torch.load("trained/poisson_scale_512_.6_0.075.pth"))

<All keys matched successfully>

In [9]:
predict_collage(model, val_dataset, batch_size, n_images, device, prefix="val", dataloader_kwargs=kwargs)
predict_collage(model, test_dataset, batch_size, 42, device, prefix="test", dataloader_kwargs=kwargs)

In [10]:
predict_images(model, test_dataset, device, out_dir="preds/poisson/outs", out_res=500)

In [11]:
metrics = ["mse", "pixel", "psnr", "ssim"]
test_metrics(model, test_dataset, 1, metrics, device, dataloader_kwargs=kwargs)

  0%|          | 0/42 [00:00<?, ?it/s]

100%|██████████| 42/42 [00:00<00:00, 63.09it/s]


{'mse': 0.003279957717971965,
 'pixel': 14.526859811883154,
 'psnr': 21.13236667996361,
 'ssim': 0.5753835723513648}