In [1]:
import os
import sys
import math
import wave
import librosa
import resampy
import numpy as np
from pathlib import Path
from joblib import Parallel, delayed
from natsort import natsorted
from loguru import logger
from IPython.display import display, Audio

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

logger.remove()
logger.add(sys.stderr, level="INFO");

In [2]:
def get_num_segments(path, max_segment_length, min_segment_length):
    try:
        with wave.open(str(path), mode="r") as wavefile:
            sr = wavefile.getframerate()
            samples = wavefile.getnframes()
            duration = samples / sr
            num_segments = int(duration / max_segment_length)
            if duration % max_segment_length >= min_segment_length:
                num_segments += 1
            return num_segments
    except wave.Error:
        return 0


class WavDataset(Dataset):
    def __init__(self, paths, max_segment_length=10, min_segment_length=4, sample_rate=22050, mono=True):
        super().__init__()
        # Scan for wav file(s)
        if isinstance(paths, (str, Path)):
            # Single file or single directory
            paths = Path(paths)
            self.paths = list(paths.iterdir()) if paths.is_dir() else [paths]
        else:
            # List of files
            self.paths = list(map(Path, paths))
        if any([path.suffix.lower() != ".wav" for path in self.paths]):
            logger.warning("One or more filenames have a file extension other than '.wav'")
        self.paths = np.asarray(natsorted(self.paths))

        self.max_segment_length = max_segment_length
        self.min_segment_length = min_segment_length
        self.sample_rate = sample_rate
        self.mono = mono

        self.num_track_segments = np.array(Parallel(n_jobs=-1)(delayed(get_num_segments)(str(path), self.max_segment_length, self.min_segment_length) for path in self.paths))
        valid_tracks_mask = self.num_track_segments > 0
        invalid_tracks_mask = ~valid_tracks_mask
        if invalid_tracks_mask.sum() > 0:
            logger.warning(f"Failed to load {invalid_tracks_mask.sum()} tracks.")
            for invalid_index in invalid_tracks_mask.nonzero()[0]:
                logger.debug(f"Failed to load {str(self.paths[invalid_index])}")
            self.paths = self.paths[valid_tracks_mask]
            self.num_track_segments = self.num_track_segments[valid_tracks_mask]
        self.cumulative_num_track_segments = np.cumsum(self.num_track_segments)
        self.num_total_segments = self.cumulative_num_track_segments[-1]

    def __getitem__(self, index):
        if index < 0:
            index = self.num_total_segments + index
        if index >= len(self):
            raise IndexError(f"Sample index out of range. Max index is {len(self) - 1}")
        track_index = np.min(np.where(self.cumulative_num_track_segments > index))
        if track_index == 0:
            index_remainder = index
        else:
            index_remainder = index - self.cumulative_num_track_segments[track_index - 1]
        track_path = self.paths[track_index]
        with wave.open(str(track_path), mode="r") as wavefile:
            num_channels, sample_width, track_sample_rate, track_num_samples, _, _ = wavefile.getparams()
            start_pos = track_sample_rate * self.max_segment_length * index_remainder
            num_samples = track_sample_rate * self.max_segment_length

            # Load raw audio
            wavefile.setpos(start_pos)
            buffer = wavefile.readframes(num_samples)
            if sample_width == 3:
                audio = np.empty((num_samples, num_channels, 4), dtype=np.uint8)
                raw_bytes = np.frombuffer(buffer, dtype=np.uint8)
                audio[:, :, :sample_width] = raw_bytes.reshape(-1, num_channels, sample_width)
                audio[:, :, sample_width:] = (a[:, :, sample_width - 1:sample_width] >> 7) * 255
                audio = audio.view('<i4').reshape(audio.shape[:-1])
            else:
                audio = np.frombuffer(buffer, dtype=f"<{'u' if sample_width == 1 else 'i'}{sample_width}").reshape(-1, num_channels)
            audio = audio.T  # Channel-first index order

            # Convert to tensor
            audio = librosa.util.fix_length(audio, num_samples)
            if self.sample_rate is None or self.sample_rate == track_sample_rate:
                output_sr = track_sample_rate
            else:
                audio = resampy.resample(audio, track_sample_rate, self.sample_rate, filter="kaiser_fast")
                output_sr = self.sample_rate
            if self.mono and audio.ndim > 1:
                audio = audio.mean(0)
            return torch.tensor(audio, dtype=torch.float32), torch.tensor(output_sr, dtype=torch.float32)

    def __len__(self):
        return self.cumulative_num_track_segments[-1]

In [3]:
# AUDIO_DIR = Path("../data/shetti/original/[WAV] Eurobeat Remix Stand Proud.wav")
WAV_DIR = Path("../data/shetti/wavs-22050/")
MAX_SEGMENT_LENGTH = 5
MIN_SEGMENT_LENGTH = 1
SAMPLE_RATE = 22050
MONO = True

dataset = WavDataset(WAV_DIR, max_segment_length=MAX_SEGMENT_LENGTH, min_segment_length=MIN_SEGMENT_LENGTH, sample_rate=SAMPLE_RATE, mono=MONO)

In [4]:
from time import time
from tqdm.auto import tqdm

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=6)
start = time()
for audio, sr in tqdm(dataloader):
    pass
#     audio.cuda(), sr.cuda()
print("Total iteration time:", time() - start)
# audio, sr = dataset[14]
# display(Audio(audio, rate=sr, autoplay=False))

HBox(children=(FloatProgress(value=0.0, max=1079.0), HTML(value='')))


Total iteration time: 47.88576793670654
