In [20]:
from monai.data import Dataset
import zarr
import pandas as pd
from pathlib import Path
import numpy as np

In [131]:
from collections.abc import Sequence


class SkintressionDataset(Dataset):
    def __init__(self, images, curves, params, sample_to_person):
        self.images = zarr.open(images, mode="r")
        self.image_keys = list(self.images.keys())
        self.curves = {curve.stem: pd.read_csv(curve) for curve in Path(curves).glob("*.csv")}
        self.params = pd.read_csv(params)
        self.sample_to_person = pd.read_csv(sample_to_person)
        self.sample_ids = self.params["sample_id"]
        self.indices, self.cumsum = self.calc_indices_and_filter()
    
    def calc_indices_and_filter(self):
        num = 0
        lengths = []
        for img in self.images.keys():
            try:
                self.curves[str(img)]
            except KeyError:
                self.image_keys.remove(img)
                print(f"Removed {img} from dataset")
            else:
                length = self.images[img].shape[0]
                num += length
                lengths.append(length)
        cumsum = np.cumsum(lengths)
        return num, cumsum
    
    def __len__(self):
        return self.indices
    
    def __getitem__(self, index: int | slice | Sequence[int]):
        img_idx = np.digitize(index, self.cumsum)
        sample_id = self.image_keys[img_idx]
        slice_idx = img_idx - self.cumsum[np.digitize(index, self.cumsum)]
        img = self.images[sample_id][slice_idx, ...]
        target = self.params.loc[self.params["sample_id"] == int(sample_id)]
        curve = self.curves[str(sample_id)]
        return img, target, curve, sample_id

In [132]:
# d = SkintressionDataset("../data/stacks.zarr/", "../data/curves/", "../data/params.csv", "../data/sample_to_person.csv")
d = SkintressionDataset("D:/skinstression/data/new/stacks.zarr/", "D:/skinstression/data/new/curves/", "D:/skinstression/data/new/params.csv", "D:/skinstression/data/new/sample_to_person.csv")

Removed 1 from dataset
Removed 2 from dataset
Removed 3 from dataset
Removed 4 from dataset
Removed 5 from dataset


In [134]:
d[0][0].shape

(1000, 1000)

In [105]:
print(d.cumsum[np.digitize(187, d.cumsum)])
print(np.digitize(31, d.cumsum))
print(31 - d.cumsum[np.digitize(31, d.cumsum)])
print(0 - d.cumsum[np.digitize(0, d.cumsum)])

217
1
-31
-30


In [62]:
d.cumsum

array([  30,   62,   93,  124,  155,  186,  217,  248,  279,  310,  341,
        371,  402,  433,  464,  495,  526,  557,  588,  619,  650,  681,
        711,  742,  773,  804,  835,  866,  897,  928,  959,  990, 1021,
       1051, 1082, 1113, 1144, 1175, 1206, 1237, 1268, 1299, 1330, 1361,
       1392, 1423, 1454, 1485, 1515, 1545, 1575, 1605, 1635, 1665, 1694,
       1755, 1785, 1815, 1845, 1875])

d.images[2].shape

In [13]:
d[10]

(array([[0, 0, 1, ..., 6, 6, 6],
        [0, 1, 1, ..., 4, 4, 8],
        [0, 0, 1, ..., 4, 5, 6],
        ...,
        [0, 0, 0, ..., 4, 4, 4],
        [0, 0, 0, ..., 6, 6, 5],
        [0, 0, 1, ..., 5, 6, 3]], dtype=uint8),
     sample_id         A          k        xc
 10         21  1.549959  19.872042  1.301004,
       stress   strain
 0   0.000000  1.00000
 1   0.000313  1.00075
 2   0.000625  1.00300
 3   0.000625  1.02425
 4   0.000625  1.04375
 5   0.001875  1.06425
 6   0.004687  1.08350
 7   0.010625  1.10300
 8   0.023438  1.12225
 9   0.047188  1.14175
 10  0.085625  1.16075
 11  0.122813  1.18025
 12  0.193438  1.19950
 13  0.241875  1.21825
 14  0.347812  1.23725
 15  0.460625  1.25625
 16  0.593750  1.27475
 17  0.729062  1.29375
 18  0.867188  1.31275
 19  0.983437  1.33150
 20  1.098438  1.35050
 21  1.228438  1.36900
 22  1.344375  1.38825,
 21)