# Dataset class
Custom PyTorch Dataset that knows how to load spectrograms.
(PyTorch requires this structure, it encapsulates all data loading logic).
Contains:
- __init__: setup
- __len__: total number of samples
- __getitem__: load and return a single spectrogram-label pair

In [1]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

In [2]:
class DS(Dataset):

    def __init__(self, csv='/Users/hela/Code/pata/data_labeled.csv', transform=None):
        self.df = pd.read_csv(csv)
        self.transform = transform
        # label to index mapping
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(self.df['label'].unique()))}
        self.idx_to_label = {i:l for l,i in self.label_to_idx.items()}
        print(f'Dataset initialized with: {len(self.df)} samples')
        print(f'Label mapping: {self.label_to_idx}')
        print(f'Class distribution:\n{self.df['label'].value_counts()}')

    # get total number of samples
    def __len__(self):
        return len(self.df)

    # get a single example (image tensor + label index) for a given index
    def __getitem__(self, idx):
        # load image
        image = Image.open(self.df.iloc[idx]['image_path']).convert('RGB')
        # apply transforms
        if self.transform:
            image = self.transform(image)
        # get label
        label = self.label_to_idx[self.df.iloc[idx]['label']]
        return image, label

In [5]:
# CHECK whether and how DS class works
# (also, I tested each line separately, by just copy-pasting and deleting 'self.')

full_ds = DS()

Dataset initialized with: 1600 samples
Label mapping: {'pa': 0, 'ta': 1}
Class distribution:
label
pa    800
ta    800
Name: count, dtype: int64
