In [58]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio as ta
import pytorch_lightning as pl

import numpy as np

In [59]:
def get_wav_paths(paths: list):
    wav_paths=[]
    if type(paths)==str:
        paths=[paths]
        
    for path in paths:
        for root, dirs, files in os.walk(path):
            wav_paths += [os.path.join(root,file) for file in files if os.path.splitext(file)[-1]=='.wav']
                        
    wav_paths.sort(key=lambda x: os.path.split(x)[-1])
    
    return wav_paths

def check_dir_exist(path_list):
    if type(path_list) == str:
        path_list = [path_list]
        
    for path in path_list:
        if type(path) == str and os.path.splitext(path)[-1] == '' and not os.path.exists(path):
            os.makedirs(path)       

def get_filename(path):
    return os.path.splitext(os.path.basename(path))  

In [68]:
class CEDataset(Dataset): 
  #데이터셋의 전처리를 해주는 부분
    def __init__(self, path_dir_orig, path_dir_dist, seg_len=2):
        self.path_dir_orig   = path_dir_orig  
        self.path_dir_dist   = path_dir_dist  
        self.seg_len = seg_len
        self.wavs={}
        self.filenames= []


        paths_wav_orig = get_wav_paths(self.path_dir_orig)
        paths_wav_dist= get_wav_paths(self.path_dir_dist)

        for path_wav_orig, path_wav_dist in zip(paths_wav_orig, paths_wav_dist):
            filename=get_filename(path_wav_orig)[0]
            wav_orig,_=ta.load(path_wav_orig)
            wav_dist,_=ta.load(path_wav_dist)
            self.wavs[filename]=(wav_orig, wav_dist)
            self.filenames.append(filename)

        self.filenames.sort()
        

    # 총 데이터의 개수를 리턴
    def __len__(self):
        return len(self.filenames)


    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx):

        filename = self.filenames[idx]
        (wav_orig, wav_dist) = self.wavs[filename]
        
        if self.seg_len>0:
            duration= int(self.seg_len * 16000)

            wav_orig= wav_orig.view(1,-1)
            wav_dist= wav_dist.view(1,-1)

            sig_len = wav_orig.shape[-1]

            t_start = np.random.randint(
                low = 0,
                high= np.max([1, sig_len- duration - 2]),
                size = 1
            )[0]
            t_end = t_start + duration

            wav_orig = wav_orig.repeat(1, t_end // sig_len + 1) [ ..., t_start : t_end]
            wav_dist = wav_dist.repeat(1, t_end// sig_len + 1) [ ..., t_start : t_end]

        return wav_orig, wav_dist, filename


In [69]:
dataset=CodecEnhancement(
    path_dir_orig = "/media/youngwon/Neo/NeoChoi/TIL_Dataset/AECNN_enhancement/target",
    path_dir_dist = "/media/youngwon/Neo/NeoChoi/TIL_Dataset/AECNN_enhancement/decoded",
    seg_len = 2)

In [70]:
a=next(iter(dataset))

In [71]:
a1 = a[0].squeeze()

In [72]:
len(a1)

32000

In [78]:
class CEDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="/media/youngwon/Neo/NeoChoi/TIL_Dataset/AECNN_enhancement", batch_size=4, seg_len=2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seg_len = seg_len

    def prepare_data(self):
        pass
    
    def setup(self, stage= None):
        full_dataset = CEDataset(path_dir_orig = f"{self.data_dir}/target", path_dir_dist = f"{self.data_dir}/decoded", seg_len = self.seg_len)
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(full_dataset, [4000, len(full_dataset) - 4000])
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size = self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size = self.batch_size)

In [79]:
datamodule= CEDataModule()
datamodule.setup()

In [80]:
a=datamodule.train_dataloader()

In [81]:
a

<torch.utils.data.dataloader.DataLoader at 0x7f3571b08160>

In [82]:
next(iter(a))

[tensor([[[-0.0490, -0.0603,  0.0074,  ...,  0.0085,  0.0077,  0.0066]],
 
         [[ 0.0284,  0.0374,  0.0317,  ...,  0.0054,  0.0041,  0.0026]],
 
         [[ 0.0030, -0.0042, -0.0036,  ..., -0.0134, -0.0077, -0.0028]],
 
         [[-0.0052, -0.0043,  0.0000,  ..., -0.0070,  0.0160,  0.0123]]]),
 tensor([[[-2.1484e-02, -3.9032e-02,  1.2634e-02,  ...,  1.0071e-03,
            2.1057e-03,  3.0212e-03]],
 
         [[ 1.6785e-02,  2.4017e-02,  2.1576e-02,  ...,  4.6692e-03,
            3.0518e-03,  1.1292e-03]],
 
         [[ 4.1504e-03, -4.7302e-03, -5.8289e-03,  ..., -1.1658e-02,
           -7.8125e-03, -4.4861e-03]],
 
         [[-1.2543e-02, -1.0651e-02,  4.8828e-04,  ..., -3.0518e-05,
            9.2163e-03,  9.5215e-03]]]),
 ('MADC0_0001', 'MSAS0_0001', 'MDSS0_0007', 'MDED0_0007')]