In [None]:
import torch
from pssr.data import ImageDataset, SlidingDataset, PairedImageDataset, PairedSlidingDataset
from pssr.crappifiers import AdditiveGaussian, Poisson, SaltPepper, MultiCrappifier
from pssr.models import ResUNet, RDResUNet, SwinIR
from pssr.util import SSIMLoss
from pssr.train import train_paired

In [None]:
data_path = "your/path" # Folder to load images from
hr_res = 512 # Resolution of images or image tiles
lr_scale = 4 # Scale ratio between low-resolution and high-resolution images, shared between dataset and model
n_frames = -1 # Set to amount of stacked frames if using 2.5D or 3D data

# Crappifier parameters should be adjusted to match your data
crappifier = MultiCrappifier(Poisson(intensity=1.2, spread=0.05), SaltPepper(spread=0.1))

# Use SlidingDataset instead to load from image sheets (e.g. .czi files)
dataset = ImageDataset(data_path, hr_res, lr_scale, n_frames, crappifier)

# A ResUNet runs fast for the quality of the predictions, although any PyTorch model can be used
model = ResUNet(
    channels=1 if n_frames == -1 else n_frames,
    scale=lr_scale,
)

In [None]:
# Consider increasing or decreasing batch size for your amount of allocated memory
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = dict(
    num_workers = 4,
    pin_memory = True,
)

In [None]:
epochs = 10
lr = 1e-3

log_frequency = 50 # Decrease to log losses more often
save_checkpoints = True # Save model checkpoints
epoch_collage = True # View training progress as images (LR, PSSR, HR pairs)

# Simple scheduler options
factor = 0.1
patience = 3

# MS-SSIM loss typically trains faster than MSE loss
loss_fn = SSIMLoss()
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=factor, patience=patience, threshold=5e-3)

train_losses, val_losses = train_paired(
    model=model,
    dataset=dataset,
    batch_size=batch_size,
    loss_fn=loss_fn,
    optim=optim,
    epochs=epochs,
    device=device,
    scheduler=scheduler,
    log_frequency=log_frequency,
    checkpoint_dir="checkpoints" if save_checkpoints else None,
    collage_dir="collages" if epoch_collage else None,
    dataloader_kwargs=kwargs,
)

In [None]:
torch.save(model.state_dict(), f"model_{hr_res//lr_scale}-{hr_res}_{val_losses[-1]:.3f}.pth")

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses)