# Movie Prediction Script

Script that uses the function `get_final_preds` to get predictions from a movie path or from a movie array in numpy format. The function will be integrated in the GUI developped in Matlab by Rado.

The function uses a (trained) UNet model to perform image segmentation and returns the segmentation and instances arrays.

Author: Prisca Dotti  
Last Modified: 23.10.2023.

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import os

import imageio
import napari
import numpy as np
import torch
from torch import nn

from config import TrainingConfig, config
from utils.training_inference_tools import get_final_preds
from utils.training_script_utils import init_model
from utils.visualization_tools import (
    get_annotations_contour,
    get_discrete_cmap,
    get_labels_cmap,
)

  from .autonotebook import tqdm as notebook_tqdm


Parameters that are necessary to configure the dataset and the UNet model (can be eventually hard-coded in the function)

In [3]:
### Set training-specific parameters ###

# Initialize training-specific parameters
config_path = os.path.join("config_files", "config_final_model.ini")
params = TrainingConfig(training_config_file=config_path)
params.run_name = "final_model"
model_filename = f"network_100000.pth"

[16:34:51] [  INFO  ] [   config   ] <290 > -- Loading C:\Users\prisc\Code\sparks_project\config_files\config_final_model.ini


Load UNet model

In [4]:
### Configure UNet ###
# params.set_device(device="auto")
params.set_device(device="cpu")  # temporary

network = init_model(params=params)

# Move the model to the GPU if available
if params.device.type != "cpu":
    network = nn.DataParallel(network).to(params.device, non_blocking=True)
    # cudnn.benchmark = True

### Load UNet model ###

# Path to the saved model checkpoint
models_relative_path = os.path.join(
    "models", "saved_models", params.run_name, model_filename
)
model_dir = os.path.realpath(os.path.join(config.basedir, models_relative_path))

# Load the model state dictionary
try:
    network.load_state_dict(torch.load(model_dir, map_location=params.device))
except RuntimeError as e:
    if "module" in str(e):
        # The error message contains "module," so handle the DataParallel loading
        print(
            "Failed to load the model, as it was trained with DataParallel. Wrapping it in DataParallel and retrying..."
        )
        # Get current device of the object (model)
        temp_device = next(iter(network.parameters())).device

        network = nn.DataParallel(network)
        network.load_state_dict(torch.load(model_dir, map_location=params.device))

        print("Network should be on CPU, removing DataParallel wrapper...")
        network = network.module.to(temp_device)
    else:
        # Handle other exceptions or re-raise the exception if it's unrelated
        raise

Failed to load the model, as it was trained with DataParallel. Wrapping it in DataParallel and retrying...
Network should be on CPU, removing DataParallel wrapper...


Define movie path

In [9]:
movie_path = r"C:\Users\prisc\Code\sparks_project\data\sparks_dataset\05_video.tif"

# movie_path = r"C:\Users\dotti\Desktop\cropped 34_video.tif"
# shape is (904, 53, 284)

Run sample in U-Net (using the function `get_final_preds`)

In [10]:
network.eval()
segmentation, instances = get_final_preds(
    model=network,
    params=params,
    movie_path=movie_path,
)

Save predicted segmentation and instances on disk

In [11]:
if movie_path is not None:
    # Get the movie filename
    movie_filename = os.path.splitext(os.path.basename(movie_path))[0]
else:
    # If no movie is provided, use a generic name
    movie_filename = "sample_movie"

# Set the output directory
out_dir = os.path.join(config.basedir, "evaluation", "matlab_inference_script")
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# Save the segmentation and instances on disk as .tif files
imageio.volwrite(
    os.path.join(out_dir, f"{movie_filename}_unet_segmentation.tif"),
    np.uint8(segmentation),
)
imageio.volwrite(
    os.path.join(out_dir, f"{movie_filename}_unet_instances.tif"), np.uint8(instances)
)

Visualize U-Net's predictions with Napari

In [11]:
# open original movie
sample = np.asarray(imageio.volread(movie_path))
# set up napari parameters
cmap = get_discrete_cmap(name="gray", lut=16)
labels_cmap = get_labels_cmap()
# visualize only border of classes (segmentation array)
segmentation_border = get_annotations_contour(segmentation)
viewer = napari.Viewer()
viewer.add_image(
    sample,
    name="input movie",
    # colormap=('colors',cmap)
)

viewer.add_labels(
    segmentation_border,
    name="segmentation",
    opacity=0.9,
    color=labels_cmap,
)  # only visualize border

viewer.add_labels(
    segmentation,
    name="segmentation",
    opacity=0.5,
    color=labels_cmap,
    visible=False,
)  # to visualize whole roi instead

viewer.add_labels(
    instances,
    name="instances",
    opacity=0.5,
)

<Labels layer 'instances' at 0x2585fdbfdf0>