# Dataset and Loader


In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt

## Dummy Dataset


In [None]:
from data.dummy_dataset import DummyFSDataset, DummySimpleDataset

In [None]:
dummy_simple_dataset = DummySimpleDataset(
    "val",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="Dummy",
)

print(len(dummy_simple_dataset))
img, msk, name, _ = dummy_simple_dataset[79]

print(
    name,
    img.shape,
    img.dtype,
    img.min(),
    img.max(),
    msk.shape,
    msk.dtype,
    torch.unique(msk),
)

plt.imshow(np.moveaxis(img.numpy(), 0, -1))
# plt.imshow(msk.numpy())

In [None]:
dummy_fs_dataset = DummyFSDataset(
    "train",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="Dummy",
    shot_options=[1, 5, 10, 20],
    # shot_options="all",
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        # ("grid", (10, 20)),
        # ("contour", "random"),
        # ("skeleton", (0.1, 0.5)),
        # ("region", 0.5),
    ],
    shot_sparsity_permutation=True,
    num_iterations=1.0,
    query_batch_size=5,
    split_query_size=0.9,
)

print(len(dummy_fs_dataset))
support, query, _ = dummy_fs_dataset[0]

supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
qry_img, qry_msk, qry_name = query

print(
    supp_img.shape,
    supp_img.dtype,
    supp_img.min(),
    supp_img.max(),
    supp_msk.shape,
    supp_msk.dtype,
    torch.unique(supp_msk),
)
print(supp_name, supp_sparsity_mode, supp_sparsity_value)
print(
    qry_img.shape,
    qry_img.dtype,
    qry_img.min(),
    qry_img.max(),
    qry_msk.shape,
    qry_msk.dtype,
    torch.unique(qry_msk),
)
print(qry_name)
print()

# plt.imshow(supp_msk[0].numpy())
# plt.imshow(qry_msk[0].numpy())

In [None]:
print(dummy_fs_dataset.support_batches)
for i in range(len(dummy_fs_dataset)):
    support, query, _ = dummy_fs_dataset[i]
    supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
    qry_img, qry_msk, qry_name = query
    print(
        supp_img.shape[0],
        supp_msk.shape[0],
        len(supp_name),
        supp_sparsity_mode,
        supp_sparsity_value,
        qry_img.shape[0],
        qry_msk.shape[0],
        len(qry_name),
    )

## Dummy Loader


In [None]:
from torch.utils.data import ConcatDataset, DataLoader

dummy_loader = DataLoader(
    ConcatDataset([dummy_fs_dataset]),
    batch_size=None,
    shuffle=dummy_fs_dataset.mode == "train",
    num_workers=0,
    pin_memory=True,
)

for batch in dummy_loader:
    support, query, dataset_name = batch

    print(type(batch.support))
    print(support[0].shape, support[1].shape, support[2][:4], support[3])
    print(query[0].shape, query[1].shape, query[2])
    print(dataset_name)

    break

## RIM-ONE Dataset


In [None]:
from tasks.optic_disc_cup.datasets import RimOneFSDataset


rim_one = RimOneFSDataset(
    "train",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="RIM-ONE DL",
    shot_options=[1, 5, 10, 20],
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        ("grid", [20, 10]),
        ("contour", [0.1, 0.5, 1]),
    ],
    shot_sparsity_permutation=True,
    query_batch_size=2,
    split_query_size=0.5,
)

# for i in range(rim_one.num_iterations):
#     print(rim_one.support_batches[i], rim_one.support_sparsities[i])