In [None]:
import soundfile as sf
import pandas as pd
import torchaudio
import torchaudio.transforms as transforms
from typing import Union, Callable
import torch as tc
import pathlib
from enum import Enum

class VoiceLoadMode(Enum):
    RAW = 1
    SPECTROGRAM = 2
    MELSPECTROGRAM = 3

class VoiceDataset(tc.utils.data.Dataset):
    def __init__(
        self,
        csv_path: Union[str, pathlib.Path],
        load_mode: VoiceLoadMode,
        loading_params: dict = {},
        raw_transforms: Callable = None,
        representation_transforms: Callable = None,
        return_metadata: bool = False,
        return_ground_truth: bool = False,
        ground_truth_mapper: dict = None
    ):
        self.csv_path = csv_path
        self.load_mode = load_mode
        self.loading_params = loading_params
        self.raw_transforms = raw_transforms
        self.representation_transforms = representation_transforms
        self.return_metadata = return_metadata
        self.return_ground_truth = return_ground_truth
        self.ground_truth_mapper = ground_truth_mapper

        self.metadata_df = pd.read_csv(csv_path)
        self.data_path = self.metadata_df["file_path"].tolist()

        self.transform = transforms.Compose([
            transforms.SomeTransform(),
            transforms.SomeOtherTransform()
        ])

        self.trim_dataset()

    def __len__(self) -> int:
        return len(self.data_path)

    def __getitem__(self, idx):
        audio_file_path = self.data_path[idx]
        audio_data, sample_rate = sf.read(audio_file_path)

        if self.raw_transforms is not None:
            audio_data = self.raw_transforms(audio_data)

        metadata = None
        if self.return_metadata:
            metadata = self.get_metadata(idx)

        ground_truth = None
        if self.return_ground_truth:
            ground_truth = self.get_ground_truth(idx)

        if self.representation_transforms is not None:
            audio_data = self.representation_transforms(audio_data)

        if self.ground_truth_mapper is not None and ground_truth is not None:
            ground_truth = self.ground_truth_mapper.get(ground_truth, ground_truth)

        if self.return_metadata and self.return_ground_truth:
            return audio_data, metadata, ground_truth
        elif self.return_metadata:
            return audio_data, metadata
        elif self.return_ground_truth:
            return audio_data, ground_truth
        else:
            return audio_data

    def get_metadata(self, index):
        row = self.metadata_df.iloc[index]
        metadata = row.to_dict()
        return metadata

    def get_ground_truth(self, index):
        ground_truth = self.metadata_df.iloc[index]["label"]
        return ground_truth

    def trim_dataset(self):
        num_files = len(self.data_path)
        num_rows = len(self.metadata_df)

        if num_files != num_rows:
            print(f"Warning: Number of files ({num_files}) does not match number of rows in CSV ({num_rows}).")

            if num_files < num_rows:
                self.metadata_df = self.metadata_df.head(num_files)
            else:
                self.data_path = self.data_path[:num_rows]
