# Kedro + PyTorch

List Train Files: different dataset abstractions. ImageFolder used in PyTorch
Split train validation (split based on images)

`type: PartitionedDataSet
dataset:
type: pandas.CSVDataset
save
path: "01_raw/train_images"
filename_suffix: ".jpeg"

In [1]:
from typing import Any, Dict, Callable
import pandas as pd

In [None]:
def partition_load_func():
    pass

In [None]:
# Node that lists all train files
def list_files(
        partitioned_file_list: Dict[str, Callable[[], Any]], limit: int = -1
) -> pd.DataFrame:
    results = []

    for parition_key, partition_load_func in sorted(partitioned_file_list.items()):
        file_path = partition_load_func() #where is this 
        results.append(file_path)

    df = pd.DataFrame(results)
    return df if limit < 0 else df.sample(n=limit, random_state=42)

In the `datacatalog.yaml` add 

```
train_images_list:
    type: PartitionedDataSet
    path: /data/01_raw/train_images
    dataset: kedro_pytorch.datasets.FileWithDirAsLabel # this is a custom implementation
    filename: ".jpg"
```

So in the location of the pipeline there is a `datasets.py` script with implementation of that class

In [None]:
class FileWithDirAsLabel(AbstractDataSet):
    '''Returns a dictionary of paths and labels (in my case it would be targets).
    '''
    def __init__(self, filepath: str):
        self.path = filepath

    def _load(self) -> dict:
        p = Path(self.path)
        return {'path': self.path, 'label': p.parent.name}

    def _save(self, data: Any) -> None:
        raise DataSetError("FileListDataset is read-only!")

    def _describe(self) -> Dict[str, Any]:
        pass

NO go to the nodes in the pipeline. Another dataset created in the catalog:
```
train_dataset:
    type: datasets.KedroPyTorchImageDataSet
    path: data/train
```

In [None]:
class KedroPyTorchImageDataSet(Dataset, AbstractDataSet):
    def __init__(
        self,
        path: str,
        path_column: str = "path",
        label_column: str = "label"
    ):
        super().__init__()
        self.target_transform_fn = target_transform