In [None]:
from pathlib import Path

import numpy as np
import torch
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
def open_image(path: Path) -> torch.Tensor:
    img = Image.open(path)
    imgs = []
    for i in range(img.n_frames):
        img.seek(i)
        imgs.append(torch.from_numpy(np.array(img).astype(np.float32)))

    return torch.stack(imgs, dim=0)


# NOTE: The below data files are only present when cloning the repo and not when pip installing the package.
image = open_image("tests/data/yale/light_field_image.tif").to(device)
measured_psf = open_image("tests/data/yale/measured_psf.tif").to(device)
mirrored_psf = torch.flip(measured_psf, dims=(-2, -1)).to(device)

image.shape, measured_psf.shape, mirrored_psf.shape

(torch.Size([1, 2048, 2048]),
 torch.Size([41, 2048, 2048]),
 torch.Size([41, 2048, 2048]))

In [None]:
# @torch.compile
def compute_step_f(
    data: torch.Tensor,  # [k, n, n]
    image: torch.Tensor,  # [1, n, n]
    PSF_fft: torch.Tensor,  # [k, n, n/2+1]
    PSFt_fft: torch.Tensor,  # [k, n, n/2+1]
) -> torch.Tensor:
    """Single step of the multiplicative Richardson-Lucy deconvolution algorithm."""
    denom = torch.fft.irfft2(PSF_fft * torch.fft.rfft2(data), dim=(-2, -1)).sum(dim=0, keepdim=True)  # [1, n, n]
    img_err = image / denom
    return data * torch.fft.fftshift(torch.fft.irfft2(torch.fft.rfft2(img_err) * PSFt_fft), dim=(-2, -1))  # [k, n, n]


jitted_fn = torch.jit.script(compute_step_f)

In [18]:
guess = (torch.ones_like(measured_psf) * 0.5).to(device)  # [k, n, n]

psf_fft = torch.fft.rfft2(measured_psf, dim=(-2, -1)).to(device)  # [k, n, n/2+1]
psft_fft = torch.fft.rfft2(mirrored_psf, dim=(-2, -1)).to(device)  # [k, n, n/2+1]

guess.shape, psf_fft.shape, psft_fft.shape

(torch.Size([41, 2048, 2048]),
 torch.Size([41, 2048, 1025]),
 torch.Size([41, 2048, 1025]))

In [19]:
import time

In [20]:
for _ in range(10):
    start = time.time()
    guess = jitted_fn(guess, image, psf_fft, psft_fft)
    print(f"Time taken: {time.time() - start} seconds")

Time taken: 0.004533052444458008 seconds
Time taken: 0.0034372806549072266 seconds
Time taken: 0.0006814002990722656 seconds
Time taken: 0.0006372928619384766 seconds
Time taken: 0.0006268024444580078 seconds
Time taken: 0.0006577968597412109 seconds
Time taken: 0.0006549358367919922 seconds
Time taken: 0.0006091594696044922 seconds
Time taken: 0.0006577968597412109 seconds
Time taken: 0.0005857944488525391 seconds


In [21]:
torch.jit.save(jitted_fn, "richardson_lucy_step.pt")

In [22]:
loaded_fn = torch.jit.load("richardson_lucy_step.pt")

In [23]:
for _ in range(10):
    start = time.time()
    guess = loaded_fn(guess, image, psf_fft, psft_fft)
    print(f"Time taken: {time.time() - start} seconds")

Time taken: 0.00412750244140625 seconds
Time taken: 0.0041065216064453125 seconds
Time taken: 0.0007188320159912109 seconds
Time taken: 0.0006849765777587891 seconds
Time taken: 0.0006461143493652344 seconds
Time taken: 0.0006556510925292969 seconds
Time taken: 0.0006611347198486328 seconds
Time taken: 0.0006539821624755859 seconds
Time taken: 0.0006382465362548828 seconds
Time taken: 0.0006551742553710938 seconds


In [27]:
def make_circle_mask(
    radius: int,
) -> np.ndarray:
    y, x = np.ogrid[: 2 * radius, : 2 * radius]
    circle_mask = (x - radius) ** 2 + (y - radius) ** 2 <= radius**2
    return circle_mask.astype(np.float32)


def post_processing(
    O: np.ndarray,  # noqa: E741
    center: tuple[int, int],
    radius: int,
) -> np.ndarray:
    circle_mask = np.expand_dims(make_circle_mask(radius), axis=0)
    sub_o = O[:, center[0] - radius : center[0] + radius, center[1] - radius : center[1] + radius]
    return sub_o * circle_mask


proc_guess = post_processing(guess.detach().cpu().numpy(), (1000, 980), 230)

In [28]:
img = Image.fromarray(np.array(proc_guess[0]))
img.save(
    "pytorch_img.tif",
    format="tiff",
    append_images=[Image.fromarray(np.array(proc_guess[i])) for i in range(1, len(proc_guess))],
    save_all=True,
)