# Dataset


## Dummy Loader


In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from data.dummy_dataset import DummyDataset, DummySimpleDataset

In [None]:
dummy_simple_dataset = DummySimpleDataset(
    "val",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="Dummy",
)

print(len(dummy_simple_dataset))
img, msk, name, _ = dummy_simple_dataset[79]

print(
    name,
    img.shape,
    img.dtype,
    img.min(),
    img.max(),
    msk.shape,
    msk.dtype,
    torch.unique(msk),
)

plt.imshow(np.moveaxis(img.numpy(), 0, -1))
# plt.imshow(msk.numpy())

In [None]:
dummy_dataset = DummyDataset(
    "train",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="Dummy",
    shot_options=[1, 5, 10, 20],
    # shot_options="all",
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        # ("grid", (10, 20)),
        # ("contour", "random"),
        # ("skeleton", (0.1, 0.5)),
        # ("region", 0.5),
    ],
    shot_sparsity_permutation=True,
    num_iterations=1.0,
    query_batch_size=5,
    split_query_size=0.9,
)

print(len(dummy_dataset))
support, query, _ = dummy_dataset[0]

supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
qry_img, qry_msk, qry_name = query

print(
    supp_img.shape,
    supp_img.dtype,
    supp_img.min(),
    supp_img.max(),
    supp_msk.shape,
    supp_msk.dtype,
    torch.unique(supp_msk),
)
print(supp_name, supp_sparsity_mode, supp_sparsity_value)
print(
    qry_img.shape,
    qry_img.dtype,
    qry_img.min(),
    qry_img.max(),
    qry_msk.shape,
    qry_msk.dtype,
    torch.unique(qry_msk),
)
print(qry_name)
print()

# plt.imshow(supp_msk[0].numpy())
# plt.imshow(qry_msk[0].numpy())

In [None]:
from torch.utils.data import ConcatDataset, DataLoader

dummy_loader = DataLoader(
    ConcatDataset([dummy_dataset]),
    batch_size=None,
    shuffle=dummy_dataset.mode == "train",
    num_workers=0,
    pin_memory=True,
)

for batch in dummy_loader:
    support, query, dataset_name = batch

    print(type(batch.support))
    print(support[0].shape, support[1].shape, support[2][:4], support[3])
    print(query[0].shape, query[1].shape, query[2])
    print(dataset_name)

    break

In [None]:
print(dummy_dataset.support_batches)
for i in range(len(dummy_dataset)):
    support, query, _ = dummy_dataset[i]
    supp_img, supp_msk, supp_name, supp_sparsity_mode, supp_sparsity_value = support
    qry_img, qry_msk, qry_name = query
    print(
        supp_img.shape[0],
        supp_msk.shape[0],
        len(supp_name),
        supp_sparsity_mode,
        supp_sparsity_value,
        qry_img.shape[0],
        qry_msk.shape[0],
        len(qry_name),
    )

## Other


In [None]:
from tasks.optic_disc_cup.datasets import RimOneDataset


rim_one = RimOneDataset(
    "train",
    3,
    (256, 256),
    max_items=100,
    seed=0,
    split_val_size=0.2,
    split_test_size=0.2,
    dataset_name="RIM-ONE DL",
    shot_options=[1, 5, 10, 20],
    sparsity_options=[
        ("point", [1, 5, 10, 20]),
        ("grid", [20, 10]),
        ("contour", [0.1, 0.5, 1]),
    ],
    shot_sparsity_permutation=True,
    query_batch_size=2,
    split_query_size=0.5,
)

# for i in range(rim_one.num_iterations):
#     print(rim_one.support_batches[i], rim_one.support_sparsities[i])

# Utilities


## Metric


In [None]:
import torch

from learners.metrics import MultiIoUMetric


metric = MultiIoUMetric()

metric(torch.tensor([1, 0, 1]), torch.tensor([1, 0, 1]))

metric.compute()

## Utils


In [None]:
from utils.diff_dict import diff_dict


diff_dict(
    {
        "config": {},
        "a": {"a": {"p": 1}},
        "b": 2,
        "p": (12, 34),
        "d": {"e": 4, "f": 5, "g": [6, 7, 8]},
        "g": [6, 7, 8],
    },
    {
        "config": {},
        "a": {"a": {"p": 1, "q": 0}},
        "c": 3,
        "d": {"e": 4, "f": 5, "g": [6, 9, 8, {"a": 1}]},
        "g": [6, 9],
        "p": (12, 33, 34),
    },
)

In [None]:
import os
import datetime


for item in os.listdir("logs/SL"):
    mtime = os.path.getmtime(os.path.join("logs/SL", item))
    print(datetime.datetime.fromtimestamp(mtime).isoformat())

In [None]:
def make_batch_sample_indices(
    population_size: int, sample_size: int, batch_size: int
) -> list[list[int]]:
    import random

    samples = sorted(random.sample(range(population_size), sample_size))
    population_batch_size = population_size // batch_size + 1
    batch_samples = [[] for _ in range(population_batch_size)]
    for s in samples:
        batch_samples[s // batch_size].append(s - (s // batch_size) * batch_size)
    return batch_samples


make_batch_sample_indices(100, 10, 20)

In [None]:
def make_batch_sample_indices_multi(
    iterations_batches: list[tuple[int, int]], total_samples: int
) -> list[list[int]]:
    import random

    populations = [iter * batch for iter, batch in iterations_batches]

    sum_populations = sum(populations)
    samples = [round(p * total_samples / sum_populations) for p in populations]
    while True:
        sum_samples = sum(samples)
        if sum_samples == total_samples:
            break
        index = random.randint(0, len(samples) - 1)
        samples[index] += 1 if sum_samples < total_samples else -1

    batch_samples = []
    zipped = zip(iterations_batches, populations, samples)
    for (_, batch), population, sample in zipped:
        batch_samples += make_batch_sample_indices(
            population,
            sample,
            batch,
        )

    return batch_samples


make_batch_sample_indices_multi([(5, 3), (4, 2), (10, 1)], 20)

## WandB


In [None]:
import wandb

from config.constants import WANDB_SETTINGS
from utils.wandb import wandb_login

In [None]:
def wandb_log_dataset_ref(dataset_path: str, dataset_name: str, dummy: bool = False):
    wandb_login()
    wandb.init(
        tags=["helper"],
        project=WANDB_SETTINGS["dummy_project" if dummy else "project"],
        name=f"log dataset {dataset_name}",
    )
    dataset_artifact = wandb.Artifact(dataset_name, type="dataset")
    dataset_artifact.add_reference(f"file://{dataset_path}")
    wandb.log_artifact(dataset_artifact)
    wandb.finish()


# wandb_log_dataset_ref("D:/Penelitian/FWS/Data/DRISHTI-GS", "DRISHTI", True)
# wandb_log_dataset_ref("D:/Penelitian/FWS/Data/RIM-ONE", "RIM-ONE", True)

In [None]:
# ac = wandb.Api().artifact_collection(
#     "run_table", "pandegaaz/few-shot-weakly-seg-old/run-svgff5kf-metrics"
# )

# ac.delete()

# for art in ac.artifacts():
#     print(art.name, art.id)

# art: wandb.Artifact = ac.artifacts()[0]  # type: ignore

# print(art.name, art.aliases)

# art.download("ppp/qqq")



# Other
