# Visual and MAE test for proper registration

This tests allows for a quick and easy visual check to ensure the registration is carried out correctly by the `Registrator` class.
An expected and actual output are displayed for both forward and reverse registration, and a difference is also displayed for completeness. The MAE should be 0 if the registration is carried out correctly.

## Fill in the project root path

In [None]:
project_root = ""  # e.g., /home/$USERNAME$/git/histalign
%cd {project_root}

## Run the rest of the script

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from histalign.backend.ccf.paths import get_atlas_path
from histalign.backend.io import load_alignment_settings, load_image, load_volume
from histalign.backend.registration import Registrator
from histalign.backend.workspace import VolumeSlicer

In [None]:
def imshow(image: np.ndarray, title: str = "") -> None:
    figure, axes = plt.subplots()

    figure.suptitle(title)

    axes.imshow(image)
    axes.axis(False)

    plt.tight_layout()
    plt.show()

In [None]:
alignment_path = str(
    Path("tests/registration/resources/complete_alignment_settings.json")
)
alignment_settings = load_alignment_settings(alignment_path)

if (
    not alignment_settings.volume_path.is_file()
    or not alignment_settings.volume_path.suffixes[-1] == ".nrrd"
):
    atlas_path = get_atlas_path(alignment_settings.volume_settings.resolution)
    if not Path(atlas_path).exists():
        download_atlas(alignment_settings.volume_settings.resolution)

    alignment_settings.volume_path = atlas_path

In [None]:
image = load_image(alignment_settings.histology_path)
volume = load_volume(alignment_settings.volume_path)

registrator = Registrator(True, True)

In [None]:
for parameter in [
    "scale",
    "shear",
    "rotation",
    "translation",
    "offset",
    "pitch",
    "yaw",
    "complete",
]:
    current_settings = load_alignment_settings(
        f"tests/registration/resources/{parameter}_alignment_settings.json"
    )
    current_settings.volume_path = alignment_settings.volume_path

    # Forward registration
    forwarded_image = registrator.get_forwarded_image(image, current_settings)
    volume_image = VolumeSlicer(volume=volume).slice(current_settings.volume_settings)

    expected_forwarded_image = load_image(
        f"tests/registration/resources/{parameter}_expected_output1.npz"
    )
    actual_forwarded_image = np.where(
        forwarded_image > 10, forwarded_image * 3, volume_image
    )

    forward_composite_image = np.where(
        forwarded_image > 10, forwarded_image * 3, volume_image
    )

    difference_forwarded_image = expected_forwarded_image.astype(
        np.int16
    ) - actual_forwarded_image.astype(np.int16)

    # Optional visualisation
    # imshow(volume_image, "Volume image")
    # imshow(forwarded_image, "Registered histology")

    imshow(
        expected_forwarded_image,
        f"Expected forward registration output {{{parameter}}}",
    )
    imshow(
        actual_forwarded_image,
        f"Actual forward registration output {{{parameter}}}",
    )
    imshow(
        difference_forwarded_image,
        f"Difference {{{parameter}}} (MAE = {np.mean(np.abs(difference_forwarded_image))})",
    )

    # Reverser registration
    reversed_image = registrator.get_reversed_image(current_settings, "atlas", image)
    reverse_composite_image = np.where(image > 10, reversed_image, 0)

    expected_reversed_image = load_image(
        f"tests/registration/resources/{parameter}_expected_output2.npz"
    )
    actual_reversed_image = np.where(reversed_image, reversed_image, image * 3)

    slicing = (slice(None, None, 10), slice(None, None, 10))

    # Optional visualisation
    # imshow(reversed_image[slicing])
    # imshow(image[slicing])

    difference_reversed_image = expected_reversed_image.astype(
        np.int16
    ) - actual_reversed_image.astype(np.int16)

    imshow(
        expected_reversed_image[slicing],
        f"Expected reverse registration output {{{parameter}}}",
    )
    imshow(
        actual_reversed_image[slicing],
        f"Actual reverse registration output {{{parameter}}}",
    )
    imshow(
        difference_reversed_image[slicing],
        f"Difference {{{parameter}}} (MAE = {np.mean(np.abs(difference_reversed_image))})",
    )