In [None]:
import torch
from pssr.crappifiers import AdditiveGaussian, Poisson, Crappifier
from pssr.data import ImageDataset, PairedImageDataset
from pssr.models import ResUNet
from pssr.predict import predict_collage, predict_images, test_metrics

In [None]:
model_path = "trained/ssim_resunet_0.12.pth"
hr_res = 512
lr_scale = 4
batch_size = 16

crappifier = Poisson()

In [None]:
val_dataset = ImageDataset("testdata/EM/hr_crop", hr_res, lr_scale, crappifier, rotation=False)
test_dataset = PairedImageDataset("testdata/EM/pairs_align/hr", "testdata/EM/pairs_align/lr")

In [None]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

In [None]:
model.load_state_dict(torch.load(model_path))

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    num_workers = 4,
    pin_memory = True,
)

In [None]:
predict_collage(model, val_dataset, batch_size, device, prefix="val", dataloader_kwargs=kwargs)
predict_collage(model, test_dataset, batch_size, device, prefix="test", dataloader_kwargs=kwargs)

In [None]:
predict_images(model, test_dataset, device, out_dir="preds/outs", norm=True)

In [None]:
test_metrics(model, test_dataset, 16, device, dataloader_kwargs=kwargs)