In [None]:
import torch
import matplotlib.pyplot as plt
from pssr.data import ImageDataset, PairedImageDataset
from pssr.crappifiers import AdditiveGaussian, Poisson, Crappifier
from pssr.models import ResUNet, ResUNetA
from pssr.loss import SSIMLoss
from pssr.train import train_paired

In [None]:
lr = 1e-3

hr_res = 512
lr_scale = 4

In [None]:
crappifier = Poisson(intensity=0.9, gain=-2, spread=0.1)
# crappifier = AdditiveGaussian(intensity=11, gain=-2, spread=1)
dataset = ImageDataset("testdata/EM/hr_crop", hr_res, crappifier=crappifier)

In [None]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

# dilations = [
#     [1,3,15,31],
#     [1,3,15],
#     [1,3],
#     [1],
#     [1],
# ]
# model = ResUNetA(
#     channels=1,
#     hidden=[64, 128, 256, 512, 1024],
#     dilations=dilations,
#     scale=lr_scale,
#     depth=2,
# )

In [None]:
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    num_workers = 4,
    pin_memory = True,
)

In [None]:
loss_fn = SSIMLoss(mix=.8, ms=True)
# loss_fn = torch.nn.MSELoss()
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.1, patience=3, threshold=5e-3)

losses = train_paired(
    model=model,
    dataset=dataset,
    batch_size=batch_size,
    loss_fn=loss_fn,
    optim=optim,
    epochs=5,
    device=device,
    scheduler=scheduler,
    log_frequency=50,
    dataloader_kwargs=kwargs,
)

In [None]:
plt.plot(losses)

In [None]:
torch.save(model.state_dict(), f"model_{hr_res//lr_scale}-{hr_res}_{losses[-1]:.3f}.pth")