# Using `zea.Models`: UNet Example for Ultrasound Image Inpainting

In this notebook, we demonstrate how to use the `zea.Models` interface with a popular deep learning architecture: the UNet. We'll use a pretrained UNet to inpaint missing regions in ultrasound images, and visualize the results. This workflow can be adapted for other tasks and models in the `zea` toolbox.

In [1]:
import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

In [2]:
import matplotlib.pyplot as plt
from keras import ops

from zea import init_device, log
from zea.backend.tensorflow.dataloader import make_dataloader
from zea.models.unet import UNet
from zea.models.lpips import LPIPS
from zea.agent.masks import random_uniform_lines
from zea.visualize import plot_image_grid, set_mpl_style

[1m[38;5;36mzea[0m[0m: Using backend 'torch'


E0000 00:00:1749846265.545158    4624 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749846265.551585    4624 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749846265.568635    4624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749846265.568656    4624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749846265.568658    4624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749846265.568660    4624 computation_placer.cc:177] computation placer already registered. Please check linka

We will work with the GPU if available, and initialize using `init_device` to pick the best available device. Also, (optionally), we will set the matplotlib style for plotting.

In [3]:
device = init_device(verbose=False)
set_mpl_style()

## Load Data

We load a small batch from the CAMUS validation dataset hosted on Hugging Face Hub.

In [4]:
n_imgs = 8
import torch
with torch.device(device):

    val_dataset = make_dataloader(
        "hf://zeahub/camus-sample/val",
        key="data/image",
        batch_size=n_imgs,
        shuffle=True,
        image_size=[128, 128],
        resize_type="resize",
        image_range=[-60, 0],
        normalization_range=[-1, 1],
        seed=42,
    )
    batch = next(iter(val_dataset))
    batch = ops.clip(batch, -1, 1)

[1m[38;5;36mzea[0m[0m: Using pregenerated dataset info file: [33m/home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/dataset_info.yaml[0m ...
[1m[38;5;36mzea[0m[0m: ...for reading file paths in [33m/home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val[0m
[1m[38;5;36mzea[0m[0m: Dataset was validated on [32mJune 13, 2025[0m
[1m[38;5;36mzea[0m[0m: Remove [33m/home/devcontainer15/.cache/zea/huggingface/datasets/datasets--zeahub--camus-sample/snapshots/617cf91a1267b5ffbcfafe9bebf0813c7cee8493/val/validated.flag[0m if you want to redo validation.
[1m[38;5;36mzea[0m[0m: H5Generator: Shuffled data.


RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory

## Load UNet Model

We use a pretrained UNet model from `zea` for inpainting.

In [5]:
presets = list(UNet.presets.keys())
log.info(f"Available built-in zea presets for UNet: {presets}")

model = UNet.from_preset("unet-echonet-inpainter")

[1m[38;5;36mzea[0m[0m: Available built-in zea presets for UNet: ['unet-echonet-inpainter']


## Simulate Missing Data

We simulate missing data by masking out random columns in each image (e.g., 75% missing). This is a common scenario in cognitive ultrasound where some scanlines may be missing (i.e. not acquired) or corrupted.

In [7]:
n_columns = 128#batch.shape[2]
mask = random_uniform_lines(n_columns // 4, n_columns, n_imgs)
corrupted = batch * ops.cast(mask[:, None, :, None], batch.dtype)

TypeError: ones() received an invalid combination of arguments - got (device=str, dtype=torch.dtype, size=Tensor, ), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)


## Inpaint with UNet

We use the UNet to inpaint the missing regions.

In [10]:
import torch

# Create a random batch with the correct shape and device
random_batch = torch.randn((n_imgs, n_columns, n_columns, 1), device=device)
inpainted = model(random_batch)
inpainted = ops.clip(inpainted, -1, 1)

## Evaluate Perceptual Similarity

We use the LPIPS metric to evaluate perceptual similarity between the ground truth and inpainted images. For more detailed example of this metric, see [this notebook](lpips_example.ipynb).

In [12]:
lpips = LPIPS.from_preset("lpips")
lpips_scores = lpips([inpainted, inpainted])
lpips_scores = ops.convert_to_numpy(lpips_scores)

## Visualization

We plot the ground truth, corrupted, inpainted, and error images. The LPIPS score is shown on each inpainted image. Note that this model was trained on the EchoNet-Dynamic dataset, whereas we are testing now on the CAMUS dataset.

In [13]:
error = ops.abs(batch - inpainted)
imgs = ops.concatenate([batch, corrupted, inpainted, error], axis=0)
imgs = ops.convert_to_numpy(imgs)

cmaps = ["gray"] * (3 * n_imgs) + ["viridis"] * n_imgs

fig, _ = plot_image_grid(
    imgs,
    vmin=-1,
    vmax=1,
    ncols=n_imgs,
    remove_axis=False,
    cmap=cmaps,
    figsize=(n_imgs * 2, 6),
)

titles = ["Ground Truth", "Corrupted", "Inpainted", "Error"]
for i, ax in enumerate(fig.axes[: len(titles) * n_imgs]):
    if i % n_imgs == 0:
        ax.set_ylabel(titles[i // n_imgs])

# Show LPIPS score on each inpainted image
for ax, lpips_score in zip(fig.axes[n_imgs * 2 : 3 * n_imgs], lpips_scores):
    ax.text(
        0.95,
        0.95,
        f"LPIPS: {float(lpips_score):.4f}",
        ha="right",
        va="top",
        transform=ax.transAxes,
        fontsize=8,
        color="yellow",
    )
plt.show()

NameError: name 'batch' is not defined

You can try other UNet presets or experiment with different masking strategies to explore the capabilities of `zea.Models`!