# Notebook: pyolimp - Neural Network Examples

In [None]:
from typing import Callable
import torch
from torch import Tensor
from olimp.precompensation._demo import demo
from olimp.precompensation._demo_cvd import demo as demo_cvd

* DWDN

In [None]:
from olimp.precompensation.nn.models.dwdn import PrecompensationDWDN


def demo_dwdn(
    image: Tensor, psf: Tensor, progress: Callable[[float], None]
) -> Tensor:
    model = PrecompensationDWDN.from_path(path="hf://RVI/dwdn.pt")

    with torch.inference_mode():
        inputs = model.preprocess(image, psf.to(torch.float32))
        progress(0.1)
        (precompensation,) = model(inputs, **model.arguments(inputs, psf))
        progress(1.0)
        return precompensation


demo("DWDN", demo_dwdn, mono=True, num_output_channels=3)

* USRNET

In [None]:
from olimp.precompensation.nn.models.usrnet import PrecompensationUSRNet


def demo_usrnet(
    image: Tensor,
    psf: Tensor,
    progress: Callable[[float], None],
) -> Tensor:
    model = PrecompensationUSRNet.from_path(path="hf://RVI/usrnet.pth")
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf, scale_factor=1, noise_level=0)

        progress(0.1)
        (precompensation,) = model(inputs)
        progress(1.0)
        return precompensation


demo("USRNET", demo_usrnet, mono=True, num_output_channels=3)

* CVAE

In [None]:
from olimp.precompensation.nn.models.cvae import CVAE


def demo_cvae(
    image: Tensor,
    psf: Tensor,
    progress: Callable[[float], None],
) -> Tensor:
    model = CVAE.from_path("hf://RVI/cvae.pth")
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf)
        progress(0.1)
        (precompensation, mu, logvar) = model(inputs)
        progress(1.0)
        return precompensation


demo("CVAE", demo_cvae, mono=True)

* VAE

In [None]:
from olimp.precompensation.nn.models.vae import VAE


def demo_vae(
    image: Tensor,
    psf: Tensor,
    progress: Callable[[float], None],
) -> Tensor:
    model = VAE.from_path("hf://RVI/vae.pth")
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf)
        progress(0.1)
        precompensation, _mu, _logvar = model(inputs)
        progress(1.0)
        return precompensation


demo("VAE", demo_vae, mono=True)

* UNET - efficientnet-b0

In [None]:
from olimp.precompensation.nn.models.unet_efficient_b0 import (
    PrecompensationUNETB0,
)


def demo_unet(
    image: torch.Tensor,
    psf: torch.Tensor,
    progress: Callable[[float], None],
) -> torch.Tensor:
    model = PrecompensationUNETB0.from_path(
        "hf://RVI/unet-efficientnet-b0.pth"
    )
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf)
        progress(0.1)
        (precompensation,) = model(inputs)
        progress(1.0)
        return precompensation


demo("UNET", demo_unet, mono=True)

* UNETVAE

In [None]:
from olimp.precompensation.nn.models.unetvae import UNETVAE


def demo_unetvae(
    image: Tensor,
    psf: Tensor,
    progress: Callable[[float], None],
) -> Tensor:
    model = UNETVAE.from_path("hf://RVI/unetvae.pth")
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf)
        progress(0.1)
        precompensation, _mu, _logvar = model(inputs)
        progress(1.0)
        return precompensation


demo("UNETVAE", demo_unetvae, mono=True)

* VDSR

In [None]:
from olimp.precompensation.nn.models.vdsr import VDSR


def demo_vdsr(
    image: Tensor,
    psf: Tensor,
    progress: Callable[[float], None],
) -> Tensor:
    model = VDSR.from_path("hf://RVI/vdsr.pth")
    with torch.inference_mode():
        psf = psf.to(torch.float32)
        inputs = model.preprocess(image, psf)
        progress(0.1)
        (precompensation,) = model(inputs)
        progress(1.0)
        return precompensation


demo("VDSR", demo_vdsr, mono=True)

* CVD-SWIN

In [None]:
from olimp.precompensation.nn.models.cvd_swin.cvd_swin_4channels import (
    CVDSwin4Channels,
)
from olimp.simulate.color_blindness_distortion import ColorBlindnessDistortion


def demo_cvd_swin(
    image: Tensor,
    distortion: ColorBlindnessDistortion,
    progress: Callable[[float], None],
) -> tuple[torch.Tensor]:
    svd_swin = CVDSwin4Channels.from_path()
    image = svd_swin.preprocess(image, hue_angle_deg=torch.tensor([0.0]))
    progress(0.1)
    precompensation = svd_swin(image)
    progress(1.0)
    return (svd_swin.postprocess(precompensation[0]),)


distortion = ColorBlindnessDistortion.from_type("protan")
demo_cvd(
    "CVD-SWIN",
    demo_cvd_swin,
    distortion=distortion,
)