In [1]:
import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
def compute_step_f(
    data: torch.Tensor,  # [k, n, n]
    image: torch.Tensor,  # [1, n, n]
    PSF_fft: torch.Tensor,  # [k, n, n/2+1]
    PSFt_fft: torch.Tensor,  # [k, n, n/2+1]
) -> torch.Tensor:
    """Single step of the multiplicative Richardson-Lucy deconvolution algorithm."""
    denom = torch.fft.irfft2(PSF_fft * torch.fft.rfft2(data), dim=(-2, -1)).sum(dim=0, keepdim=True)  # [1, n, n]
    img_err = image / denom
    return data * torch.fft.fftshift(torch.fft.irfft2(torch.fft.rfft2(img_err) * PSFt_fft), dim=(-2, -1))  # [k, n, n]


def richard_lucy_10(img: torch.Tensor, psf: torch.Tensor) -> torch.Tensor:
    PSF_fft = torch.fft.rfft2(psf)  # [k, n, n/2+1]
    PSFt_fft = torch.fft.rfft2(psf.flip([0, 1]))  # [k, n, n/2+1]
    data = torch.ones_like(img)  # [k, n, n]

    for _ in range(10):
        data = compute_step_f(data, img, PSF_fft, PSFt_fft)

    return data


jitted_fn = torch.jit.script(
    richard_lucy_10,
    example_inputs=(
        torch.zeros(1, 2048, 2048, dtype=torch.float32).to(device),
        torch.rand(41, 2014, 2048, dtype=torch.float32).to(device),
    ),
)
torch.jit.save(jitted_fn, "richardson_lucy_10.pt")

