In [1]:
import torch
import matplotlib.pyplot as plt
from pssr.data import ImageDataset, PairedImageDataset
from pssr.crappifiers import AdditiveGaussian, Poisson, Crappifier
from pssr.models import ResUNet
from pssr.loss import SSIMLoss
from pssr.train import train_paired

import numpy as np
from PIL import Image

In [11]:
lr = 1e-5

hr_res = 512
lr_scale = 4

In [3]:
# class ScalePoisson(Crappifier):
#     def __init__(self, downres : int, intensity : float = 1, gain : float = 0):
#         self.downres = downres
#         self.intensity = intensity
#         self.gain = gain

#         self.resample = Image.Resampling.BICUBIC
        
#     def crappify(self, image : np.ndarray):
#         image = np.clip(self._interpolate(image, np.random.poisson(image/255*image.max())/image.max()*255) + self.gain, 0, 255)
#         return np.asarray(Image.fromarray(image).resize([self.downres]*2, resample=self.resample).resize(image.shape, resample=self.resample))
    
#     def _interpolate(self, x, y):
#         return x * (1 - self.intensity) + y * self.intensity

In [4]:
crappifier = Poisson(gain=-10, intensity=1)
# crappifier = ScalePoisson(125, 1, -2)
dataset = ImageDataset("testdata/EM_hr_crop", hr_res, crappifier=crappifier)

In [5]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

In [6]:
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    shuffle = True,
    num_workers = 4,
    pin_memory = True,
)

In [7]:
model.load_state_dict(torch.load("trained/poisson_scale_512_.6_0.070.pth"))

<All keys matched successfully>

In [12]:
# Train on mix=.8 until converges, then same on mix=.6
loss_fn = SSIMLoss(mix=.6, ms=True)
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.1, patience=3, threshold=5e-3, verbose=True)

losses = train_paired(
    model=model,
    dataset=dataset,
    batch_size=batch_size,
    loss_fn=loss_fn,
    optim=optim,
    epochs=10,
    device=device,
    scheduler=scheduler,
    log_frequency=50,
    dataloader_kwargs=kwargs,
)

pixel[12.68], psnr[25.65], ssim[0.636]:  38%|███▊      | 229/608 [02:42<04:29,  1.40it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(losses)

In [None]:
torch.save(model.state_dict(), f"model_{hr_res}_{losses[-1]:.3f}.pth")

In [13]:
torch.save(model.state_dict(), f"poisson_scale_512_.6_0.075.pth")