In [33]:
from torch.utils.data import DataLoader
from torchvision import transforms, utils
import numpy as np
import torch
import argparse
import torch
import sys
import librosa
import torchaudio
from feature.feature_prepare import thchs30,MagicData,primewords,ST_CMDS


def load_audio(path):
    sound, _ = torchaudio.load(path)
    sound = sound.numpy().T
    if len(sound.shape) > 1:
        if sound.shape[1] == 1:
            sound = sound.squeeze()
        else:
            sound = sound.mean(axis=1)  # multiple channels, average
    return sound

class ASRDataset(Dataset):
    def __init__(self, label_dic,file_dic, usage = "train",transform=None):
        self.label_dic = label_dic
        self.file_dic = file_dic
        self.usage = usage
        self.transform = transform
        self.audio_conf = dict(sample_rate=16000,
                        window_size=0.02,
                        window_stride=0.01,
                        window="hamming",
                        noise_dir=None,
                        noise_prob=0.4,
                        noise_levels=(0.0, 0.5))
        self.normalize = True
    def __len__(self):
        """
        继承 Dataset 类后,必须重写的一个方法
        返回数据集的大小
        :return:
        """
        return len(self.file_dic[self.usage])
     
    
    def __getitem__(self, idx):
        """
        继承 Dataset 类后,必须重写的一个方法
        返回第 idx 个图像及相关信息
        :param idx:
        :return:
        """
        feature_length = 800
        file_name = self.file_dic[self.usage][idx]
        wav_dir =  self.file_dic["file_dic"]['wav'][file_name]
        feature = self.load_wav_feature(wav_dir)
        label = self.label_dic[file_name]['code']

#         while feature.shape[0] < feature_length:
#             feature = np.concatenate((feature, feature), axis=0)
#             label.extend(label)
        
#         feature = np.pad(feature,((0, 1600-feature.shape[0]), (0, 0)),'constant')
#         feature = np.reshape(feature, [1, feature.shape[0], feature.shape[1]])
        feature = torch.Tensor(feature)

        return feature,label,len(label)
    
    def load_wav_feature(self, file_dir):
        y = load_audio(file_dir)

        n_fft = int(self.audio_conf["sample_rate"] * self.audio_conf["window_size"])
        win_length = n_fft
        hop_length = int(self.audio_conf["sample_rate"]  * self.audio_conf["window_size"])
        # STFT
        D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length,
                         win_length=win_length, window=self.audio_conf["window"])
        spect, phase = librosa.magphase(D)
        # S = log(S+1)
        spect = np.log1p(spect)
        spect = torch.FloatTensor(spect)
        if self.normalize:
            mean = spect.mean()
            std = spect.std()
            spect.add_(-mean)
            spect.div_(std)

        return spect

        


def aligin_collate(batch_size):
    """process variable length labels """
    wave_list = list()
    label_list = list()
    length_list = list()
    for _, (wave, label, length) in enumerate(batch_size):
        wave_list.append(wave)
        label_list.extend(label)
        length_list.append(length)

    stacked_wave = torch.stack(wave_list, dim=0)
    label = torch.IntTensor(np.array(label_list))
    length = torch.IntTensor(np.array(length_list))

    return stacked_wave, label, length

def get_data_loader( ):


    thchs30_dataset = thchs30("/data/jiaxin.gu/jupyter/asr_test/dataset/data_thchs30",label_file_type = 'trn')
    thchs30_dataset.read_label_file()    
    ASRDataset_ = ASRDataset(thchs30_dataset.label_dic,thchs30_dataset.file_dic,"train")
    train_loader = torch.utils.data.DataLoader(ASRDataset_,
                                            batch_size=3,
                                            collate_fn=aligin_collate,
                                            shuffle=True,
                                            drop_last=True,
                                            num_workers=8)
    return train_loader

In [26]:

thchs30_dataset = thchs30("/data/jiaxin.gu/jupyter/asr_test/dataset/data_thchs30",label_file_type = 'trn')
thchs30_dataset.read_label_file()    


/data/jiaxin.gu/jupyter/asr_test/dataset/data_thchs30/


In [34]:
ASRDataset_ = ASRDataset(thchs30_dataset.label_dic,thchs30_dataset.file_dic,"train")


(tensor([[ 4.9414,  3.0051,  4.0349,  ...,  3.6628,  4.0328,  0.7338],
         [ 3.8549,  1.4800,  2.1494,  ...,  1.8797,  2.1140,  0.6266],
         [ 1.2094,  0.7675,  0.7176,  ..., -0.3736, -0.3496,  0.3540],
         ...,
         [-0.4184, -0.4353, -0.4386,  ..., -0.4373, -0.4356, -0.4284],
         [-0.4187, -0.4370, -0.4383,  ..., -0.4416, -0.4322, -0.4306],
         [-0.4164, -0.4399, -0.4301,  ..., -0.4391, -0.4257, -0.4346]]),
 [198,
  1159,
  1140,
  317,
  799,
  245,
  1039,
  58,
  345,
  443,
  239,
  75,
  331,
  1148,
  1080,
  799,
  559,
  1039,
  342,
  1159,
  842,
  844,
  1049,
  1221,
  1139,
  1084,
  164,
  219,
  1064,
  8,
  1233,
  561,
  75,
  52,
  271,
  326],
 36)

In [43]:
get_data_loader()

/data/jiaxin.gu/jupyter/asr_test/dataset/data_thchs30/


<torch.utils.data.dataloader.DataLoader at 0x7fb5c83c70f0>

In [44]:
MagicData()

/data/jiaxin.gu/anaconda3/envs/tf/lib/python36.zip/../dataset/train/


TypeError: 'NoneType' object is not iterable

In [None]:
thchs30,MagicData,primewords,ST_CMDS