# Input_data package usage example

## Setup

### Print available datasets

In [None]:
from src.input_data import create_dataset, SupportedDatasets, list_supported_datasets

datasets = list_supported_datasets()
print(f"Supported datasets: {', '.join(datasets)}")

### Set data root directory

In [None]:
from pathlib import Path
root = Path.cwd().parent / "data" # src -> ml-sandbox -> ~/data

## MNIST Dataset

### Download the dataset

In [None]:
print(f"\n{'='*10} Working with {SupportedDatasets.MNIST.name} {'='*10}\n")

mnist_dataset = create_dataset(SupportedDatasets.MNIST, root=root, force_download=False)

### Print dataset information and visualization

In [None]:
# Print dataset information and show random samples
mnist_dataset.print_info()

# Visualize some examples
mnist_dataset.show_illustrative_samples()

## FASHION-MNIST Dataset

### Download the dataset

In [None]:
print(f"\n{'='*10} Working with {SupportedDatasets.FASHION_MNIST.name} {'='*10}\n")

fashion_mnist_dataset = create_dataset(SupportedDatasets.FASHION_MNIST, root=root, force_download=False)

### Print dataset information and visualization

In [None]:
# Print dataset information and show random samples
fashion_mnist_dataset.print_info()

# Visualize some examples
fashion_mnist_dataset.show_illustrative_samples()

## CIFAR-10 Dataset

### Download the dataset

In [None]:
print(f"\n{'='*10} Working with {SupportedDatasets.CIFAR10.name} {'='*10}\n")

cifar10_dataset = create_dataset(SupportedDatasets.CIFAR10, root=root, force_download=False)

### Print dataset information and visualization

In [None]:
# Print dataset information and show random samples
cifar10_dataset.print_info()

# Visualize some examples
cifar10_dataset.show_illustrative_samples()

## New Dataset

Test out a new implementation of the `ManagedDataset` interface in the following steps:

### Specify dataset information

In [None]:
from src.input_data.structure import DatasetInfo

PLACEHOLDER_INFO = DatasetInfo(
    name="Placeholder",
    description="Placeholder dataset",
    classes=["class1", "class2", "class3", "class4", "class5"],
    num_classes=5,
    input_shape=(3, 32, 32),
    license="",
    citation=""
)

### Specify dowloads information

In [None]:
from src.input_data.downloaders import DownloadInfo

PLACEHOLDER_DOWNLOADS = [
    DownloadInfo(
        name="Placeholder",
        filename="placeholder.tar.gz",
        urls=[
            "https://example.com/placeholder.tar.gz",
            "https://example2.com/placeholder.tar.gz"
        ],
        md5="d41d8cd98f00b204e9800998ecf8427e",  # Example MD5 for an empty file
        description="Placeholder"
    ),
    # Add more DownloadInfo instances as needed
]

### Implement the class interface

In [None]:
from typing import List, Tuple, override
import numpy as np
from src.input_data.structure import ManagedDataset

class PlaceholderDataset(ManagedDataset):
    """
    Placeholder dataset with automatic download and flexible storage.

    Loads ALL Placeholder data (train + test) into a unified dataset.
    Use get_dataloaders() to split into train/val/test sets.
    
    Returns:
        sample: torch.Tensor of shape (1, x, x) with values in [0, 1]
        target: int class label (0-9)
    """
    
    @override
    @property
    def download_infos(self) -> List[DownloadInfo]:
        return PLACEHOLDER_DOWNLOADS

    @override
    @property
    def dataset_name(self) -> str:
        return "placeholder"

    @override
    @property
    def dataset_info(self) -> DatasetInfo:
        return PLACEHOLDER_INFO
    
    @override
    def _load_raw_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load raw MNIST data from downloaded files."""
        all_images = []
        all_labels = []
        
        # Load training data
        file_path = self.dataset_root / "placeholder.tar.gz"
        
        # --- Define the whole extraction and loading logic here ---
    
        # Combine all data and normalize to 0-1 range
        combined_data = np.concatenate(all_images, axis=0).astype(np.float32) / 255.0
        combined_labels = np.concatenate(all_labels, axis=0)
        
        print(f"Loaded complete Placeholder dataset: {len(combined_data):,} samples ")
        
        return combined_data, combined_labels