WERSJA ROBOCZA

In [None]:
dataset = VoiceDataset()
dataset.csv_path = "sciezka/do/pliku.csv"
dataset.data_path = ["sciezka/do/audio1.wav", "sciezka/do/audio2.wav"]
dataset.return_metadata = True
dataset.return_ground_truth = True
dataset.raw_transforms = raw_transforms
dataset.representation_transforms = representation_transforms
dataset.ground_truth_mapper = {"Healthy": 0, "Parkinson": 1}


In [None]:
import soundfile as sf
import pandas as pd
import torchaudio
import torchaudio.transforms as transforms

def get_metadata(self, index):
    metadata_df = pd.read_csv(self.csv_path)
    row = metadata_df.iloc[index]
    metadata = row.to_dict()
    return metadata

def get_ground_truth(self, index):
    metadata_df = pd.read_csv(self.csv_path)
    ground_truth = metadata_df.iloc[index]["label"]
    return ground_truth

class VoiceDataset(tc.utils.data.Dataset):
    def __getitem__(self, index):
        audio_file_path = self.data_path[index]
        audio_data, sample_rate = sf.read(audio_file_path)

        if self.raw_transforms is not None:
            audio_data = self.raw_transforms(audio_data)

        metadata = None
        if self.return_metadata:
            metadata = self.get_metadata(index)

        ground_truth = None
        if self.return_ground_truth:
            ground_truth = self.get_ground_truth(index)

        if self.representation_transforms is not None:
            audio_data = self.representation_transforms(audio_data)

        if self.ground_truth_mapper is not None and ground_truth is not None:
            ground_truth = self.ground_truth_mapper.get(ground_truth, ground_truth)

        if self.return_metadata and self.return_ground_truth:
            return audio_data, metadata, ground_truth
        elif self.return_metadata:
            return audio_data, metadata
        elif self.return_ground_truth:
            return audio_data, ground_truth
        else:
            return audio_data


def raw_transforms(audio_data, transform_type="normalize", **kwargs):
    if transform_type == "normalize":
        normalized_audio = transforms.Normalize()(audio_data)
        return normalized_audio

    elif transform_type == "filter":
        filtered_audio = transforms.LowpassFilter(**kwargs)(audio_data)
        return filtered_audio

    elif transform_type == "change_volume":
        volume_changed_audio = transforms.Vol()(audio_data, **kwargs)
        return volume_changed_audio

    elif transform_type == "time_stretch":
        time_stretched_audio = transforms.TimeStretch()(audio_data, **kwargs)
        return time_stretched_audio

    elif transform_type == "spectrogram":
        waveform = torchaudio.tensor(audio_data)
        spectrogram = transforms.Spectrogram()(waveform)
        return spectrogram

    elif transform_type == "melspectrogram":
        waveform = torchaudio.tensor(audio_data)
        melspectrogram = transforms.MelSpectrogram()(waveform)
        return melspectrogram

    else:
        raise ValueError("Invalid transform_type. Please choose one of the supported transformations.")


def representation_transforms(audio_representation, transform_type="crop", **kwargs):
    if transform_type == "crop":
        transformed_representation = transforms.Crop(**kwargs)(audio_representation)

    elif transform_type == "resize":
        transformed_representation = transforms.Resize(**kwargs)(audio_representation)

    elif transform_type == "add_noise":
        transformed_representation = transforms.AdditiveNoise(**kwargs)(audio_representation)

    elif transform_type == "adjust_contrast":
        transformed_representation = transforms.Contrast(**kwargs)(audio_representation)

    else:
        raise ValueError("Invalid transform_type. Please choose one of the supported transformations.")

    return transformed_representation


In [None]:
def __len__(self) -> int:
    """
    Returns the dataset size.
    NOTE - handle self.iteration_size correctly
    """
    return len(self.data_path)


In [None]:
class VoiceDataset(tc.utils.data.Dataset):
    def __getitem__(self, idx):
        """
        Loads given case and returns it with the appropriate representation.
        For raw data: [CxL], where C is the number of channels, L is the number of samples
        For spectrogram/melspectrogram data: [CxYxX], where C is the number of channels, Y is the number of rows, X is the number of columns

        The inputs will then collated to [BxCxL] or [BxCxYxX] by the collating function during loading (does not need to be defined now).
        """
        audio_file_path = self.data_path[idx]
        audio_data, sample_rate = sf.read(audio_file_path)

        if self.raw_transforms is not None:
            audio_data = self.raw_transforms(audio_data)

        metadata = None
        if self.return_metadata:
            metadata = self.get_metadata(idx)

        ground_truth = None
        if self.return_ground_truth:
            ground_truth = self.get_ground_truth(idx)

        if self.representation_transforms is not None:
            audio_data = self.representation_transforms(audio_data)

        if self.ground_truth_mapper is not None and ground_truth is not None:
            ground_truth = self.ground_truth_mapper.get(ground_truth, ground_truth)

        # Perform collation here
        if len(audio_data.shape) == 2: 
            audio_data = audio_data.unsqueeze(0)  
        elif len(audio_data.shape) == 3:  
            audio_data = audio_data.unsqueeze(0)  
        data_tuple = (audio_data, metadata, ground_truth)

        return data_tuple


OSTATECZNA WERSJA:

In [None]:
import soundfile as sf
import pandas as pd
import torchaudio
import torchaudio.transforms as transforms
from typing import Union, Callable
import torch as tc
import pathlib
from enum import Enum

class VoiceLoadMode(Enum):
    RAW = 1
    SPECTROGRAM = 2
    MELSPECTROGRAM = 3

class VoiceDataset(tc.utils.data.Dataset):
    def __init__(
        self,
        data_path: Union[str, pathlib.Path],
        csv_path: Union[str, pathlib.Path],
        load_mode: VoiceLoadMode,
        loading_params: dict = {},
        raw_transforms: Callable = None,
        representation_transforms: Callable = None,
        return_metadata: bool = False,
        return_ground_truth: bool = False,
        ground_truth_mapper: dict = None
    ):
        self.data_path = data_path
        self.csv_path = csv_path
        self.load_mode = load_mode
        self.loading_params = loading_params
        self.raw_transforms = raw_transforms
        self.representation_transforms = representation_transforms
        self.return_metadata = return_metadata
        self.return_ground_truth = return_ground_truth
        self.ground_truth_mapper = ground_truth_mapper

    def __len__(self) -> int:
        return len(self.data_path)

    def __getitem__(self, idx):
        audio_file_path = self.data_path[idx]
        audio_data, sample_rate = sf.read(audio_file_path)

        if self.raw_transforms is not None:
            audio_data = self.raw_transforms(audio_data)

        metadata = None
        if self.return_metadata:
            metadata = self.get_metadata(idx)

        ground_truth = None
        if self.return_ground_truth:
            ground_truth = self.get_ground_truth(idx)

        if self.representation_transforms is not None:
            audio_data = self.representation_transforms(audio_data)

        if self.ground_truth_mapper is not None and ground_truth is not None:
            ground_truth = self.ground_truth_mapper.get(ground_truth, ground_truth)

        if self.return_metadata and self.return_ground_truth:
            return audio_data, metadata, ground_truth
        elif self.return_metadata:
            return audio_data, metadata
        elif self.return_ground_truth:
            return audio_data, ground_truth
        else:
            return audio_data

    def get_metadata(self, index):
        metadata_df = pd.read_csv(self.csv_path)
        row = metadata_df.iloc[index]
        metadata = row.to_dict()
        return metadata

    def get_ground_truth(self, index):
        metadata_df = pd.read_csv(self.csv_path)
        ground_truth = metadata_df.iloc[index]["label"]
        return ground_truth
