In [4]:
from monai.data import Dataset
import zarr
import pandas as pd
from pathlib import Path

In [85]:
from collections.abc import Sequence


class SkisntressionDataset(Dataset):
    def __init__(self, images, curves, params, sample_to_person):
        self.images = zarr.open(images, mode="r")
        self.curves = {curve.stem: pd.read_csv(curve) for curve in Path(curves).glob("*.csv")}
        self.params = pd.read_csv(params)
        self.sample_to_person = pd.read_csv(sample_to_person)
        self.sample_ids = self.params["sample_id"]
    
    def __len__(self):
        return len(self.sample_ids)
    
    def __getitem__(self, index: int | slice | Sequence[int]):
        sample_id = self.sample_ids[index]
        img = self.images[sample_id][0, ...]
        target = self.params.loc[self.params["sample_id"] == int(sample_id)]
        return img, target, sample_id

In [86]:
d = SkisntressionDataset("../data/stacks.zarr/", "../data/curves/", "../data/params.csv", "../data/sample_to_person.csv")

In [84]:
d[10]

(array([[0, 0, 1, ..., 6, 6, 6],
        [0, 1, 1, ..., 4, 4, 8],
        [0, 0, 1, ..., 4, 5, 6],
        ...,
        [0, 0, 0, ..., 4, 4, 4],
        [0, 0, 0, ..., 6, 6, 5],
        [0, 0, 1, ..., 5, 6, 3]], dtype=uint8),
     sample_id         A          k        xc
 10         21  1.549959  19.872042  1.301004,
 21)