# Dataset and Loader


In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt

## Dummy Dataset


In [None]:
from data.dummy_dataset import DummyFSDataset, DummySimpleDataset

In [None]:
dummy_simple_dataset = DummySimpleDataset(
    "val",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_val_fold=0,
    split_test_size=0.2,
    split_test_fold=0,
    cache_data=False,
    dataset_name="Dummy",
)

print(len(dummy_simple_dataset))
img, msk, name, _ = dummy_simple_dataset[79]

print(
    name,
    img.shape,
    img.dtype,
    img.min(),
    img.max(),
    msk.shape,
    msk.dtype,
    torch.unique(msk),
)

plt.imshow(np.moveaxis(img.numpy(), 0, -1))
# plt.imshow(msk.numpy())

In [None]:
dummy_fs_dataset = DummyFSDataset(
    "train",
    3,
    (256, 256),
    max_items=25,
    seed=0,
    split_val_size=0.2,
    split_val_fold=0,
    split_test_size=0.2,
    split_test_fold=0,
    cache_data=False,
    dataset_name="Dummy",
    shot_options="all",
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        ("grid", (10, 20)),
        ("contour", "random"),
        ("skeleton", (0.1, 0.5)),
        ("region", 0.5),
    ],
    sparsity_params={},
    shot_sparsity_permutation=True,
    homogen_support_batch=True,
    query_batch_size=10,
    split_query_size=0.5,
    split_query_fold=0,
    num_iterations=5.0,
)

print(len(dummy_fs_dataset.items))
print(len(dummy_fs_dataset), dummy_fs_dataset.num_iterations)

In [None]:
print(dummy_fs_dataset.support_batches)
for i in range(len(dummy_fs_dataset)):
    support, query, _ = dummy_fs_dataset[i]
    supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
    qry_img, qry_msk, qry_name = query
    print(
        supp_img.shape[0],
        supp_msk.shape[0],
        len(supp_name),
        "|",
        supp_sparsity_mode,
        [round(v, 2) if isinstance(v, float) else v for v in supp_sparsity_value]
        if isinstance(supp_sparsity_value, list)
        else supp_sparsity_value,
        "|",
        qry_img.shape[0],
        qry_msk.shape[0],
        len(qry_name),
    )

In [None]:
support, query, _ = dummy_fs_dataset[0]

supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
qry_img, qry_msk, qry_name = query

print(
    supp_img.shape,
    supp_img.dtype,
    supp_img.min(),
    supp_img.max(),
    supp_msk.shape,
    supp_msk.dtype,
    torch.unique(supp_msk),
)
print(supp_name, supp_sparsity_mode, supp_sparsity_value)
print(
    qry_img.shape,
    qry_img.dtype,
    qry_img.min(),
    qry_img.max(),
    qry_msk.shape,
    qry_msk.dtype,
    torch.unique(qry_msk),
)
print(qry_name)
print()

# plt.imshow(supp_msk[0].numpy())
# plt.imshow(qry_msk[0].numpy())

## Dummy Loader


In [None]:
from torch.utils.data import ConcatDataset, DataLoader

dummy_loader = DataLoader(
    ConcatDataset([dummy_fs_dataset]),
    batch_size=None,
    shuffle=dummy_fs_dataset.mode == "train",
    num_workers=0,
    pin_memory=True,
)

for batch in dummy_loader:
    support, query, dataset_name = batch

    print(type(batch.support))
    print(support[0].shape, support[1].shape, support[2][:4], support[3])
    print(query[0].shape, query[1].shape, query[2])
    print(dataset_name)

    break

# Sparse Masks


## Initialization


In [None]:
import numpy as np
from matplotlib import pyplot as plt

from data.few_sparse_dataset import FewSparseDataset
from data.typings import SparsityValue
from tasks.optic_disc_cup.datasets import (
    RimOne3TrainFSDataset,
    DrishtiTrainFSDataset,
    RefugeTrainFSDataset,
    RefugeValFSDataset,
)

plt.style.use("dark_background")


In [None]:
sparsity_values: dict[str, SparsityValue] = {
    "point": 25,
    "grid": 0.5,
    "contour": 0.5,
    "skeleton": 0.5,
    "region": 0.5,
}

In [None]:
def print_image_mask(image, mask):
    print(image.shape, image.dtype, image.min(), image.max())
    print(mask.shape, mask.dtype, np.unique(mask))


def plot_masks(mask, sparse_masks):
    n_rows = int(np.floor(len(sparse_masks) / 2)) + 1
    _, axs = plt.subplots(n_rows, 2, figsize=(5, n_rows * 2.5))
    assert isinstance(axs, np.ndarray)
    axs = axs.flat
    [ax.axis("off") for ax in axs]
    axs[0].imshow(mask)
    for i, sm in enumerate(sparse_masks.values()):
        axs[i + 1].imshow(sm)
    plt.tight_layout()

In [None]:
def plot_multiple_images_masks(
    dataset: FewSparseDataset, indices: list[int], keys: list[str]
):
    ncols = len(indices)
    nrows = len(keys)
    _, axs = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    assert isinstance(axs, np.ndarray)
    for c, index in enumerate(indices):
        image, mask, sparse_masks, _ = dataset.get_data_with_sparse_all(
            index, sparsity_values
        )
        r = 0
        if "image" in keys:
            axs[r, c].imshow(image)
            axs[r, c].axis("off")
            r += 1
        if "dense" in keys:
            axs[r, c].imshow(mask)
            axs[r, c].axis("off")
            r += 1
        for key, sm in sparse_masks.items():
            if key not in keys:
                continue
            axs[r, c].imshow(sm)
            axs[r, c].axis("off")
            r += 1
    plt.tight_layout()

## RIM-ONE-3-train


In [None]:
rim_one_3_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 5,
    "contour_radius_thick": 2.5,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.4,
}

rim_one_3_train_data = RimOne3TrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=rim_one_3_sparsity_params,
)

In [None]:
# image, mask, sparse_masks, _ = rim_one_3_data.get_data_with_sparse_all(0, sparsity_values)
# print_image_mask(image, mask)
# plot_masks(mask, sparse_masks)

In [None]:
plot_multiple_images_masks(
    rim_one_3_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## DRISHTI-GS-train


In [None]:
drishti_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 5,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.5,
}

drishti_train_data = DrishtiTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=drishti_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    drishti_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## REFUGE-train


In [None]:
refuge_train_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 7,
    "contour_radius_thick": 3,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.4,
}

refuge_train_data = RefugeTrainFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_train_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    refuge_train_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)

## REFUGE-val


In [None]:
refuge_val_sparsity_params: dict = {
    "point_dot_size": 10,
    "grid_spacing": 25,
    "grid_dot_size": 7,
    "contour_radius_dist": 7,
    "contour_radius_thick": 3,
    "skeleton_radius_thick": 5,
    "region_compactness": 0.5,
}

refuge_val_data = RefugeValFSDataset(
    mode="train",
    num_classes=3,
    resize_to=(256, 256),
    sparsity_params=refuge_val_sparsity_params,
)

In [None]:
plot_multiple_images_masks(
    refuge_val_data,
    list(range(0, 8)),
    ["point", "grid", "contour", "skeleton", "region"],
)