In [None]:
from __future__ import annotations

import json
import hashlib
import os.path
from pathlib import Path
import pickle
import random
from time import perf_counter
from typing import Optional

import cv2
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from skimage import filters, restoration
import tifffile

from histalign.backend.ccf.downloads import (
    download_annotation_volume,
    download_structure_mask,
)
from histalign.backend.ccf.paths import get_annotation_path, get_structure_mask_path
from histalign.backend.io import load_volume
from histalign.backend.models import Resolution

In [56]:
def imshow(
    image: np.ndarray,
    title: str | None = None,
    figsize: tuple[int, int] | None = None,
    cmap: str | None = "gray",
) -> None:
    global _distinct_colours

    _ = plt.figure(figsize=figsize)

    if title is not None:
        plt.suptitle(title)
    plt.axis(False)

    if cmap == "distinct":
        cmap = generate_distinct_cmap(image)

    plt.imshow(image, cmap=cmap)

    plt.tight_layout()
    plt.show()


def get_annotation_contours(image: np.ndarray) -> list[list[np.ndarray]]:
    contours = []
    for value in np.unique(image):
        contours.append(
            cv2.findContours(
                (image == value).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE
            )[0]
        )

    return contours


def get_structures_contours(
    structures: list[str],
    index: tuple[int, int, int],
    resolution: Resolution.MICRONS_100,
) -> list[np.ndarray]:
    index = tuple(slice(None) if value == -1 else value for value in index)

    contours = []
    for structure in structures:
        volume_path = get_structure_mask_path(structure, resolution)
        if not os.path.exists(volume_path):
            download_structure_mask(structure, resolution)

        volume = load_volume(volume_path, return_raw_array=True)

        image = volume[index]
        contours += cv2.findContours(
            (image > 0).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE
        )[0]

    return contours


def draw_contours(image: np.ndarray, contours: list[np.ndarray]) -> np.ndarray:
    cv2.drawContours(image, contours, -1, (255, 255, 255))


def save_contours(contours: list[np.ndarray], file_path: str | Path) -> None:
    masked_array = None
    largest_shape = max(map(lambda x: x.shape, contours), key=lambda x: x[0])

    for contour in contours:
        resized_array = contour.astype(np.uint16)
        resized_array.resize(largest_shape)

        mask = np.ones(contour.shape, dtype=bool)
        mask.resize(largest_shape)
        mask = ~mask

        new_masked_array = np.ma.array(resized_array, mask=mask)[np.newaxis]

        if masked_array is None:
            masked_array = new_masked_array
        else:
            masked_array = np.ma.concatenate([masked_array, new_masked_array], axis=0)

    np.savez_compressed(file_path, data=masked_array.data, mask=masked_array.mask)


def load_masked_array(file_path: str | Path) -> np.ma.MaskedArray:
    with np.load(file_path) as handle:
        data = handle["data"].astype(np.int32)
        mask = handle["mask"]

    return np.ma.array(data, mask=mask)


def load_contours(file_path: str | Path) -> list[np.ndarray]:
    masked_array = load_masked_array(file_path)

    return [
        masked_array[i][~masked_array.mask[i]].data.reshape(-1, 1, 2)
        for i in range(masked_array.shape[0])
    ]


def chaikins_corner_cutting(coords: np.ndarray, refinements: int = 5):
    coords = np.array(coords)

    for _ in range(refinements):
        L = coords.repeat(2, axis=0)
        R = np.empty_like(L)
        R[0] = L[0]
        R[2::2] = L[1:-1:2]
        R[1:-1:2] = L[2::2]
        R[-1] = L[-1]
        coords = L * 0.75 + R * 0.25

    return coords

In [None]:
CONTOUR_CACHE_PATH = Path("contour_cache.pkl")
CONTOUR_CACHE_MAP_PATH = Path("contour_cache_map.json")

Contour = np.ndarray


class ContourCache:
    auto_save: bool
    max_contours: int
    max_points: int

    _contours: list[list[Contour]]
    _count: int
    _key_contours_dict: dict[str, int]

    def __init__(
        self, auto_save: bool = True, max_contours: int = 2_000, max_points: int = 5_000
    ) -> None:
        """A cache object granting access to the local contour cache.

        The default sizes should result in at most a 20MB file on disk for the contours
        (and another smaller one for the dictionary) and around twice that when loaded
        in memory. This doubling is due to the fact that OpenCV needs contours as int32
        but we store them as uint16.

        Args:
            auto_save (bool, optional):
                Whether the case should save after every additional and removal.
            max_contours (int, optional):
                How many contours the cache should hold at any one point. This is an
                upper limit and might never be 100% utilised.
            max_points (int, optional):
                How many points any one contour is allowed to have.
        """
        self.auto_save = auto_save
        self.max_contours = max_contours
        self.max_points = max_points

        self._contours = []
        self._count = 0
        self._key_contours_dict = {}

    @property
    def entries(self) -> list[str]:
        return list(self._key_contours_dict.keys())

    def has_contours(self, key: str) -> bool:
        """Returns whether the given key is associated with cached contours.

        Note that providing a key, even obtained from an add operation, is not
        guaranteed to return `True`. If the cache grew and dropped the contours
        associated with this key, the function will return `False`.

        Args:
            key (str): Key to retrieve the contours with.

        Returns:
            bool: Whether the key is associated with cached contours.
        """
        return key in self._key_contours_dict.keys()

    def insert_contours(
        self,
        contours: list[Contour],
        key: str = "",
    ) -> str:
        """Adds contours to the cache.

        Args:
            contours (list[Contour]): Contours to cache.
            key (str, optional):
                Optional key to assign to the contours. If omitted, it is automatically
                generated.

        Returns:
            str:
                The key associated with the contours. This can be used to retrieve the
                contours if they are still in the case.

        Raises:
            ValueError:
                When attempting to add contours with more points than allowed or when
                attempting to add more contours than the cache can store.
        """
        max_points = max(map(lambda x: x.shape[0], contours))
        if max_points > self.max_points:
            raise ValueError(
                f"Tried adding a contour with more points than allowed "
                f"({max_points} > {self.max_points})."
            )

        contour_count = len(contours)
        if contour_count > self.max_contours:
            raise ValueError(
                f"Tried adding more contours in one go than the cache allows "
                f"({contour_count} > {self.max_contours})."
            )

        key = key or self.generate_key(contours)
        exists = self.has_contours(key)

        # Ensure we don't go over the contour limit
        total_count = self._count + contour_count
        if exists:
            total_count -= len(self._contours[self._key_contours_dict[key]])

        if total_count > self.max_contours:
            removed_count = self._invalidate_contours(total_count - self.max_contours)
            total_count -= removed_count

        if exists:
            self._contours[self._key_contours_dict[key]] = contours
        else:
            self._contours.append(contours)
            self._key_contours_dict[key] = len(self._contours) - 1

        self._count = total_count

        if self.auto_save:
            self.save()

        return key

    def load_contours(self, key: str, allow_missing: bool = False) -> list[Contour]:
        """Retrieves the contours associated with the given key.

        Args:
            key (str): Key used to retrieve the contours.
            allow_missing (bool, optional):
                Whether to error out when the key is not associated with any contours.

        Returns:
            list[Contour]:
                A list containing the retrieved contours or an empty list if no contours
                exist and `allow_missing` is `True`.

        Raises:
            ValueError: When the cache misses and `allow_missing` is `False`.
        """
        if not self.has_contours(key):
            if allow_missing:
                return []
            raise ValueError(f"No cached contours for the given key '{key}'.")

        return self._contours[self._key_contours_dict[key]]

    def pop_contours(
        self, key: str, allow_missing: bool = False, decrement_indices: bool = True
    ) -> list[Contour]:
        """Pops the contours associated with the given key from the cache.

        Args:
            key (str): Key used to retrieve the contours.
            allow_missing (bool, optional):
                Whether to error out when the key is not associated with any contours.
            decrement_indices (bool, optional):
                Whether to decrement other indices on removal.

        Returns:
            list[Contour]:
                The removed contours or an empty list if not contours exist and
                `allow_missing` is `True`.

        Raises:
            ValueError: When the cache misses and `allow_missing` is `False`.
        """
        if not self.has_contours(key):
            if allow_missing:
                return []
            raise ValueError(f"No cached contours for the given key '{key}'.")

        index = self._key_contours_dict.pop(key)
        contours = self._contours.pop(index)

        self._count -= len(contours)
        if decrement_indices:
            self._decrement_indices(index)

        if self.auto_save:
            self.save()

        return contours

    # noinspection PyTypeChecker
    def save(self) -> None:
        """Saves the cache to disk."""
        with open(CONTOUR_CACHE_PATH, "wb") as handle:
            pickle.dump(self._contours, handle)

        with open(CONTOUR_CACHE_MAP_PATH, "wb") as handle:
            pickle.dump(self._key_contours_dict, handle)

    @classmethod
    def load(cls) -> ContourCache:
        """Loads and returns the contour cache.

        Returns:
            ContourCache: The contour cache.
        """
        instance = cls()

        if CONTOUR_CACHE_PATH.exists():
            with open(CONTOUR_CACHE_PATH, "rb") as handle:
                instance._contours = pickle.load(handle)

            if CONTOUR_CACHE_MAP_PATH.exists():
                with open(CONTOUR_CACHE_MAP_PATH, "rb") as handle:
                    instance._key_contours_dict = pickle.load(handle)

            instance._count = sum(map(lambda x: len(x), instance._contours))

        return instance

    @staticmethod
    def generate_key(contours: list[Contour]) -> str:
        """Generates a key for the provided contours.

        Returns:
            str: The key generated from the contours.
        """
        keys = [hashlib.md5(contour.tobytes()).hexdigest() for contour in contours]

        return hashlib.md5("".join(keys).encode("UTF-8")).hexdigest()

    def _decrement_indices(self, start_index: int) -> None:
        """Decrements all indices after `start_index`.

        Args:
            start_index (int):
        """
        for key, index in self._key_contours_dict.items():
            if index > start_index:
                self._key_contours_dict[key] = index - 1

    def _invalidate_contours(self, count: int) -> int:
        """Removes at least `count` contours from the cache.

        Returns:
            int: How many contours were removed.
        """
        i = 0
        removed = 0
        while removed < count:
            i += 1
            oldest_key = min(self._key_contours_dict.items(), key=lambda x: x[1])[0]

            contours = self.pop_contours(oldest_key, decrement_indices=False)
            removed += len(contours)

        self._decrement_indices(i)

        return removed

In [None]:
resolution = Resolution.MICRONS_25
shape = (528, 320, 456)
index = (-1, 50, -1)

with open("ccf_annotations_expanded.csv") as handle:
    contents = handle.read()

lines = contents.split("\n")[1:]
structures = list(map(lambda x: x.split(",")[2].strip(), lines))

structures.remove("Primary somatosensory area upper limb")
structures.remove("Retrosplenial area lateral agranular part")
structures.remove("Retrosplenial area dorsal part")
structures.remove("Retrosplenial area ventral part")
structures.remove("Primary somatosensory area barrel field")
structures.remove("Primary somatosensory area lower limb")
structures.remove("Primary somatosensory area mouth")
structures.remove("Primary somatosensory area nose")
structures.remove("Primary somatosensory area trunk")
structures.remove("Primary somatosensory area upper limb")
structures.remove("Primary somatosensory area unassigned")
structures.remove("Posteromedial visual area")
structures.remove("Anterior visual area")
structures.remove("Laterointermediate visual area")
structures.remove("Rostrolateral area")

structures += [
    "Primary somatosensory area, upper limb",
    "Retrosplenial area, lateral agranular part",
    "Retrosplenial area, dorsal part",
    "Retrosplenial area, ventral part",
    "Primary somatosensory area, barrel field",
    "Primary somatosensory area, lower limb",
    "Primary somatosensory area, mouth",
    "Primary somatosensory area, nose",
    "Primary somatosensory area, trunk",
    "Primary somatosensory area, upper limb",
    "Primary somatosensory area, unassigned",
    "posteromedial visual area",
    # "Anterior visual area",
    # "Laterointermediate visual area",
    # "Rostrolateral area",
]

cache = ContourCache.load()
cache.auto_save = False

In [None]:
for structure in structures:
    key = str([resolution.value, index, structure])
    if cache.has_contours(key):
        continue

    try:
        contours = get_structures_contours([structure], index, resolution)
    except KeyError:
        print(f"Unknown structure: {structure}")
        continue

    if len(contours) < 1:
        continue

    cache.insert_contours(contours, key)

cache.save()

In [None]:
all_contours = []
for structure in structures:
    key = str([resolution.value, index, structure])
    all_contours += cache.load_contours(key, allow_missing=True)

In [None]:
overlay = np.zeros(
    np.array(shape)[np.where(np.array(index) == -1)[0]].T, dtype=np.uint8
)
draw_contours(overlay, all_contours)

imshow(overlay, figsize=(20, 20))