# Dataset and Loader


In [1]:
import torch
import numpy as np
from matplotlib import pyplot as plt

## Dummy Dataset


In [2]:
from data.dummy_dataset import DummyFSDataset, DummySimpleDataset

In [None]:
dummy_simple_dataset = DummySimpleDataset(
    "val",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_val_fold=0,
    split_test_size=0.2,
    split_test_fold=0,
    cache_data=False,
    dataset_name="Dummy",
)

print(len(dummy_simple_dataset))
img, msk, name, _ = dummy_simple_dataset[79]

print(
    name,
    img.shape,
    img.dtype,
    img.min(),
    img.max(),
    msk.shape,
    msk.dtype,
    torch.unique(msk),
)

plt.imshow(np.moveaxis(img.numpy(), 0, -1))
# plt.imshow(msk.numpy())

In [None]:
dummy_fs_dataset = DummyFSDataset(
    "train",
    3,
    (256, 256),
    max_items=25,
    seed=0,
    split_val_size=0.2,
    split_val_fold=0,
    split_test_size=0.2,
    split_test_fold=0,
    cache_data=False,
    dataset_name="Dummy",
    shot_options=[5, 10],
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        # ("grid", (10, 20)),
        # ("contour", "random"),
        # ("skeleton", (0.1, 0.5)),
        # ("region", 0.5),
    ],
    sparsity_params={},
    support_query_data="mixed",
    support_batch_mode="mixed",
    query_batch_size=10,
    split_query_size=0.5,
    split_query_fold=0,
    num_iterations=5.0,
)

print(len(dummy_fs_dataset.items))
print(len(dummy_fs_dataset), dummy_fs_dataset.num_iterations)

In [None]:
print(dummy_fs_dataset.support_batches)
for i in range(len(dummy_fs_dataset)):
    support, query, _ = dummy_fs_dataset[i]
    supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
    qry_img, qry_msk, qry_name = query
    print(
        supp_img.shape[0],
        supp_msk.shape[0],
        len(supp_name),
        "|",
        supp_sparsity_mode,
        [round(v, 2) if isinstance(v, float) else v for v in supp_sparsity_value]
        if isinstance(supp_sparsity_value, list)
        else supp_sparsity_value,
        "|",
        qry_img.shape[0],
        qry_msk.shape[0],
        len(qry_name),
    )

In [None]:
support, query, _ = dummy_fs_dataset[0]

supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
qry_img, qry_msk, qry_name = query

print(
    supp_img.shape,
    supp_img.dtype,
    supp_img.min(),
    supp_img.max(),
    supp_msk.shape,
    supp_msk.dtype,
    torch.unique(supp_msk),
)
print(supp_name, supp_sparsity_mode, supp_sparsity_value)
print(
    qry_img.shape,
    qry_img.dtype,
    qry_img.min(),
    qry_img.max(),
    qry_msk.shape,
    qry_msk.dtype,
    torch.unique(qry_msk),
)
print(qry_name)
print()

# plt.imshow(supp_msk[0].numpy())
# plt.imshow(qry_msk[0].numpy())

## Dummy Loader


In [None]:
from torch.utils.data import ConcatDataset, DataLoader

dummy_loader = DataLoader(
    ConcatDataset([dummy_fs_dataset]),
    batch_size=None,
    shuffle=dummy_fs_dataset.mode == "train",
    num_workers=0,
    pin_memory=True,
)

for batch in dummy_loader:
    support, query, dataset_name = batch

    print(type(batch.support))
    print(support[0].shape, support[1].shape, support[2][:4], support[3])
    print(query[0].shape, query[1].shape, query[2])
    print(dataset_name)

    break

# Sparse Masks


## Initialization


In [1]:
import numpy as np
from matplotlib import pyplot as plt

from data.few_sparse_dataset import FewSparseDataset
from data.typings import SparsityValue
from tasks.optic_disc_cup.datasets import (
    RimOne3TrainFSDataset,
    DrishtiTrainFSDataset,
    RefugeTrainFSDataset,
    RefugeValFSDataset,
)

plt.style.use("dark_background")


In [2]:
sparsity_values: dict[str, SparsityValue] = {
    "point": 25,
    "grid": 0.5,
    "contour": 0.5,
    "skeleton": 0.5,
    "region": 0.5,
}

In [3]:
def print_image_mask(image, mask):
    print(image.shape, image.dtype, image.min(), image.max())
    print(mask.shape, mask.dtype, np.unique(mask))


def plot_masks(mask, sparse_masks):
    n_rows = int(np.floor(len(sparse_masks) / 2)) + 1
    _, axs = plt.subplots(n_rows, 2, figsize=(5, n_rows * 2.5))
    assert isinstance(axs, np.ndarray)
    axs = axs.flat
    [ax.axis("off") for ax in axs]
    axs[0].imshow(mask)
    for i, sm in enumerate(sparse_masks.values()):
        axs[i + 1].imshow(sm)
    plt.tight_layout()

In [4]:
def plot_multiple_images_masks(
    dataset: FewSparseDataset, indices: list[int], keys: list[str]
):
    ncols = len(indices)
    nrows = len(keys)
    _, axs = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    assert isinstance(axs, np.ndarray)
    for c, index in enumerate(indices):
        image, mask, sparse_masks, _ = dataset.get_data_with_sparse_all(
            index, sparsity_values
        )
        r = 0
        if "image" in keys:
            axs[r, c].imshow(image)
            axs[r, c].axis("off")
            r += 1
        if "dense" in keys:
            axs[r, c].imshow(mask)
            axs[r, c].axis("off")
            r += 1
        for key, sm in sparse_masks.items():
            if key not in keys:
                continue
            axs[r, c].imshow(sm)
            axs[r, c].axis("off")
            r += 1
    plt.tight_layout()

## RIM-ONE-3-train


In [None]:
rim_one_3_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 5,
    "contour_radius_thick": 2.5,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.4,
}

rim_one_3_train_data = RimOne3TrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=rim_one_3_sparsity_params,
)

In [None]:
# image, mask, sparse_masks, _ = rim_one_3_data.get_data_with_sparse_all(0, sparsity_values)
# print_image_mask(image, mask)
# plot_masks(mask, sparse_masks)

In [None]:
plot_multiple_images_masks(
    rim_one_3_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## DRISHTI-GS-train


In [None]:
drishti_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 5,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.5,
}

drishti_train_data = DrishtiTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=drishti_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    drishti_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## REFUGE-train


In [5]:
refuge_train_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 7,
    "contour_radius_thick": 3,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.4,
}

refuge_train_data = RefugeTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_train_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    refuge_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## REFUGE-val


In [None]:
refuge_val_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 7,
    "contour_radius_thick": 3,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.5,
}

refuge_val_data = RefugeValFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_val_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    refuge_val_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## Publications


In [6]:
plt.style.use("default")

In [7]:
refuge_train_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 7,
    "contour_radius_thick": 3,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.4,
}

refuge_train_data = RefugeTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_train_sparsity_params,
)

image, mask, _ = refuge_train_data.get_data(11)

In [None]:
_, axs = plt.subplots(1, 1, figsize=(9, 9))
axs.axis("off")
axs.imshow(mask)

In [None]:
_, axs = plt.subplots(1, 4, figsize=(28, 8))
assert isinstance(axs, np.ndarray)
for c, sparsity_value in enumerate([13, 25, 37, 50]):
    sparse_mask = refuge_train_data.get_sparse_mask(
        "point", mask, image, sparsity_value
    )
    axs[c].axis("off")
    axs[c].imshow(sparse_mask)
plt.tight_layout(pad=3)

In [None]:
_, axs = plt.subplots(4, 4, figsize=(28, 28))
assert isinstance(axs, np.ndarray)
for r, sparse_mode in enumerate(["grid", "contour", "skeleton", "region"]):
    for c, sparsity_value in enumerate([0.25, 0.5, 0.75, 1.0]):
        sparse_mask = refuge_train_data.get_sparse_mask(
            sparse_mode, mask, image, sparsity_value
        )
        axs[r, c].axis("off")
        axs[r, c].imshow(sparse_mask)
plt.tight_layout(pad=3)

## Documents


In [5]:
plt.style.use("default")

In [6]:
refuge_train_sparsity_params: dict = {
    "contour_radius_dist": 15,
    "contour_radius_thick": 3,
}

refuge_train_data = RefugeTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_train_sparsity_params,
)

In [7]:
# image, mask, _ = refuge_train_data.get_data(2)
# mask = (mask == 2).astype(mask.dtype)

# sparse_mask = refuge_train_data.get_sparse_mask("contour", mask, image, 1.0)
# sparse_mask = (sparse_mask == 1).astype(sparse_mask.dtype)

# # horizontal mirror
# image = np.flip(image, axis=1)
# sparse_mask = np.flip(sparse_mask, axis=1)

# # vertical mirror
# image = np.flip(image, axis=0)
# sparse_mask = np.flip(sparse_mask, axis=0)

# # rotate 90 degrees
# image = np.rot90(image)
# sparse_mask = np.rot90(sparse_mask)

# # crop
# image = image[40:-40, 40:-40]
# sparse_mask = sparse_mask[40:-40, 40:-40]


# def rotate_and_crop(image, angle):
#     from scipy.ndimage import rotate

#     rotated_image = rotate(image, angle, reshape=False)
#     h, w = rotated_image.shape[:2]
#     crop_size = int(min(h, w) / np.sqrt(2))
#     start_x = (w - crop_size) // 2
#     start_y = (h - crop_size) // 2
#     cropped_image = rotated_image[
#         start_y : start_y + crop_size, start_x : start_x + crop_size
#     ]
#     return cropped_image


# image = rotate_and_crop(image, 45)
# sparse_mask = rotate_and_crop(sparse_mask, 45)


# def shear_and_crop(image):
#     from skimage.transform import AffineTransform, warp

#     transform = AffineTransform(shear=(0.35, 0.0))
#     sheared_image = warp(image, transform.inverse, mode="constant", cval=0)
#     h, w = sheared_image.shape[:2]
#     crop_size = int(min(h, w) / np.sqrt(2))
#     start_x = (w - crop_size) // 2
#     start_y = (h - crop_size) // 2
#     cropped_image = sheared_image[
#         start_y + 4 : start_y + crop_size, start_x - 37 : start_x + crop_size - 41
#     ]
#     return cropped_image


# image = shear_and_crop(image)
# sparse_mask = shear_and_crop(sparse_mask)

In [8]:
# image, _, _ = refuge_train_data.get_data(2)

# # increase brightness
# image = np.clip(image * 1.5, 0, 255).astype(image.dtype)


# def grayscale(image):
#     from skimage.color import rgb2gray

#     return rgb2gray(image)


# image = grayscale(image)


# def clahe(image):
#     from skimage.exposure import equalize_adapthist

#     image = equalize_adapthist(image, clip_limit=0.01)
#     return image


# image = clahe(image)


# def update_gamma(image):
#     from skimage.exposure import adjust_gamma

#     return adjust_gamma(image, 1.5)  # type: ignore


# image = update_gamma(image)

In [9]:
# from scipy.ndimage import gaussian_filter, map_coordinates

# _, mask, _ = refuge_train_data.get_data(2)
# mask = (mask == 2).astype(mask.dtype)

# sparse_mask = refuge_train_data.get_sparse_mask("contour", mask, image, 1.0)
# sparse_mask = (sparse_mask == 1).astype(sparse_mask.dtype)


# def generate_random_displacement(shape, alpha, sigma, random_state=None):
#     if random_state is None:
#         random_state = np.random.RandomState(None)
#     displacement = (random_state.rand(*shape) * 2 - 1) * alpha
#     return gaussian_filter(displacement, sigma, mode="constant", cval=0)


# def global_displacement(image, alpha, sigma):
#     shape = image.shape
#     dx = generate_random_displacement(shape, alpha, sigma)
#     dy = generate_random_displacement(shape, alpha, sigma)
#     x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
#     indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
#     return map_coordinates(image, indices, order=1, mode="reflect").reshape(shape)


# def directional_displacement(image, alpha, sigma, direction):
#     shape = image.shape
#     if direction == "horizontal":
#         dx = generate_random_displacement(shape, alpha, sigma)
#         dy = np.zeros_like(dx)
#     elif direction == "vertical":
#         dx = np.zeros_like(image)
#         dy = generate_random_displacement(shape, alpha, sigma)
#     x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
#     indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
#     return map_coordinates(image, indices, order=1, mode="reflect").reshape(shape)


# def abrupt_displacement(image, alpha, sigma):
#     shape = image.shape
#     dx = generate_random_displacement(shape, alpha, sigma // 2)
#     dy = generate_random_displacement(shape, alpha, sigma // 2)
#     x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
#     indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
#     return map_coordinates(image, indices, order=1, mode="reflect").reshape(shape)


# sparse_mask = global_displacement(sparse_mask, 150, 8)
# sparse_mask = directional_displacement(sparse_mask,150, 8, "vertical")
# sparse_mask = abrupt_displacement(sparse_mask, 30, 7)

In [10]:
# image, mask, _ = refuge_train_data.get_data(2)
# mask = (mask == 2).astype(mask.dtype)

# sparse_mask = refuge_train_data.get_sparse_mask("contour", mask, image, 1.0)
# sparse_mask = (sparse_mask == 1).astype(sparse_mask.dtype)

# def apply_gaussian_filter(image, sigma):
#     from scipy.ndimage import gaussian_filter
#     return gaussian_filter(image, sigma)

# image[:, :, 0] = apply_gaussian_filter(image[:, :, 0], 8)
# image[:, :, 1] = apply_gaussian_filter(image[:, :, 1], 8)
# image[:, :, 2] = apply_gaussian_filter(image[:, :, 2], 8)

# sparse_mask = apply_gaussian_filter(sparse_mask.astype(np.float32), 8)
# sparse_mask = sparse_mask * (1.0 / sparse_mask.max())


In [70]:
# from scipy.ndimage import binary_erosion, binary_dilation

# _, mask, _ = refuge_train_data.get_data(2)
# mask = (mask == 2).astype(mask.dtype)

# sparse_mask = refuge_train_data.get_sparse_mask("contour", mask, image, 1.0)
# sparse_mask = (sparse_mask == 1).astype(sparse_mask.dtype)


# def erosion(image, structure=None):
#     return binary_erosion(image, structure=structure).astype(image.dtype)


# def dilation(image, structure=None):
#     return binary_dilation(image, structure=structure).astype(image.dtype)


# def skeletonization(image, structure=None):
#     skeleton = np.zeros_like(image)
#     temp_image = image.copy()
#     while np.any(temp_image):
#         eroded = binary_erosion(temp_image, structure=structure)
#         skeleton_layer = temp_image & ~binary_dilation(eroded, structure=structure)  # type: ignore
#         skeleton |= skeleton_layer
#         temp_image = eroded
#     return skeleton


# sparse_mask = erosion(sparse_mask)
# sparse_mask = dilation(sparse_mask)
# sparse_mask = dilation(sparse_mask) - sparse_mask
# sparse_mask = skeletonization(sparse_mask)


In [68]:
# from skimage.color import rgb2gray
# from skimage.filters import sobel
# from skimage.segmentation import (
#     slic,
#     felzenszwalb,
#     quickshift,
#     watershed,
#     mark_boundaries,
# )
# from skimage.util import img_as_float

# image, _, _ = refuge_train_data.get_data(2)

# image_float = img_as_float(image)

# segments_slic = slic(image_float, n_segments=100, compactness=10, start_label=1)

# segments_felzenszwalb = felzenszwalb(image_float, scale=40, sigma=0.5, min_size=50)

# segments_quickshift = quickshift(image_float, kernel_size=3, max_dist=15, ratio=0.5)

# gradient = sobel(rgb2gray(image))
# segments_watershed = watershed(gradient, markers=100, compactness=0.0005)  # type: ignore

# image = mark_boundaries(image, segments_slic, color=(0, 0, 1))
# image = mark_boundaries(image, segments_felzenszwalb, color=(0, 0, 1))
# image = mark_boundaries(image, segments_quickshift, color=(0, 0, 1))
# image = mark_boundaries(image, segments_watershed, color=(0, 0, 1))


In [None]:
_, axs = plt.subplots(1, 1, figsize=(6, 6))
axs.axis("off")
axs.imshow(image, cmap="gray")

In [None]:
_, axs = plt.subplots(1, 1, figsize=(6, 6))
axs.axis("off")
axs.imshow(sparse_mask, cmap="gray")

# Region Analysis


In [1]:
import numpy as np
import pandas as pd
import altair as alt

In [None]:
from skimage import segmentation
from numpy.typing import NDArray

from data.typings import FewSparseDatasetKwargs
from tasks.optic_disc_cup.datasets import (
    DrishtiTestFSDataset,
    RefugeTestFSDataset,
    RimOne3TestFSDataset,
)

  "class": algorithms.Blowfish,


In [2]:
dataset_kwargs: FewSparseDatasetKwargs = {
    "seed": 0,
    "split_val_fold": 0,
    "split_test_fold": 0,
    "cache_data": True,
    "query_batch_size": 5,
    "split_query_size": 0.5,
    "split_query_fold": 0,
    "shot_options": [1, 5, 10, 15, 20],
    "sparsity_options": [("region", [0.1, 0.25, 0.5, 0.75, 1.0])],
    "support_query_data": "mixed",
    "support_batch_mode": "full_permutation",
    "split_test_size": 1,
    "sparsity_params": {"region_segments": 250, "region_compactness": 0.5},
}

drishti_dataset = DrishtiTestFSDataset(
    mode="test",
    num_classes=3,
    resize_to=(256, 256),
    **(dataset_kwargs | {"dataset_name": "DRISHTI-GS-test"}),
)
refuge_dataset = RefugeTestFSDataset(
    mode="test",
    num_classes=3,
    resize_to=(256, 256),
    **(dataset_kwargs | {"dataset_name": "REFUGE-test"}),
)
rim_one_3_dataset = RimOne3TestFSDataset(
    mode="test",
    num_classes=3,
    resize_to=(256, 256),
    **(dataset_kwargs | {"dataset_name": "RIM-ONE-3-test"}),
)

In [3]:
def get_region_coverage(
    img: NDArray, msk: NDArray, segments: int, compactness: float
) -> list[int]:
    num_classes = 3
    slic = segmentation.slic(
        img, n_segments=segments, compactness=compactness, start_label=1
    )
    labels = np.unique(slic)

    pure_regions = [[] for _ in range(num_classes)]
    for label in labels:
        sp = msk[slic == label].ravel()
        cnt = np.bincount(sp)
        for c in range(num_classes):
            if (cnt[c] if c < len(cnt) else None) == cnt.sum():
                pure_regions[c].append(label)

    new_msk = np.zeros_like(msk)
    new_msk[:] = -1
    for c, pure_region in enumerate(pure_regions):
        for sp in pure_region:
            new_msk[slic == sp] = c

    coverage_list = [sum([len(pr) for pr in pure_regions]), len(labels)]
    for c in range(num_classes):
        coverage_list.append(np.sum(new_msk == c))
        coverage_list.append(np.sum(msk == c))

    return coverage_list


In [None]:
# region_segments_options = [50, 100, 150, 200, 250, 300, 350, 400]
# region_compactness_options = [10**-1, 10**-0.5, 10**0, 10**0.5, 10**1, 10**1.5, 10**2]
# region_options = [
#     (s, c) for s in region_segments_options for c in region_compactness_options
# ]

# coverage_data = []
# for i, (segments, compactness) in enumerate(region_options):
#     print(f" {i + 1}/{len(region_options)}")
#     for dataset in [drishti_dataset, refuge_dataset, rim_one_3_dataset]:
#         name = dataset.dataset_name
#         indices = set(dataset.query_indices)
#         for i in indices:
#             img, msk, _ = dataset.get_data(i)
#             (
#                 cover_seg,
#                 total_seg,
#                 cover_0,
#                 total_0,
#                 cover_1,
#                 total_1,
#                 cover_2,
#                 total_2,
#             ) = get_region_coverage(img, msk, segments, compactness)
#             coverage_data.append(
#                 {
#                     "segments": segments,
#                     "compactness": compactness,
#                     "dataset": name,
#                     "index": i,
#                     "covered_segments": cover_seg,
#                     "total_segments": total_seg,
#                     "covered_class_0": cover_0,
#                     "total_class_0": total_0,
#                     "covered_class_1": cover_1,
#                     "total_class_1": total_1,
#                     "covered_class_2": cover_2,
#                     "total_class_2": total_2,
#                 }
#             )

# coverage_df = pd.DataFrame(coverage_data)
# coverage_df.to_csv("logs/region_coverage.csv", index=False)