In [1]:
from abc import abstractmethod, abstractproperty, ABC
from typing import Callable, Iterator, TypeAlias, NoReturn, Any
import pandas as pd
import numpy as np
import numpy.typing as npt
import torch
import tensorflow as tf


# Interface
# =========


from abc import abstractmethod, ABC
from typing import Callable, Iterator, Any, Protocol

from torch.utils.data import Dataset as TorchDataset

from data_protocols import TorchMapDatasetProtocol, TorchIterableDatasetProtocol, \
    TorchDatasetProtocol

# Implementation
# ==============

# Types
# -----

# This type alias refers to a function that takes in a np array and returns a np array.
NPArrayTransform: TypeAlias = Callable[[np.ndarray], np.ndarray]

# This type alias refers to a function that transforms one label into another. Note that the 
# datatype may change.
LabelTransform: TypeAlias = Callable[[int | str], int | str]


# Adapters
# --------

class NPArray2TorchDatasetAdapter(TorchMapDatasetProtocol, TorchIterableDatasetProtocol):
    """
    This class is a helper class that allows us to convert:
    - a dataset where each example consists of a np array and an associated 
      label in string or integer format,
    - into a pytorch dataset.
    It is used in the implementation of the ImageDataSetImplementation class.
    """
    def __init__(
            self, 
            X: list[np.ndarray], 
            y: list[int | str], 
            transform: NPArrayTransform | None = None,
            target_transform: LabelTransform | None = None,        
            gpu: bool = False
    ):
        """
        Note that we expect input to be of type list and not iterator, because the latter does not 
        have a length method.
        Consider making this less restrictive in the future, but it may not be necessary to support
        iterables without length such as generators. While it is possible to require input length
        as an argument, this gets clumsy if you want to make sure it is only required if the input 
        data structures do not have a length method (need to use generics?). 
        Even in cases where the input would just be a folder path, the validate function should be 
        able to perform the validation right after it has listed all the files. (This should even be 
        the case for a stream of batches, where the validation could take place for each batch.)
        """
        self.X = X
        self.y = y
        self.transform = transform
        self.target_transform = target_transform
        self.gpu = gpu
        self.__validate_input()

    def __len__(self) -> int:
        return len(self.y)

    def __getitem__(self, index) -> tuple[torch.Tensor, int | str]:
        unlabeled_example: np.ndarray = self.X[index]
        if self.transform is not None:
            unlabeled_example = self.transform(unlabeled_example)
        
        label: int | str = self.y[index]
        if self.target_transform is not None:
            label = self.target_transform(label)

        tensor: torch.Tensor = torch.from_numpy(unlabeled_example)
        return (tensor, label)
    
    def __validate_input(self) -> None | NoReturn:
        if self.gpu == True:
            raise NotImplementedError(
                "Conversion is not yet implemented when using GPU. Use Pytorch datastructures to " \
                "leverage GPU."
            )
        
        if len(self.y) == 0:
            raise ValueError("y must not be empty")
        
        if len(self.X) != len(self.y):
            raise ValueError("X and y must have the same length")
        
        # If validation passed, we don't return anything
        else:
            return None
        
    def numpy(self) -> np.ndarray:
        return np.array(self.X)


class TorchDataset2NPArrayAdapter():
    """
    This adapter takes a torch dataset and implements the same methods, but instead of tensors it 
    returns np arrays. 
    (Note that NumPy does not have an equivalent data structure to Pytorch's dataset
    """
    def __init__(self, dataset: TorchDatasetProtocol: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType, gpu: bool = False) -> None:
        self.dataset = dataset
        self.gpu = gpu

    def __len__(self) -> int:
        return len(self.dataset)  # type: ignore
    
    def __getitem__(self, index) -> tuple[np.ndarray, int | str]:
        if self.gpu == True:
            raise NotImplementedError(
                "Conversion is not yet implemented when using GPU. Use Pytorch datastructures to" \
                "leverage GPU."
            )
        else:
            tensor, label = self.dataset[index]
            # Need to detach and specify device. See https://discuss.pytorch.org/t/should-it-really-be-necessary-to-do-var-detach-cpu-numpy/35489/2
            numpy_data: np.ndarray = tensor.detach().cpu().numpy()
            return (numpy_data, label)

# class LabelFromFolderDataset(TorchDatasetProtocol: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType):

# Main implementation
# -------------------

class ImageDataSetImplementation(DataSetProtocol):
    """
    This uses the pytorch format internally, but allows instantiation from and 
    conversion to numpy ~~and tensoflow~~ format as well.
    """
    def __init__(self, torch_dataset: TorchDatasetProtocol: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType, gpu: bool = False):
        if gpu:
            raise NotImplementedError("GPU support is not yet implemented")
        self.data = torch_dataset
        self.gpu = gpu
    

    @classmethod
    def from_torch(cls, torch_dataset):
        return cls(torch_dataset=torch_dataset)

    def to_torch(self) -> TorchDatasetProtocol: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType:
        return self.data
    

    @classmethod
    def from_numpy(cls, X: list[np.ndarray], y: list[int | str], gpu: bool = False):
        torch_dataset: TorchDatasetProtocol: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType: TypeAlias = TorchMapDatasetProtocol | TorchIterableDatasetType = NPArray2TorchDatasetAdapter(X=X, y=y)
        return cls(torch_dataset=torch_dataset)

    def to_numpy_dataset(self):
        """
        This returns a "numpy dataset", i.e. an object similar to Pytorch's datasets, except that it
        returns np arrays instead of tensors.
        """
        return TorchDataset2NPArrayAdapter(dataset=self.data, gpu=self.gpu)


    # TODO: distinguish between instantiating from tensor versus tf.dataset
    # @classmethod
    # def from_tensorflow(cls, data: tf.Tensor, gpu: bool = False):
    #     if gpu == True:
    #         raise NotImplementedError("Conversion is not yet implemented when using GPU. Use Pytorch datastructures to leverage GPU.")
    #     else:
    #         numpy_data: np.ndarray = data.numpy()
    #         torch_data: torch.Tensor = torch.from_numpy(numpy_data)
    #         return cls(torch_data=torch_data)
        
    # def to_tensorflow(self) -> tf.Tensor:
    #     if self.gpu == True:
    #         raise NotImplementedError("Conversion is not yet implemented when using GPU. Use Pytorch datastructures to leverage GPU.")
    #     else:
    #         numpy_data: np.ndarray = self.to_numpy()
    #         tf_data: tf.Tensor = tf.convert_to_tensor(numpy_data)
    #         return tf_data

SyntaxError: invalid syntax (497133051.py, line 115)

# Nested Folders: Fungi Classification
https://archive.ics.uci.edu/dataset/773/defungi

In [6]:
# Assumes data has been downloaded to this directory
DATA_DIR_FUNGI = 'data/defungi/'

In [3]:
from pathlib import Path
from functools import cached_property
from dataclasses import dataclass

Label: TypeAlias = str | int

# class _ReadFunctionConfig(ABC):
#     """
#     This serves as the supertype for configs passed to the read function.  These will usually be 
#     data classes (which already have the `asdict()` method implemented), and need to be registered 
#     as a permissible type of ReadFunctionConfig by calling `ReadFunctionConfig.register(my_config)`. 
#     """
#     @abstractmethod
#     def asdict(self) -> dict[str, Any]: 
#         pass


@dataclass
class FileReader:
    reader_function: Callable[[Any], torch.Tensor]
    reader_function_config: dict # _ReadFunctionConfig


# class FileReader(Protocol):
#     @property
#     def reader_function(self) -> Callable[[Path, _ReadFunctionConfig], torch.Tensor]:
#         ...
#     @property
#     def reader_function_config(self) -> _ReadFunctionConfig:
#         ...


class LabelFromFoldernameDataset(TorchMapDatasetProtocol, TorchIterableDatasetProtocol):
    """
    This class implements the protocols for a PyTorch Datasets (both map and list). 
    
    (Note that the inheritance from these protocols does not make this class a Protocol itself. For 
    that it would have to inherit directly from Protocol as well.)
    """
    def __init__(
        self, 
        file_reader: FileReader,
        data_dir: str, 
        expected_filetype: str | None = None,
        input_transform: Callable | None = None, 
        target_transform: Callable | None = None,
    ):
        self.file_reader = file_reader
        self.data_dir: Path = Path(data_dir)
        self.expected_filetype = expected_filetype
        self.input_transform = input_transform
        self.target_transform = target_transform

    @cached_property
    def _inputfile_to_label_mapping(self) -> dict[str, Label]:
        """
        Args:
        expected_filetype: If specified, ValueError is raised unless all files in subfolder have the 
            specified extension. Do NOT include the dot in the extension. Set to None to skip check.

        This function lists the paths of all input files, and maps them to the associated label, 
        which is based on the subfolder a file lives in. The reason to do this ahead of time, rather
        than iterating over all the subfolders when we need to retrieve the data, is allows us to 
        easily shuffle the data. 

        Note that this approach has the downside of needing to hold the paths to all input files in 
        memory. Once this becomes a problem, we can refine this algorithm, but it will probably add 
        a lot of complexity to do so. 
        One possible solution would be to still list the number of files for each subfolder, but to 
        only store the *number* of files in each directory. When we want to retrieve the data, we 
        can then sample from the distribution of folder names/label plus index (ranging from one to 
        the number of files for that specific label).

        Note on type hints: While it would be preferable to to distinguish directory from file 
        paths, I didn't find subtypes or Protocols of Path to distinguish between these. 
        ToDo: Define Protocols for Path and DirectoryPath, and use them here. 
        """
        def _check_filetype(file: Path, expected_filetype: str | None) -> None | NoReturn:
            if expected_filetype is None:
                return None  # Skip check

            else:
                # First remove dot from file suffix
                actual_filetype =  file.suffix[1:]
                if actual_filetype == expected_filetype:
                    return None  # passed validation
                else:
                    raise ValueError(f"{file} does not have expected extension {expected_filetype}")

        inputfile_to_label_mapping: dict[str, Label] = {}
        subfolders: list[Path] = [
            obj for obj in self.data_dir.iterdir() 
            if obj.is_dir()
        ]

        for subfolder in subfolders:
            # Label is same as folder name
            label: Label = subfolder.name  
            # Find all files in the subfolder and save their names
            for file in subfolder.iterdir():
                _check_filetype(file=file, expected_filetype=self.expected_filetype)
                inputfile_to_label_mapping[file.name] = label

        return inputfile_to_label_mapping

    def __len__(self) -> int:
        """
        This function returns the number of files in the data directory.
        """
        return len(
            self._inputfile_to_label_mapping
        )

    def __getitem__(self, index) -> tuple[torch.Tensor, int | Label]:
        """
        This function returns an example (file + label) at the specified index.

        Todo: Make  sure that the repeated casting to list is not taking too much time. Otherwise,
        consider storing data in a two dimensional array, where each row contains path and label. 
        The trade-off is that this will consume more memory, though. A further optimization that 
        would avoid this memory overhead is to only store the order of labels, along with the
        number of files for each label, as already hinted above. If the list of paths is stored in
        the same order as the labels, it is easy to compute the label for a path at a given index.

        Another alternative: Create a more memory efficient version that only implements the
        Protocol for Torch*Iterable*DataSet.

        """
        # todo: try dict.items()

        # Get all keys, which represent the filenames. Then convert to list, so we can use use index
        inputfiles: list[str] = list(
            self._inputfile_to_label_mapping \
            .keys()
        )
        # Get path and label for that index
        input_file: str = inputfiles[index]
        label: Label = self._inputfile_to_label_mapping[input_file]
        
        # Load file at selected path into memory
        absolute_path_to_input: Path = self.data_dir / str(label) / input_file
        reader_args: dict[str, Any] = self.file_reader.reader_function_config  #.asdict()
        input: torch.Tensor = self.file_reader.reader_function(
            str(absolute_path_to_input), 
            **reader_args,
        )
        
        # Apply transformations, if specified
        if self.input_transform is not None:
            input = self.input_transform(input)
        if self.target_transform is not None:
            label = self.target_transform(label)

        return (input, label)


    # def __iter__(self) -> tuple[torch.Tensor, int | Label]:
    #     """
    #     This function iterates over all files in the data directory, and returns the associated 
    #     label and feature. 
    #     """
    #     datapath_to_label_mapping = self._inputfile_to_label_mapping
    #     for datapath, label in datapath_to_label_mapping.items():
    #         feature: torch.Tensor = torch.load(datapath)
    #         if self.feature_transform is not None:
    #             feature = self.feature_transform(feature)
    #         if self.target_transform is not None:
    #             label = self.target_transform(label)
    #         yield (feature, label)



In [7]:
import torchvision

fungi_dataset = LabelFromFoldernameDataset(
    file_reader=FileReader(
        reader_function=torchvision.io.read_image,
        reader_function_config={'mode': torchvision.io.ImageReadMode.UNCHANGED},
    ),
    data_dir=DATA_DIR_FUNGI,
    expected_filetype='jpg',
)


<cell>3: [1m[31merror:[m Cannot instantiate abstract class


In [8]:
fungi_dataset[0]

(tensor([[[116, 117, 117,  ..., 143, 142, 142],
          [117, 117, 118,  ..., 143, 143, 142],
          [119, 119, 119,  ..., 144, 143, 143],
          ...,
          [146, 145, 144,  ..., 144, 145, 146],
          [146, 146, 144,  ..., 144, 145, 146],
          [147, 146, 145,  ..., 144, 145, 146]],
 
         [[121, 122, 122,  ..., 146, 145, 145],
          [122, 122, 123,  ..., 146, 146, 145],
          [121, 121, 121,  ..., 147, 146, 146],
          ...,
          [136, 135, 134,  ..., 151, 152, 153],
          [136, 136, 134,  ..., 151, 152, 153],
          [137, 136, 135,  ..., 151, 152, 153]],
 
         [[117, 118, 118,  ..., 139, 138, 138],
          [118, 118, 119,  ..., 139, 139, 138],
          [118, 118, 118,  ..., 140, 139, 139],
          ...,
          [137, 136, 135,  ..., 144, 145, 146],
          [137, 137, 135,  ..., 144, 145, 146],
          [138, 137, 136,  ..., 144, 145, 146]]], dtype=torch.uint8),
 'H1')

In [9]:
len(fungi_dataset)

9114