In [None]:
import torch
import os
from torch import nn
from torch.utils.data import Dataset, DataLoader
from scipy.io import wavfile
import numpy as np
import simpleaudio as sa




In [None]:

# Data Loader

class AudioDataset(Dataset):
    
    def __init__(self, dataRoot, sampleRate = 16000, nMix = 2, soundLen = 5, dataType = 'tr', mixType='max'):
        """
        Args: 
            dataRoot: the root directory to wsj mixture, the directory setup follows matlab scrip used by Isik, Y., Le Roux, J., Chen, Z., Watanabe, S., & Hershey, J. R. (2016). Single-channel multi-speaker separation using deep clustering. Proceedings of the Annual Conference of the International Speech Communication Association, INTERSPEECH, 08-12-Sept, 545–549. https://doi.org/10.21437/Interspeech.2016-1176. In short, the mixtures and sources are in dataRoot/{wav16k, wav8k}/{max, min}/{tr, cv, tt}/{mix, s1, s2, ..}. The configuration file generated by the matlab file is put under dataRoot/{wav16k, wav8k}/{max, min}/
            
            sampleRate: 16000 or 8000
            nMix: the number of source speakers
            soundLen: the sound length in seconds, used to make all sound the same length
            dataType: 'tr', 'cv', or 'tt'
            mixType: 'max' or 'min', see matlab script used to generate the mixture

        """

        self.sampleRate = sampleRate
        self.nMix = nMix
        self.soundLen = soundLen
        self.dataType = dataType
        self.mixType = mixType
        self.indexEnd = int(self.sampleRate * self.soundLen)


        # Construct data directory
        if sampleRate == 16000:
            self.dataPath = os.path.join(dataRoot, 'wav16k')
        elif sampleRate == 8000:
            self.dataPath = os.path.join(dataRoot, 'wav8k')
        else:
            raise(ValueError("Sample rate can only be 16k or 8k"))

        if mixType == 'max':
            self.dataPath = os.path.join(self.dataPath, 'max')
        elif mixType == 'min':
            self.dataPath = os.path.join(self.dataPath, 'max')
        else:
            raise(ValueError("Mix type can only be max or min"))

        self.mixInventoryFile = os.path.join(self.dataPath, 'mix_{:d}_spk_{}_{}_mix'.format(nMix, mixType, dataType))

        if dataType == 'tr':
            self.dataPath = os.path.join(self.dataPath, 'tr')
        elif dataType == 'cv':
            self.dataPath = os.path.join(self.dataPath, 'cv')
        elif dataType == 'tt':
            self.dataPath = os.path.join(self.dataPath, 'tt')
        else:
            raise(ValueError("Data type can only be tr, cv or tt"))
        

        with open(self.mixInventoryFile, 'r') as inventoryFile:
            self.inventory = inventoryFile.read().splitlines()

        self.mixDir = os.path.join(self.dataPath, 'mix')         
        self.sDir = []
        for s in range(nMix):
            self.sDir.append(os.path.join(self.dataPath, 's{:d}'.format(s+1)))

        return
        

    def __len__(self):
        return len(self.inventory)


    def __getitem__(self, idx): 
        mixName = self.inventory[idx]+'.wav'
        sr, mixture = wavfile.read(os.path.join(self.dataPath, 'mix', mixName))
        assert sr == self.sampleRate
        mixture = self._trimOrPadAudio(mixture)
        sources = np.zeros((mixture.shape[0], self.nMix))
        for s in range(self.nMix):
            sr , tempSource = wavfile.read(os.path.join(self.dataPath, 's{:d}'.format(s+1), mixName))
            assert sr == self.sampleRate
            sources[:, s] = self._trimOrPadAudio(tempSource)
        return mixture, sources

    def _trimOrPadAudio(self, sound):
        currentLength = sound.shape[0]
        if currentLength > self.indexEnd:
            soundAdjusted = np.copy(sound[0:self.indexEnd])
        else:
            soundAdjusted = np.concatenate((sound, np.zeros(self.indexEnd-currentLength)))

        return soundAdjusted

    # Test the __getitem__ method and play sounds
    def playIdx(self, idx):
        mixture, sources = self.__getitem__(idx)
        sound = np.ascontiguousarray(mixture, dtype=np.int16)
        play_obj = sa.play_buffer(sound, 1, 16//8, self.sampleRate)
        play_obj.wait_done()
        for s in range(self.nMix):
            sound = np.ascontiguousarray(sources[:, s], dtype=np.int16)
            play_obj = sa.play_buffer(sound, 1, 16//8, self.sampleRate)
            play_obj.wait_done()
        return



In [None]:
import torch
import torch.nn as nn
sampleRate = 8000
dataRoot = ''
trDataset = AudioDataset(dataRoot, sampleRate=sampleRate, nMix=2,  dataType='tr', mixType='max')

mixture, sources = trDataset.__getitem__(0)
#trDataset.playIdx(1000)


AttributeError: module 'torch.nn' has no attribute 'Ten'

In [None]:

# test dataloader

train_dataloader = DataLoader(trDataset, batch_size=64, shuffle=True)
for batch, (mixture, targets) in enumerate(train_dataloader):
    print(batch)
mixture, sources = next(iter(train_dataloader))
mixture.shape
sources.shape

0


1


2


3


4


5


6


7


8


9


10


11


12


13


14


15


16


17


18


19


20


21


22


23


24


25


26


27


28


29


30


31


32


33


34


35


36


37


38


39


40


41


42


43


44


45


46


47


48


49


50


51


52


53


54


55


56


57


58


59


60


61


62


63


64


65


66


67


68


69


70


71


72


73


74


75


76


77


78


79


80


81


82


83


84


85


86


87


88


89


90


91


92


93


94


95


96


97


98


99


100


101


102


103


104


105


106


107


108


109


110
111


112


113


114


115


116


117


118


119


120


121


122


123


124


125


126


127


128


129


130


131


132


133


134


135


136


137


138


139


140


141


142


143


144


145
146
147


148


149


150


151


152


153


154


155


156


157


158


159


160


161


162


163


164


165


166


167


168


169


170


171


172


173


174


175


176


177


178


179


180
181


182


183


184


185


186


187


188


189


190


191


192


193


194


195


196


197


198


199


200


201


202


203


204


205


206


207


208


209


210


211


212


213


214


215


216


217


218


219


220


221


222


223


224


225


226


227


228


229
230


231


232


233


234


235


236


237


238


239


240


241


242


243


244


245


246


247


248


249


250


251


252


253


254


255


256


257


258


259


260


261


262


263


264


265


266


267


268


269


270


271


272


273


274


275


276


277


278


279


280


281


282


283


284


285


286


287


288


289


290
291


292
293


294
295


296
297


298
299


300
301


302
303


304
305


306
307


308


309
310


311
312


torch.Size([64, 40000, 2])