In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio as ta

import pytorch_lightning as pl

from utils import *

In [10]:
def get_wav_paths(paths: list):
    wav_paths=[]
    if type(paths)==str:
        paths=[paths]
        
    for path in paths:
        for root, dirs, files in os.walk(path):
            wav_paths += [os.path.join(root,file) for file in files if os.path.splitext(file)[-1]=='.wav']
                        
    wav_paths.sort(key=lambda x: os.path.split(x)[-1])
    
    return wav_paths

def check_dir_exist(path_list):
    if type(path_list) == str:
        path_list = [path_list]
        
    for path in path_list:
        if type(path) == str and os.path.splitext(path)[-1] == '' and not os.path.exists(path):
            os.makedirs(path)       

def get_filename(path):
    return os.path.splitext(os.path.basename(path))  

In [13]:
class CodecEnhancement(Dataset): 
  #데이터셋의 전처리를 해주는 부분
    def __init__(self, path_dir_orig, path_dir_dist, seg_len):
        self.path_dir_orig   = path_dir_orig  
        self.path_dir_dist   = path_dir_dist  
        self.seg_len = seg_len

        paths_wav_orig = get_wav_paths(self.path_dir_orig)
        paths_wav_dist= get_wav_paths(self.path_dir_dist)

        for path_wav_orig, path_wav_dist in zip(paths_wav_orig, paths_wav_dist):
            filename=get_filename(path_wav_orig)[0]
            wav_orig,_=ta.load(path_wav_orig)
            wav_dist,_=ta.load(path_wav_dist)
            self.wavs[filename]=(wav_orig, wav_dist)
            self.filenames.append(filename)
        

    # 총 데이터의 개수를 리턴
    def __len__(self):
        return len(self.filenames)



    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx):

        filename = self.filenames[idx]
        (wave_orig, wav_dist) = self.wavs[filename]
        
        if self.seg_len>0:
            duration= int(self.seg_len * 16000)

            wav_orig= wav_orig.view(1,-1)
            wav_dist= wav_dist.view(1,-1)

            sig_len = wav_orig[-1]

            t_start = np.random.randint(
                low = 0,
                high= np.max([1, sig_len- duration - 2]),
                size = 1
            )[0]
            t_end = t_start + duration

            wav_orig = wav_orig.repeat(1, t_end // sig_len + 1) [ ..., t_start : t_end]
            wav_dist = wav_dist.repeat(1, t_end// sig_len + 1) [ ..., t_start : t_end]

        return wav_orig, wav_dist, filename


In [14]:
dataset=CodecEnhancement(
    path_dir_orig = "/media/youngwon/Neo/NeoChoi/TIL_Dataset/AECNN_enhancement/target",
    path_dir_dist = "/media/youngwon/Neo/NeoChoi/TIL_Dataset/AECNN_enhancement/decoded",
    seg_len = 2)

AttributeError: 