In [1]:
import numpy as np
from PIL import Image
from skopt.space import Real
from pssr.crappifiers import AdditiveGaussian, Poisson, Crappifier
from pssr.data import PairedImageDataset
from pssr.train import approximate_crappifier
from pssr.predict import predict_collage

In [2]:
# gaussian_crappifier = AdditiveGaussian(mean=0, deviation=5)
# poisson_crappifier = Poisson(intensity=1.5)

# dataset = ImageDataset("data/EM_hr_1_10", crappifier=poisson_crappifier)

In [3]:
dataset = PairedImageDataset("testdata/EM_pairs_crop/hr", "testdata/EM_pairs_crop/lr", hr_res=512, lr_res=128, rotation=False)

In [4]:
# 13, -2
gaussian_space = [Real(0, 15),
         Real(-5, 5),
]

# 1, -2
poisson_space = [Real(0, 5),
         Real(-5, 5),
]

In [5]:
crappfier = AdditiveGaussian
space = gaussian_space

# crappfier = Poisson
# space = poisson_space

In [6]:
# result = approximate_crappifier(crappfier, space, dataset, n_samples=42, opt_kwargs=dict(n_calls=100, n_initial_points=10))

In [7]:
# result.x

In [9]:
# crappfier = crappfier(*result.x)
# crappfier = Poisson(1, -10)

In [10]:
import numpy as np
from PIL import Image

def _collage_preds(lr, hr_hat, hr, image_range : int = 255, max_images : int = 5):
    lr, hr_hat, hr = [_image_stack(data, image_range, max_images) for data in (lr, hr_hat, hr)]
    lr = lr.resize((hr.width, hr.height), Image.Resampling.NEAREST)
    if hr_hat.size != hr.size:
         hr_hat = hr_hat.resize((hr.width, hr.height), Image.Resampling.NEAREST)

    return _image_stack([lr, hr_hat, hr], image_range, raw=False)

def _image_stack(data, image_range, max_images : int = 5, raw : bool = True):
    images = [Image.fromarray(image, mode="L") for image in (np.clip(data.detach().cpu().numpy()[:min(max_images, len(data)), 0], 0, image_range)*(255//image_range)).astype(np.uint8)] if raw else data
    width, height = images[0].width, images[0].height
    stack = Image.new("L", (width, height*len(images))) if raw else Image.new("L", (width*len(images), height))
    for idx, image in enumerate(images):
        if raw:
            stack.paste(image, (0, height*idx))
        else:
            stack.paste(image, (width*idx, 0))
    return stack

In [11]:
import torch, os
from torch.utils.data import DataLoader

prefix = "poisson"
n_images = 8
batch_size = 8

train_dataloader = DataLoader(dataset, batch_size, sampler=range(0, dataset.val_len))

collage = Image.new("L", (dataset.hr_res*3, dataset.hr_res*n_images))
remaining = n_images
for idx, (hr, lr) in enumerate(train_dataloader):
    images = [torch.tensor(np.asarray(Image.fromarray(hr[0].numpy()).resize([dataset.lr_res]*2, Image.Resampling.BILINEAR))) for hr in hr]
    images = torch.stack(images).unsqueeze(1)
    images = crappfier.crappify(images)
    
    collage.paste(_collage_preds(lr, images, hr, 255, min(remaining, batch_size)), (0, dataset.hr_res*batch_size*idx))

    remaining -= batch_size
    if remaining <= 0:
        break

os.makedirs("preds", exist_ok=True)
collage.save(f"preds/{prefix}collage{n_images}.png")

torch.Size([8, 1, 128, 128])


TypeError: Cannot handle this data type: (1, 1, 128, 128), |u1