In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import pandas as pd
import wave

In [24]:
class DatasetCombinerTTS:
    def __init__(self, base_path):
        self.base_path = base_path

    def read_audio_file(self, file_path):
        """Helper function to read audio file content."""
        with wave.open(file_path, 'rb') as wave_file:
            n_frames = wave_file.getnframes()
            frame_rate = wave_file.getframerate()
            audio_content = wave_file.readframes(n_frames)
        return audio_content, frame_rate

    def load_Persian_TTS(self):
        """Load and process Persian TTS dataset."""
        dataset_path = os.path.join(self.base_path, 'PersianTTSDataset')
        csv_path = os.path.join(dataset_path, 'metadata.csv')
        wav_folder = os.path.join(dataset_path, 'wavs')

        df = pd.read_csv(csv_path, header=None, names=['transcript', 'filename'], delimiter='|')

        audio_contents = []
        frame_rates = []
        transcripts = df['transcript'].tolist()

        for file_name in df['filename']:
            file_name = str(file_name)
            audio_path = os.path.join(wav_folder, file_name)
            if os.path.exists(audio_path):
                audio_content, frame_rate = self.read_audio_file(audio_path)
                audio_contents.append(audio_content)
                frame_rates.append(frame_rate)
            else:
                print(f"File {audio_path} does not exist")

        df_persian_tts = pd.DataFrame({
            'audio': audio_contents,
            'frame_rate': frame_rates,
            'transcript': transcripts
        })

        return df_persian_tts

    def load_Persian_TTS_female(self):
        """Load and process Persian TTS female dataset."""
        dataset_path = os.path.join(self.base_path, 'PersianTTSDataset_female')
        csv_path = os.path.join(dataset_path, 'metadata.csv')
        wav_folder = os.path.join(dataset_path, 'wavs')

        df = pd.read_csv(csv_path, header=None, names=['transcript', 'filename'], delimiter='|')

        audio_contents = []
        frame_rates = []
        transcripts = df['transcript'].tolist()

        for file_name in df['filename']:
            file_name = str(file_name)
            audio_path = os.path.join(wav_folder, file_name)
            if os.path.exists(audio_path):
                audio_content, frame_rate = self.read_audio_file(audio_path)
                audio_contents.append(audio_content)
                frame_rates.append(frame_rate)
            else:
                print(f"File {audio_path} does not exist")

        df_persian_tts_female = pd.DataFrame({
            'audio': audio_contents,
            'frame_rate': frame_rates,
            'transcript': transcripts
        })

        return df_persian_tts_female

    def load_Persian_TTS_male(self):
        """Load and process Persian TTS male dataset."""
        dataset_path = os.path.join(self.base_path, 'PersianTTSDataset_male')
        csv_path = os.path.join(dataset_path, 'metadata.csv')
        wav_folder = os.path.join(dataset_path, 'wavs')

        df = pd.read_csv(csv_path, header=None, names=['transcript', 'filename'], delimiter='|')

        audio_contents = []
        frame_rates = []
        transcripts = df['transcript'].tolist()

        for file_name in df['filename']:
            file_name = str(file_name)
            audio_path = os.path.join(wav_folder, file_name)
            if os.path.exists(audio_path):
                audio_content, frame_rate = self.read_audio_file(audio_path)
                audio_contents.append(audio_content)
                frame_rates.append(frame_rate)
            else:
                print(f"File {audio_path} does not exist")

        df_persian_tts_male = pd.DataFrame({
            'audio': audio_contents,
            'frame_rate': frame_rates,
            'transcript': transcripts
        })
        return df_persian_tts_male

    def combine_datasets(self):
        """Combine all datasets into a single DataFrame."""
        df_persian_tts = self.load_Persian_TTS()
        df_persian_tts_female = self.load_Persian_TTS_female()
        # df_persian_tts_male = self.load_Persian_TTS_male()

        # Print size of each dataset
        print(f"Persian TTS dataset size: {df_persian_tts.shape}")
        print(f"Persian TTS female dataset size: {df_persian_tts_female.shape}")
        # print(f"Persian TTS male dataset size: {df_persian_tts_male.shape}")

        # Combine all datasets
        combined_df = pd.concat([df_persian_tts, df_persian_tts_female], ignore_index=True)
        return combined_df

    def get_combined_dataset(self):
        """Get the combined dataset."""
        return self.combine_datasets()



In [None]:
base_path = '/content/drive/MyDrive/FarsiTTS'
combiner = DatasetCombinerTTS(base_path)
combined_dataset = combiner.get_combined_dataset()

In [None]:
print(combined_dataset.head())

In [None]:
print(f"Combined dataset size: {combined_dataset.shape}")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class PersianTTSDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        audio = self.dataframe.iloc[idx]['audio']
        frame_rate = self.dataframe.iloc[idx]['frame_rate']
        transcript = self.dataframe.iloc[idx]['transcript']
        sample = {'audio': audio, 'frame_rate': frame_rate, 'transcript': transcript}
        return sample

In [None]:
combined_dataset = combiner.get_combined_dataset()
persian_tts_dataset = PersianTTSDataset(combined_dataset)

dataloader = DataLoader(persian_tts_dataset, batch_size=32, shuffle=True)

# for batch in dataloader:
#     print(batch)
#     break