# Initialization


In [None]:
import numpy as np
from matplotlib import pyplot as plt

from tasks.optic_disc_cup.datasets import DrishtiDataset, RimOneDataset

plt.style.use("dark_background")


# Dataset Exploration


## RIM-ONE


### Create Dataset


In [None]:
rim_one_sparsity_params: dict = {
    "contour_radius_dist": 4,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 4,
    "region_compactness": 0.5,
}

rim_one_data = RimOneDataset(
    mode="train",
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=rim_one_sparsity_params,
)

### Check Sparse Masks


In [None]:
image, mask, sparse_masks, image_filename = rim_one_data.get_data_with_sparse_all(
    0, 50, 10, 1, 1, 1
)
print(image.shape, image.max(), image.min(), image_filename)
print(mask.shape, mask.dtype, np.unique(mask))

n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1
_, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows * 6))
axs = axs.flat
axs[0].imshow(image)
axs[1].imshow(mask)
for i, sm in enumerate(sparse_masks):
    axs[i + 2].imshow(sm)

### Check Others


In [None]:
image_sizes = []
for image_path, mask_path in rim_one_data.get_all_data_path():
    image, _ = rim_one_data.read_image_mask(image_path, mask_path)
    image_sizes.append(image.shape)

image_sizes = np.array(image_sizes)

print(np.unique(image_sizes[:, 0], return_counts=True))
print(image_sizes[:, 0].min(), image_sizes[:, 0].max())
print(np.unique(image_sizes[:, 1], return_counts=True))
print(image_sizes[:, 1].min(), image_sizes[:, 1].max())

## DRISHTI


### Create Dataset


In [None]:
drishti_sparsity_params: dict = {
    "contour_radius_dist": 4,
    "contour_radius_thick": 1,
    "skeleton_radius_thick": 3,
    "region_compactness": 0.5,
}

drishti_data = DrishtiDataset(
    mode="train",
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=drishti_sparsity_params,
)

### Check Sparse Masks


In [None]:
image, mask, sparse_masks, image_filename = drishti_data.get_data_with_sparse_all(
    1, 50, 20, 1, 1, 1
)
print(image.shape, image.max(), image.min(), image_filename)
print(mask.shape, mask.dtype, np.unique(mask))

n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1
_, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows * 6))
axs = axs.flat
axs[0].imshow(image)
axs[1].imshow(mask, cmap="gray")
for i, sm in enumerate(sparse_masks):
    axs[i + 2].imshow(sm)

### Check Others


In [None]:
image_sizes = []
for image_path, mask_path in drishti_data.get_all_data_path():
    image, _ = rim_one_data.read_image_mask(image_path, mask_path)
    image_sizes.append(image.shape)

image_sizes = np.array(image_sizes)

print(np.unique(image_sizes[:, 0], return_counts=True))
print(image_sizes[:, 0].min(), image_sizes[:, 0].max())
print(np.unique(image_sizes[:, 1], return_counts=True))
print(image_sizes[:, 1].min(), image_sizes[:, 1].max())