Other examples of dataset:
* [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py)
* generator for [tarballs](https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractfile) and [zip](https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile.open)

In [1]:
import torch
import torchvision 
import torchaudio

import os
import random
from functools import reduce, partial
from warnings import warn
import pickle

In [2]:
def get_data(URL):
    r = requests.get(URL)
    file_like_object = io.BytesIO(r.content)
    tar = tarfile.open(fileobj=file_like_object)
    d = {}
    for member in tar.getmembers():
        if member.isfile() and member.name.endswith('csv'):
            k = 'train' if 'train' in member.name else 'test'
            d[k] = tar.extractfile(member)
    return d

# Common tools

In [3]:
class Cache:
    """
    Wrap a generator so that, whenever a new item is returned, it is saved to disk in a pickle.
    """

    def __init__(self, generator, location):
        self.generator = generator
        self.location = location

        self._id = id(self)
        self._cache = []
        self._internal_index = 0

    def __iter__(self):
        self._internal_index = 0
        return self

    def __next__(self):
        if self._internal_index < len(self):
            item = self[self._internal_index]
        else:
            item = next(self.generator)
        
            file = str(self._id) + "-" + str(len(self))
            file = os.path.join(self.location, file)
            self._cache.append(file)
        
            os.makedirs(self.location, exist_ok=True)
            with open(file, 'wb') as file:
                pickle.dump(item, file)

        self._internal_index += 1
        return item
    
    def __getitem__(self, index):
        file = self._cache[index]
        with open(file, 'rb') as file:
            item = pickle.load(file)
        return item
    
    def __len__(self):
        # Return length of cache
        return len(self._cache)

In [4]:
def compose(*funcs):
    """
    Compose multiple generator.
    """
    return lambda x: reduce(lambda f, g: g(f), list(funcs), x)


def download(urls, root_path):
    """
    Download each url to root_path.
    
    Input: url generator, folder inside archive
    Output: downloaded archive, folder inside archive
    """
    for url, folder in urls:
        # torchvision.datasets.utils.download_url(url, root_path)
        file = os.path.join(root_path, os.path.basename(url))
        yield file, folder
    
    
def extract(files):
    """
    Extract each archive to their respective folder.
    
    Input: (url, folder name inside archive) generator
    Output: path to inside archive
    """
    for file, folder in files:
        # torchvision.datasets.utils.extract_archive(file)
        path = os.path.dirname(file)
        path = os.path.join(path, folder)
        yield path
          
            
def walk(paths, extension):
    """
    Walk inside a path recursively to find all files with given extension.
    
    Input: path
    Output: path, file name identifying a row of data
    """
    for path in paths:
        for dp, dn, fn in os.walk(path):
            for f in fn:
                if extension in f:
                    yield path, f

                    
def shuffle(generator):
    """
    Shuffle the order of a generator.
    
    Input: generator
    Output: generator
    """

    # Need to load the whole list in memory
    generator = list(generator)
    # print(len(generator))
    random.shuffle(generator)
    for g in generator:
        yield g

        
def filtering(fileids, reference):
    """
    Skip fileids that are not present in given reference file.
    
    Output: (path, file) generator, reference file
    Output: path, file
    """
    
    path_old = ""
    
    for path, fileid in fileids:
        
        if path != path_old:
            # Check if same path to avoid reloading the file constantly
            ref = os.path.join(path, reference)
            with open(ref) as ref:
                r = "".join(ref.readlines())

        # It would be more efficient to loop through the reference file instead
        if fileid in r:
            yield path, fileid

# YesNo

In [5]:
def load_yesno(fileids):
    """
    Load data corresponding to each YESNO fileids.
    
    Input: path, file name identifying a row of data
    Output: label, waveform, sample_rate
    """
    
    extension = ".wav"
    for path, fileid in fileids:
        file = os.path.join(path, fileid)
        waveform, sample_rate = torchaudio.load(file)
        label = os.path.basename(fileid).split(".")[0].split("_")
    
        yield label, waveform, sample_rate
        

def YESNO(root):
    """
    Cache a pipeline loading YESNO.
    """
    
    url = [
        ("http://www.openslr.org/resources/1/waves_yesno.tar.gz", "waves_yesno")
    ]
     
    pipeline = compose(
        partial(download, root_path=root),
        extract,
        partial(walk, extension=".wav"),
        shuffle,
        load_yesno,
    )
    
    return Cache(pipeline(url), "tmp/")


data = YESNO("/Users/vincentqb/yesnotest")

next(data)
data[0]

(['1', '1', '1', '0', '1', '0', '1', '0'],
 tensor([[ 0.0016,  0.0017,  0.0016,  ..., -0.0016, -0.0010, -0.0002]]),
 8000)

# VCTK

In [6]:
def load_vctk(fileids):
    """
    Load data corresponding to each VCTK fileids.

    Input: path, file name identifying a row of data
    Output: id, content, waveform, sample_rate
    """
    
    txt_folder = "txt"
    txt_extension = ".txt"
    
    audio_folder = "wav48"
    audio_extension = ".wav"
    
    for path, fileid in fileids:
        
        fileid = os.path.basename(fileid).split(".")[0]
        folder = fileid.split("_")[0]
        txt_file = os.path.join(path, txt_folder, folder, fileid + txt_extension)        
        audio_file = os.path.join(path, audio_folder, folder, fileid + audio_extension)        
        
        try:
            with open(txt_file) as txt_file:
                content = txt_file.readlines()[0]
        except FileNotFoundError:
            warn("Translation not found for {}".format(audio_file))
            # warn("File not found: {}".format(txt_file))
            continue

        waveform, sample_rate = torchaudio.load(audio_file)
        
        yield fileid, content, waveform, sample_rate
        
        
def VCTK(root):
    """
    Cache a pipeline loading VCTK.
    """
    
    url = [
        ('http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz', "VCTK-Corpus/")
    ]
    
    pipeline = compose(
        partial(download, root_path=root),
        extract,
        partial(walk, extension=".wav"),
        shuffle,
        load_vctk,
    )
    
    return Cache(pipeline(url), "tmp/")


data = VCTK("/Users/vincentqb/vctktest/")

next(data)
data[0]

('p284_077',
 'Even the one she loved.\n',
 tensor([[ 0.0007,  0.0013,  0.0009,  ..., -0.0007, -0.0012, -0.0011]]),
 48000)

# LibriSpeech

In [7]:
def load_librispeech(fileids):
    """
    Load data corresponding to each LIBRISPEECH fileids.
    
    Input: path, file name identifying a row of data
    Output: id, waveform, sample_rate, translation
    """
    
    text_extension = ".trans.txt"
    audio_extension = ".flac"
    for data_path, fileid in fileids:
        fileid = os.path.basename(fileid).split(".")[0]
        folder1, folder2, file = fileid.split("-")
        file_text = folder1 + "-" + folder2 + text_extension
        file_text = os.path.join(data_path, folder1, folder2, file_text)
        file_audio = folder1 + "-"+ folder2 + "-" + file + audio_extension
        file_audio = os.path.join(data_path, folder1, folder2, file_audio)
        waveform, sample_rate = torchaudio.load(file_audio)
        
        found = False
        for line in open(file_text):
            fileid_text, content = line.strip().split(" ", 1)
            if fileid == fileid_text:
                found = True
                break
        if not found:
            from warnings import warn
            warn("Translation not found for {}.".format(fileid))
            continue

        yield fileid, waveform, sample_rate, content
        

def LIBRISPEECH(root, selection="dev-clean"):
    """
    Cache a pipeline loading LIBRISPEECH.
    """
    
    # http://www.openslr.org/resources/12/dev-clean.tar.gz
    # http://www.openslr.org/resources/12/test-clean.tar.gz
    # http://www.openslr.org/resources/12/test-other.tar.gz
    # http://www.openslr.org/resources/12/train-clean-100.tar.gz
    # http://www.openslr.org/resources/12/train-clean-360.tar.gz
    # http://www.openslr.org/resources/12/train-other-500.tar.gz

    selections = [
        "dev-clean",
        "test-clean",
        "test-other",
        "train-clean-100",
        "train-clean-360",
        "train-other-500"
    ]
        
    base = "http://www.openslr.org/resources/12/"
    url = [
        (os.path.join(base, selection + ".tar.gz"), os.path.join("LibriSpeech", selection))
    ]
     
    pipeline = compose(
        partial(download, root_path=root),
        extract,
        partial(walk, extension=".flac"),
        shuffle,
        load_librispeech,
    )

    return Cache(pipeline(url), "tmp/")


data = LIBRISPEECH("/Users/vincentqb/librispeechtest/")

next(data)
data[0]

('3853-163249-0032',
 tensor([[-0.0011, -0.0030, -0.0018,  ...,  0.0016, -0.0002, -0.0073]]),
 16000,
 'CAN YOU REMEMBER WHAT HEPSEY TOLD US AND CALL THEM POOR LONG SUFFERIN CREETERS NAMES')

# CommonVoice

In [8]:
def load_commonvoice(fileids, tsv):
    """
    Load data corresponding to each COMMONVOICE fileids.
    
    Input: path, file name identifying a row of data
    Output: client_id, path, sentence, up_votes, down_votes, age, gender, accent, waveform, sample_rate
    """
    
    for path, fileid in fileids:
        filename = os.path.join(path, "clips", fileid)
        tsv = os.path.join(path, tsv)
        
        found = False
        for line in open(tsv):
            if fileid in line:
                # client_id, path, sentence, up_votes, down_votes, age, gender, accent
                line = line.strip().split("\t")
                found = True
                break
        if not found:
            continue

        # waveform, sample_rate
        output = torchaudio.load(filename)    

        line.extend(output)
        yield line
        

def COMMONVOICE(root, language="tatar", tsv="train.tsv"):
    """
    Cache a pipeline loading COMMONVOICE.
    """
    
    web = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-3/"

    languages = {
        "tatar": "tt",
        "english": "en",
        "german": "de",
        "french": "fr",
        "welsh": "cy",
        "breton": "br",
        "chuvash": "cv",
        "turkish": "tr",
        "kyrgyz": "ky",
        "irish": "ga-IE",
        "kabyle": "kab",
        "catalan": "ca",
        "taiwanese": "zh-TW",
        "slovenian": "sl",
        "italian": "it",
        "dutch": "nl",
        "hakha chin": "cnh",
        "esperanto": "eo",
        "estonian": "et",
        "persian": "fa",
        "basque": "eu",
        "spanish": "es",
        "chinese": "zh-CN",
        "mongolian": "mn",
        "sakha": "sah",
        "dhivehi": "dv",
        "kinyarwanda": "rw",
        "swedish": "sv-SE",
        "russian": "ru",
    }

    url = web + languages[language] + ".tar.gz"
    url = [(url, "")]
     
    pipeline = compose(
        partial(download, root_path=root),
        extract,
        partial(walk, extension=".mp3"),
        # partial(filtering, reference=tsv),
        shuffle,
        partial(load_commonvoice, tsv=tsv),
    )
    
    return Cache(pipeline(url), "tmp/")


data = COMMONVOICE("/Users/vincentqb/commonvoicetest/")

next(data)
data[0]

['bb10e83bdf015da18144f427509d8cb56cfa4884527dc0cb3da927c845b733e48d3c451ae9538723b747fd6e34b15a863635e71b09a7611b7484f09e4cd109be',
 'common_voice_tt_17759554.mp3',
 'Авыл башлыкларына исем ошамады булса кирәк.',
 '2',
 '0',
 'thirties',
 'male',
 tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -2.9039e-05,
          -1.0319e-06,  2.4986e-05]]),
 48000]