In [None]:
import torch
import matplotlib.pyplot as plt
from pssr.data import ImageDataset
from pssr.crappifiers import AdditiveGaussian, Poisson
from pssr.models import ResUNet
from pssr.loss import SSIMLoss
from pssr.train import train_paired

In [None]:
lr = 1e-3

hr_res = 512
lr_scale = 4

In [None]:
# crappifier = AdditiveGaussian(mean=-2, deviation=13)
crappifier = Poisson(gain=-2, intensity=1)
dataset = ImageDataset("data/EM_hr_1_10", hr_res, crappifier=crappifier, val_split=0.1)

In [None]:
model = ResUNet(
    channels=1,
    hidden=[64, 128, 256, 512, 1024],
    scale=lr_scale,
    depth=3,
)

In [None]:
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = dict(
    shuffle = True,
    num_workers = 4,
    pin_memory = True,
)

In [None]:
# model.load_state_dict(torch.load("model_poisson_512_0.081.pth"))

In [None]:
# Train on mix=.8 until converges, then same on mix=.6
loss_fn = SSIMLoss(mix=.8, ms=True)
optim = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.1, patience=5, threshold=5e-3, verbose=True)

losses = train_paired(
    model=model,
    dataset=dataset,
    batch_size=batch_size,
    loss_fn=loss_fn,
    optim=optim,
    epochs=20,
    device=device,
    scheduler=scheduler,
    log_frequency=50,
    dataloader_kwargs=kwargs,
)

In [None]:
plt.plot(losses)

In [None]:
# torch.save(model.state_dict(), f"model_{hr_res}_{losses[-1]:.3f}.pth")