In [1]:
from monai.data import ITKReader, ImageDataset, TiffFileWSIReader, ImageReader, Dataset, CSVDataset, DatasetFunc
from monai.transforms import LoadImage
from functools import partial
from monai.data.image_reader import _copy_compatible_dict, _stack_images
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
import torch
from os.path import abspath
from typing import Sequence
from pathlib import Path
import h5torch
from time import time
import tifffile
import matplotlib.pyplot as plt
import pyvips
import numpy as np
import zarr
from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg
import pandas as pd
from monai.data.utils import pad_list_data_collate

tifffile_imread, _ = optional_import("tifffile", name="imread")

In [11]:
from collections.abc import Sequence
from monai.config import PathLike
from numpy import ndarray

@require_pkg("tifffile")
class TiffFileReaderv1(ImageReader):
    supported_suffixes = ["tif", "tiff", "svs"]

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def get_data(self, img) -> tuple[ndarray, dict]:
        img_array: list[np.ndarray] = []
        compatible_meta: dict = {}

        for i in ensure_tuple(img):
            print(i)
            data = self._get_array_data(i)
            img_array.append(data)
            header = self._get_meta_dict(i)
            header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras)
            header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS
            header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy()
            header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
            if self.channel_dim is None:  # default to "no_channel" or -1
                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
                    float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
                )
            else:
                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
            _copy_compatible_dict(header, compatible_meta)

        return _stack_images(img_array, compatible_meta), compatible_meta

    def read(self, data, **kwargs):
        img_ = []

        filenames: Sequence[PathLike] = ensure_tuple(data)
        kwargs_ = self.kwargs.copy()
        kwargs_.update(kwargs)
        for name in filenames:
            name = f"{name}"
            img_.append(tifffile.imread(name, **kwargs_))
        return img_ if len(filenames) > 1 else img_[0]

    def verify_suffix(self, filename):
        if filename.endswith(self.supported_suffixes):
            return True
        else:
            return False



In [107]:
@require_pkg(pkg_name="tifffile")
class TiffFileReader(ImageReader):

    supported_suffixes = ["tif", "tiff"]
    backend = "tifffile"

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):
        """
        Read whole slide image objects from given file or list of files.

        Args:
            data: file name or a list of file names to read.
            kwargs: additional args that overrides `self.kwargs` for existing keys.

        Returns:
            whole slide image object or list of such objects.

        """
        img_list: list = []

        filenames: Sequence[PathLike] = ensure_tuple(data)
        kwargs_ = self.kwargs.copy()
        kwargs_.update(kwargs)
        for filename in filenames:
            store = tifffile_imread(filename, **kwargs_)
            img = zarr.open(store, mode="r")
            img_list.append(img)

        return img_list if len(filenames) > 1 else img_list[0]

    def verify_suffix(self, filename):
        if filename.endswith(self.supported_suffixes):
            return True
        else:
            return False
    
    def get_data(self, img) -> tuple[ndarray, dict]:
        # img = np.array(img)
        img = torch.tensor(np.array(img))
        return img, {}


In [115]:
images = list(Path("../data/stacks/").glob("*.tif"))
dataset = ImageDataset(images, reader=TiffFileReader(aszarr=True))
dataset.loader.readers = [TiffFileReader(aszarr=True)]
b = time()
print(dataset[1].meta)
print(time() - b)

{'filename_or_obj': '../data/stacks/10.tif', affine: tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], dtype=torch.float64)}
1.0520861148834229


In [150]:
images = list(Path("../data/stacks/").glob("*.tif"))
dataset = ImageDataset(images, reader='ITKReader')
print(dataset.loader.readers)
b = time()
print(dataset[0].shape)
print(time() - b)

[<monai.data.image_reader.ITKReader object at 0x7fe38a930950>, <monai.data.image_reader.NumpyReader object at 0x7fe38a932c10>, <monai.data.image_reader.PILReader object at 0x7fe38a930790>, <monai.data.image_reader.ITKReader object at 0x7fe3b26eb150>]
torch.Size([1000, 1000, 30])
0.3076913356781006


In [145]:
from collections.abc import Sequence


class A(Dataset):
    def __init__(self, data: Sequence, transform = None) -> None:
        """
        Args:
            data: input data to load and transform to generate dataset for model.
            transform: a callable data transform on input data.

        """
        self.data = data
        self.transform = transform


    def __getitem__(self, index: int | slice | Sequence[int]):
        # store = tifffile.imread(self.data[index]["img"], aszarr=True)
        # img = zarr.open(store)
        # img = torch.tensor(np.array(img))
        img = tifffile.imread(self.data[index]["img"])
        img = torch.tensor(np.array(img))
        return img

In [146]:
a = A([dict(img=images[i]) for i in range(len(images))])

In [147]:
b = time()
x = a[0]
print(time() - b)

0.33558106422424316


In [124]:
b = time()
store = tifffile.imread("../data/stacks/1.tif", aszarr=True)
img = torch.tensor(zarr.open(store, mode="r"))
print(time() -b)

3.473780870437622


In [56]:
from monai.data import ImageDataset
import pandas as pd
from pathlib import Path
import torch
from typing import Sequence

class SkinstressionDataset(ImageDataset):
    def __init__(self, image_dir, curves_dir, df, variables, *args, **kwargs) -> None:
        self.df = df
        image_files = (Path(image_dir) / df["filename"]).tolist()
        self.curves_dir = Path(curves_dir)
        assert isinstance(variables, list), "no variables selected. E.g. variables=['a', 'k', 'xc']"
        labels = torch.tensor(df[list(variables)].to_numpy())
        super().__init__(image_files=image_files, labels=labels, *args, **kwargs)
    
    def __getitem__(self, index):
        sample = list(super().__getitem__(index))
        sample[0] = sample[0].unsqueeze(0)  # Because the model expects a batch and color dimension, next to the 3D image.
        sample_id = str(self.df.loc[index]["sample_id"])
        df_curves = pd.read_csv(str(self.curves_dir / Path(sample_id).with_suffix(".csv")))
        sample.append(df_curves["strain"].to_numpy())
        sample.append(df_curves["stress"].to_numpy())
        return tuple(sample)

image_dir = "../data/stacks/"
curves_dir = "../data/curves/"
df = pd.read_csv("../data/stacks/params.csv")
df_persons = pd.read_csv("../data/stacks/sample_to_person.csv")
df = df.merge(df_persons, on="sample_id")
df["filename"] = [str(index) + ".tif" for index in df["sample_id"]]
df.columns = list(map(str.lower, df.columns))
variables = ["k"]
# variables = ["a", "k", "xc"]
dataset = SkinstressionDataset(image_dir=image_dir, curves_dir=curves_dir, df=df, variables=variables)

In [57]:
dataset[0]

(metatensor([[[[  0.,   3.,   4.,  ...,   9.,   1., 144.],
           [  2.,   4.,   5.,  ...,   1.,   3., 127.],
           [  0.,   2.,   2.,  ...,   2.,   0., 155.],
           ...,
           [ 12.,  11.,  14.,  ...,  21.,  24.,  94.],
           [  6.,   6.,  11.,  ...,  29.,   9.,  85.],
           [  0.,   8.,  14.,  ...,  18.,  25.,  91.]],
 
          [[  3.,   3.,   7.,  ...,  12.,   0., 138.],
           [  1.,   1.,   3.,  ...,   5.,   0., 114.],
           [  3.,   2.,   6.,  ...,   7.,   2., 146.],
           ...,
           [  5.,  11.,  14.,  ...,  20.,  20.,  86.],
           [ 10.,  11.,   7.,  ...,  22.,  20.,  86.],
           [ 10.,   6.,  10.,  ...,  24.,  27.,  92.]],
 
          [[  9.,   2.,   9.,  ...,   9.,   0., 134.],
           [  0.,   0.,   8.,  ...,   2.,   0., 135.],
           [  0.,   4.,   2.,  ...,   4.,   3., 131.],
           ...,
           [  6.,   6.,  10.,  ...,  12.,  15.,  79.],
           [ 10.,  11.,   6.,  ...,  18.,  15.,  84.],
       

In [59]:
dataset[0][1]

tensor([28.9636], dtype=torch.float64)

In [53]:
def filter_and_sort_ineligible_data(data: list[Path], csv_dataset):
    eligible_sample_ids = [sample["sample_id"] for sample in csv_dataset]
    eligible_data = []
    for _id in eligible_sample_ids:
        for sample in data:
            if int(sample.stem) != _id:
                continue
            eligible_data.append(sample)
    return eligible_data

target_dataset = CSVDataset(["../data/stacks/params.csv", "../data/stacks/sample_to_person.csv"])
image_files = list(Path("../data/stacks").glob("*.tif"))
image_files = DatasetFunc(image_files, filter_and_sort_ineligible_data, csv_dataset=target_dataset)
img_dataset = ImageDataset(image_files, reader="ITKReader")

In [193]:
sorted(variables[::-1])

['A', 'k']

In [206]:
from collections.abc import Sequence

import pandas as pd


class SkinstressionDataset(Dataset):
    def __init__(self, params, cols, sample_to_person, image_dir, curves_dir=None, suffix=".tif", reader="ITKReader"):
        """
        cols: list of params (str) to be selected from csv files in params. Order must correspond with model output.
        """
        self.target_dataset = CSVDataset([params, sample_to_person], col_names=["sample_id", "person_id"] + sorted(cols))
        if curves_dir is not None:
            curve_files = list(Path(curves_dir).glob("*.csv"))
            self.curves_datasets = []
            for curve_file in curve_files:
                curve_dataset = pd.read_csv(curve_file)
                self.curves_datasets.append(curve_dataset)
        else:
            self.curves_datasets = None
        image_files = list(Path(image_dir).glob(f"*{suffix}"))
        image_files = DatasetFunc(image_files, self.filter_and_sort_ineligible_data, csv_dataset=self.target_dataset)
        self.img_dataset = ImageDataset(image_files, reader=reader)
        self.sample_info = self.pop_sample_infos()
    
    def pop_sample_infos(self):
        sample_infos = []
        for data in self.target_dataset:
            sample_info = {
                "sample_id": data.pop("sample_id"),
                "person_id": data.pop("person_id"),
            }
            sample_infos.append(sample_info)
        return sample_infos
    
    @staticmethod
    def filter_and_sort_ineligible_data(data: list[Path], csv_dataset: CSVDataset):
        eligible_sample_ids = map(lambda x: x["sample_id"], csv_dataset)
        eligible_data = []
        for _id in eligible_sample_ids:
            for sample in data:
                if int(sample.stem) != _id:
                    continue
                eligible_data.append(sample)
        return eligible_data

    def __getitem__(self, index: int | slice | Sequence[int]):
        out = {
            "img": self.img_dataset[index],
            "target": self.target_dataset[index],
            "sample_info": self.sample_info[index],
        }
        if self.curves_datasets is not None:
            out["curve"] = self.curves_datasets[index]
        return out

skin = SkinstressionDataset("../data/stacks/params.csv", ["k", "A"], "../data/stacks/sample_to_person.csv", "../data/stacks/", "../data/curves")

In [1]:
skin[0]["sample_info"], skin[0]["target"], skin[0]["img"].meta["filename_or_obj"]

NameError: name 'skin' is not defined

In [1]:
from skinstression.dataset import SkinstressionDataModule

dm = SkinstressionDataModule("../data/stacks", "../data/stacks/params.csv", "../data/curves/", "../data/sample_to_person.csv", ["k", "A"], batch_size=1)

In [2]:
dm.setup("fit")
loader = dm.train_dataloader()

there are 51/60 eligible samples
train: 12 val: 5 test: 15


In [3]:
batch = next(iter(loader))
batch['sample_info']

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not MetaTensor

In [None]:
from skinstression.model import Skinstression

model = Skinstression()

In [None]:
batch["img"].shape

AttributeError: 'list' object has no attribute 'shape'

In [2]:
from monai.data import ZipDataset

class SkinstressionDataset(ZipDataset):
    def __init__(self, params, cols, sample_to_person, image_dir, curve_dir=None, suffix=".tif", reader="ITKReader"):
        self.target_dataset = CSVDataset([str(params), str(sample_to_person)], col_names=["sample_id", "person_id"] + sorted(cols))
        if curve_dir is not None:
            curve_files = list(Path(curve_dir).glob("*.csv"))
            self.curves_datasets = []
            for curve_file in curve_files:
                curve_dataset = pd.read_csv(curve_file)
                curve_dataset = curve_dataset.to_numpy()
                self.curves_datasets.append(curve_dataset)
        else:
            self.curves_datasets = None
        image_files = list(Path(image_dir).glob(f"*{suffix}"))
        image_files = DatasetFunc(image_files, self.filter_and_sort_ineligible_data, csv_dataset=self.target_dataset)
        self.img_dataset = ImageDataset(image_files, reader=reader)
        self.sample_info = self.pop_sample_infos()
        super().__init__([self.img_dataset, self.target_dataset, self.sample_info, self.curves_datasets])
    
    def pop_sample_infos(self):
        sample_infos = []
        for data in self.target_dataset:
            sample_info = {
                "sample_id": data.pop("sample_id"),
                "person_id": data.pop("person_id"),
            }
            sample_infos.append(sample_info)
        return sample_infos
    
    @staticmethod
    def filter_and_sort_ineligible_data(data: list[Path], csv_dataset: CSVDataset):
        eligible_sample_ids = map(lambda x: x["sample_id"], csv_dataset)
        eligible_data = []
        for _id in eligible_sample_ids:
            for sample in data:
                if int(sample.stem) != _id:
                    continue
                eligible_data.append(str(sample))
        return eligible_data

In [3]:
skin = SkinstressionDataset("../data/stacks/params.csv", ["k", "A"], "../data/stacks/sample_to_person.csv", "../data/stacks/", "../data/curves")

In [9]:
from monai.data import DataLoader
import torch
from monai.data import pad_list_data_collate
from monai.transforms.croppad.functional import pad_func

def custom_collate_fn(batch):
    # Get the maximum height and width among all images in the batch
    max_height = np.max([item[0].shape[0] for item in batch])
    max_width = np.max([item[0].shape[1] for item in batch])

    # Pad each image to the maximum height and width
    padded_images = []
    for item in batch:
        img = item[0].moveaxis(-1, 0)
        pad_height = max_height - img.shape[1]
        pad_width = max_width - img.shape[2]
        padded_img = pad_func(img, ((0, 0), (0, pad_height), (0, pad_width)), {})
        padded_images.append(padded_img)

    # Stack the padded images
    stacked_images = torch.stack(padded_images)

    # Return the stacked images along with other items in the batch
    return {'img': stacked_images, 'target': [item[1] for item in batch]}

loader = DataLoader(skin, batch_size=5, collate_fn=custom_collate_fn)
next(iter(loader))


loader = DataLoader(skin, batch_size=5, collate_fn=custom_collate_fn)
next(iter(loader))

padded
torch.Size([31, 1000, 1000])
padded
torch.Size([31, 1000, 1000])


{'img': metatensor([[[[ 3.,  0.,  0.,  ...,  4.,  0.,  1.],
           [ 1.,  0.,  2.,  ...,  2.,  2.,  2.],
           [ 0.,  1.,  0.,  ...,  0.,  1.,  1.],
           ...,
           [14., 16., 17.,  ...,  1.,  2.,  2.],
           [15., 25., 22.,  ...,  1.,  1.,  0.],
           [17., 11., 20.,  ...,  0.,  0.,  1.]],
 
          [[ 1.,  0.,  0.,  ...,  1.,  1.,  0.],
           [ 0.,  1.,  2.,  ...,  0.,  3.,  1.],
           [ 0.,  0.,  1.,  ...,  2.,  0.,  1.],
           ...,
           [22., 18., 17.,  ...,  2.,  1.,  2.],
           [17., 21., 14.,  ...,  1.,  2.,  0.],
           [18., 14., 18.,  ...,  2.,  0.,  0.]],
 
          [[ 1.,  3.,  0.,  ...,  2.,  0.,  2.],
           [ 1.,  1.,  0.,  ...,  1.,  0.,  3.],
           [ 1.,  1.,  1.,  ...,  4.,  3.,  2.],
           ...,
           [17., 18., 18.,  ...,  3.,  0.,  0.],
           [24., 21., 12.,  ...,  0.,  0.,  1.],
           [11., 16., 14.,  ...,  0.,  0.,  1.]],
 
          ...,
 
          [[10.,  9.,  6.,  ..., 

In [10]:
batch = next(iter(loader))

padded
torch.Size([31, 1000, 1000])


In [17]:
batch["target"]

[{'A': 5.8810762382094, 'k': 24.949814880240133},
 {'A': 3.6637352195486574, 'k': 21.96298150284788},
 {'A': 2.4274528696190725, 'k': 17.765872097226683},
 {'A': 3.146787124907819, 'k': 46.6717087788422},
 {'A': 4.105054523470894, 'k': 29.315148105929325}]