In [3]:
import pandas as pd
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from collections import Counter
import matplotlib.pyplot as plt

class WildScenesDataset:
    root_dirs = [
        os.path.join('..', 'data', 'WildScenes', 'WildScenes2d', 'V-01'),
        os.path.join('..', 'data', 'WildScenes', 'WildScenes2d', 'V-02'),
        os.path.join('..', 'data', 'WildScenes', 'WildScenes2d', 'V-03')
    ]
    _data_list_dir = os.path.join('datasets', 'data_list')
    _csv_files = {
        'train': os.path.join(_data_list_dir, 'train.csv'),
        'valid': os.path.join(_data_list_dir, 'valid.csv'),
        'test': os.path.join(_data_list_dir, 'test.csv'),
    }
    csv = _csv_files
    _label_to_trainid = {
        0: 15, 1: 16, 2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7,
        10: 8, 11: 16, 12: 9, 13: 10, 14: 11, 15: 12, 16: 13, 17: 16, 18: 14,
    }

    def __init__(self, dataset_type, transform=None):
        assert dataset_type in ('train', 'valid', 'test')
        self._dataset_type = dataset_type
        self._data_frame = pd.read_csv(WildScenesDataset._csv_files[self._dataset_type])
        self._transform = transform

    def __len__(self):
        return len(self._data_frame)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f"Index {index} out of bounds for dataset of size {len(self)}")
        
        image_path = self._data_frame['image'].iloc[index]
        label_path = self._data_frame['label'].iloc[index]

        image = Image.open(image_path).convert('RGB')
        label = Image.open(label_path).convert('L')
        label_np = np.array(label)

        label_trainId = np.vectorize(lambda x: self._label_to_trainid.get(x, 255))(label_np)

        if self._transform is not None:
            image, label_trainId = self._transform(image, label_trainId)

        return image, label_trainId

    @staticmethod
    def get_label_distribution(label_path):
        """
        Get the distribution of classes in a single label image.
        """
        label = Image.open(label_path).convert('L')
        label_np = np.array(label)
        unique, counts = np.unique(label_np, return_counts=True)
        return dict(zip(unique, counts))

    @staticmethod
    def print_class_distribution(df, set_name):
        """
        Print the distribution of classes in a dataset.
        """
        all_classes = Counter()
        for dist in df['distribution']:
            all_classes.update(dist)
        
        print(f"\nClass distribution in {set_name} set:")
        for class_id, count in all_classes.items():
            print(f"Class {class_id}: {count} pixels")
        
        return all_classes

    @staticmethod
    def plot_class_distribution(train_dist, valid_dist, test_dist):
        """
        Plot the distribution of classes across train, validation, and test sets.
        """
        classes = sorted(set(train_dist.keys()) | set(valid_dist.keys()) | set(test_dist.keys()))
        train_counts = [train_dist.get(c, 0) for c in classes]
        valid_counts = [valid_dist.get(c, 0) for c in classes]
        test_counts = [test_dist.get(c, 0) for c in classes]

        x = np.arange(len(classes))
        width = 0.25

        fig, ax = plt.subplots(figsize=(15, 8))
        ax.bar(x - width, train_counts, width, label='Train')
        ax.bar(x, valid_counts, width, label='Validation')
        ax.bar(x + width, test_counts, width, label='Test')

        ax.set_ylabel('Pixel Count')
        ax.set_title('Class Distribution Across Datasets')
        ax.set_xticks(x)
        ax.set_xticklabels(classes)
        ax.legend()

        plt.tight_layout()
        plt.savefig('class_distribution.png')
        plt.close()

    @staticmethod
    def make_data_list(train_rate=0.7, valid_rate=0.2):
        """
        Generate data_list CSV files with stratified sampling and step-by-step output.
        """
        if not os.path.exists(WildScenesDataset._data_list_dir):
            os.makedirs(WildScenesDataset._data_list_dir)

        all_image_paths = []
        all_label_paths = []
        all_label_distributions = []

        print("Step 1: Collecting image and label paths, and computing label distributions")
        for root_dir in WildScenesDataset.root_dirs:
            image_base = os.path.join(root_dir, 'image')
            label_base = os.path.join(root_dir, 'indexLabel')

            if not os.path.exists(image_base) or not os.path.exists(label_base):
                print(f"Error: Image or label directory does not exist in {root_dir}")
                continue

            for image in os.listdir(image_base):
                image_path = os.path.join(image_base, image)
                label_path = os.path.join(label_base, image)

                if not (os.path.isfile(label_path) and os.path.exists(label_path)):
                    print(f"Warning: Skipping invalid file pair {image_path}, {label_path}")
                    continue

                all_image_paths.append(image_path)
                all_label_paths.append(label_path)
                all_label_distributions.append(WildScenesDataset.get_label_distribution(label_path))

        print(f"Total images processed: {len(all_image_paths)}")

        print("\nStep 2: Creating DataFrame with image paths, label paths, and distributions")
        df = pd.DataFrame({
            'image': all_image_paths,
            'label': all_label_paths,
            'distribution': all_label_distributions
        })

        print("DataFrame shape:", df.shape)
        print("Sample of the DataFrame:")
        print(df.head())

        print("\nStep 3: Creating stratification column based on the most common class in each image")
        df['strat'] = df['distribution'].apply(lambda x: max(x, key=x.get))
        print("Unique stratification values:", df['strat'].unique())

        print("\nStep 4: Performing stratified split")
        train_valid, test = train_test_split(df, test_size=1-train_rate-valid_rate, stratify=df['strat'], random_state=42)
        train, valid = train_test_split(train_valid, test_size=valid_rate/(train_rate+valid_rate), stratify=train_valid['strat'], random_state=42)

        print('Total: {:d} | Train: {:d} | Validation: {:d} | Test: {:d}'.format(
            len(df), len(train), len(valid), len(test)))

        train_dist = WildScenesDataset.print_class_distribution(train, "Train")
        valid_dist = WildScenesDataset.print_class_distribution(valid, "Validation")
        test_dist = WildScenesDataset.print_class_distribution(test, "Test")

        WildScenesDataset.plot_class_distribution(train_dist, valid_dist, test_dist)

        print("\nStep 5: Saving train, valid, and test sets to CSV files")
        train[['image', 'label']].to_csv(WildScenesDataset.csv['train'], index=False)
        valid[['image', 'label']].to_csv(WildScenesDataset.csv['valid'], index=False)
        test[['image', 'label']].to_csv(WildScenesDataset.csv['test'], index=False)
        print("CSV files saved successfully.")

def main():
    print("Starting WildScenes dataset processing...")
    WildScenesDataset.make_data_list()
    print("Dataset processing completed.")

    # Example usage of the dataset
    print("\nTesting dataset loading...")
    train_dataset = WildScenesDataset('train')
    print(f"Train dataset size: {len(train_dataset)}")
    
    # Load and display a sample image and label
    sample_image, sample_label = train_dataset[0]
    print(f"Sample image shape: {np.array(sample_image).shape}")
    print(f"Sample label shape: {sample_label.shape}")
    print(f"Unique values in sample label: {np.unique(sample_label)}")

if __name__ == '__main__':
    main()

ValueError: If using all scalar values, you must pass an index