Other examples of dataset:
* [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py) and [here](https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py)
* generator for [tarballs](https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractfile) and [zip](https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile.open)
* [AsrDataset](https://github.com/pytorch/fairseq/blob/4812f64b651ab64881510d38d4e35ce4ce22b04f/examples/speech_recognition/data/asr_dataset.py#L14)

In [49]:
import torch
import torchvision 
import torchaudio

import os
import random
from functools import reduce, partial
from warnings import warn
import pickle

import six
import csv
import os
import tarfile
import logging
import re
import sys
import zipfile

In [50]:
def get_data(URL):
    r = requests.get(URL)
    file_like_object = io.BytesIO(r.content)
    tar = tarfile.open(fileobj=file_like_object)
    d = {}
    for member in tar.getmembers():
        if member.isfile() and member.name.endswith('csv'):
            k = 'train' if 'train' in member.name else 'test'
            d[k] = tar.extractfile(member)
    return d

In [51]:
def unicode_csv_reader(unicode_csv_data, **kwargs):
    r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
    Borrowed and slightly modified from the Python docs:
    https://docs.python.org/2/library/csv.html#csv-examples
    Arguments:
        unicode_csv_data: unicode csv data (see example below)
    Examples:
        >>> from torchtext.utils import unicode_csv_reader
        >>> import io
        >>> with io.open(data_path, encoding="utf8") as f:
        >>>     reader = unicode_csv_reader(f)
    """

    # Fix field larger than field limit error
    maxInt = sys.maxsize
    while True:
        # decrease the maxInt value by factor 10
        # as long as the OverflowError occurs.
        try:
            csv.field_size_limit(maxInt)
            break
        except OverflowError:
            maxInt = int(maxInt / 10)
    csv.field_size_limit(maxInt)

    if six.PY2:
        # csv.py doesn't do Unicode; encode temporarily as UTF-8:
        csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
        for row in csv_reader:
            # decode UTF-8 back to Unicode, cell by cell:
            yield [cell.decode('utf-8') for cell in row]
    else:
        for line in csv.reader(unicode_csv_data, **kwargs):
            yield line

suggestions:
* small functional
* get length
* shuffle
* meaningful error on function mismatch
* ~~cache or buffer~~
* ~~generator~~
* stream files from disk or web
* stream archives
* ~~no compose function~~
* ~~currie instead of partial?~~

# Common tools

In [52]:
class Cache:
    """
    Wrap a generator so that, whenever a new item is returned, it is saved to disk in a pickle.
    """

    def __init__(self, generator, location):
        self.generator = generator
        self.location = location

        self._id = id(self)
        self._cache = []
        self._internal_index = 0

    def __iter__(self):
        self._internal_index = 0
        return self

    def __next__(self):
        if self._internal_index < len(self):
            item = self[self._internal_index]
        else:
            item = next(self.generator)
        
            file = str(self._id) + "-" + str(len(self))
            file = os.path.join(self.location, file)
            self._cache.append(file)
        
            os.makedirs(self.location, exist_ok=True)
            with open(file, 'wb') as file:
                pickle.dump(item, file)

        self._internal_index += 1
        return item
    
    def __getitem__(self, index):
        file = self._cache[index]
        with open(file, 'rb') as file:
            item = pickle.load(file)
        return item
    
    def __len__(self):
        # Return length of cache
        return len(self._cache)

In [53]:
class Buffer:
    """
    Wrap a generator so as to keep the last few in memory.
    """
    
    def __init__(self, generator, capacity=10):
        self.generator = generator
        self.capacity = capacity
        self._cache = []
        self._fill()
    
    def _fill(self):
        while len(self._cache) <= self.capacity:
            self._cache.append(next(self.generator))
    
    def __getitem__(self, n):
        self._fill()
        return self._cache[n]
    
    def __iter__(self):
        return self
    
    def __next__(self):
        item = self._cache.pop(0)
        self._fill()
        return item

In [54]:
def download_to_file(urls, root_path):
    """
    Download each url to root_path.
    
    Input: url generator, folder inside archive
    Output: downloaded archive, folder inside archive
    """
    for url, folder in urls:
        # torchvision.datasets.utils.download_url(url, root_path)
        file = os.path.join(root_path, os.path.basename(url))
        yield file, folder
    
    
def extract(files):
    """
    Extract each archive to their respective folder.
    
    Input: (url, folder name inside archive) generator
    Output: path to inside archive
    """
    for file, folder in files:
        # torchvision.datasets.utils.extract_archive(file)
        path = os.path.dirname(file)
        path = os.path.join(path, folder)
        yield path
          
            
def walk(paths, extension):
    """
    Walk inside a path recursively to find all files with given extension.
    
    Input: path
    Output: path, file name identifying a row of data
    """
    for path in paths:
        for dp, dn, fn in os.walk(path):
            for f in fn:
                if extension in f:
                    yield path, f

                    
def shuffle(generator):
    """
    Shuffle the order of a generator.
    
    Input: generator
    Output: generator
    """

    # Load whole generator in memory
    generator = list(generator)
    # print(len(generator))
    random.shuffle(generator)
    for g in generator:
        yield g

        
def filtering(fileids, reference):
    """
    Skip fileids that are not present in given reference file.
    
    Output: (path, file) generator, reference file
    Output: path, file
    """
    
    path_old = ""
    
    for path, fileid in fileids:
        
        # Check if same path to avoid reloading the file constantly
        if path != path_old:
            ref = os.path.join(path, reference)
            with open(ref) as ref:
                r = "".join(ref.readlines())
            path_old = path

        # It would be more efficient to loop through the reference file instead
        if fileid in r:
            yield path, fileid

# YesNo

[original](https://www.openslr.org/1/), [torchaudio](https://pytorch.org/audio/_modules/torchaudio/datasets/yesno.html)

In [55]:
def load_yesno(fileids):
    """
    Load data corresponding to each YESNO fileids.
    
    Input: path, file name identifying a row of data
    Output: label, waveform, sample_rate
    """
    
    extension = ".wav"
    for path, fileid in fileids:
        file = os.path.join(path, fileid)
        waveform, sample_rate = torchaudio.load(file)
        label = os.path.basename(fileid).split(".")[0].split("_")
    
        yield {
            "label": label,
            "waveform": waveform,
            "sample_rate": sample_rate,
        }
        

def YESNO(root):
    """
    Cache a pipeline loading YESNO.
    """
    
    url = [
        ("http://www.openslr.org/resources/1/waves_yesno.tar.gz", "waves_yesno")
    ]
    
    path = download(url, root_path=root)
    path = extract(path)
    path = walk(path, extension=".wav")
    path = shuffle(path)
    data = load_yesno(path)
    
    # return Buffer(data)
    # return Cache(data, "tmp/")
    return data


data = YESNO("/Users/vincentqb/yesnotest")

next(data)

{'label': ['0', '1', '1', '1', '1', '1', '1', '1'],
 'waveform': tensor([[3.0518e-05, 6.1035e-05, 3.0518e-05,  ..., 2.7466e-03, 1.8005e-03,
          2.2888e-03]]),
 'sample_rate': 8000}

In [56]:
next(data)

{'label': ['1', '1', '1', '0', '1', '0', '1', '0'],
 'waveform': tensor([[ 0.0016,  0.0017,  0.0016,  ..., -0.0016, -0.0010, -0.0002]]),
 'sample_rate': 8000}

# VCTK

[original](https://datashare.is.ed.ac.uk/handle/10283/2651), [torchaudio](https://pytorch.org/audio/datasets.html?highlight=dataset#torchaudio.datasets.VCTK)

In [57]:
def load_vctk(fileids):
    """
    Load data corresponding to each VCTK fileids.

    Input: path, file name identifying a row of data
    Output: id, content, waveform, sample_rate
    """
    
    txt_folder = "txt"
    txt_extension = ".txt"
    
    audio_folder = "wav48"
    audio_extension = ".wav"
    
    for path, fileid in fileids:
        
        fileid = os.path.basename(fileid).split(".")[0]
        folder = fileid.split("_")[0]
        txt_file = os.path.join(path, txt_folder, folder, fileid + txt_extension)        
        audio_file = os.path.join(path, audio_folder, folder, fileid + audio_extension)        
        
        try:
            with open(txt_file) as txt_file:
                content = txt_file.readlines()[0]
        except FileNotFoundError:
            warn("Translation not found for {}".format(audio_file))
            # warn("File not found: {}".format(txt_file))
            continue

        waveform, sample_rate = torchaudio.load(audio_file)
        
        yield {
            "id": fileid,
            "content": content,
            "waveform": waveform,
            "sample_rate": sample_rate,
        }
        
        
def VCTK(root):
    """
    Cache a pipeline loading VCTK.
    """
    
    url = [
        ('http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz', "VCTK-Corpus/")
    ]
    
    path = download(url, root_path=root)
    path = extract(path)
    path = walk(path, extension=".wav")
    path = shuffle(path)
    data = load_vctk(path)
    
    # return Cache(data, "tmp/")
    return data


data = VCTK("/Users/vincentqb/vctktest/")

next(data)

{'id': 'p231_181',
 'content': 'I am not ready to walk away.\n',
 'waveform': tensor([[-0.0117, -0.0173, -0.0150,  ...,  0.0106,  0.0099,  0.0113]]),
 'sample_rate': 48000}

# LibriSpeech

[original](http://www.openslr.org/12)

In [58]:
def load_librispeech(fileids):
    """
    Load data corresponding to each LIBRISPEECH fileids.
    
    Input: path, file name identifying a row of data
    Output: id, waveform, sample_rate, translation
    """
    
    text_extension = ".trans.txt"
    audio_extension = ".flac"
    for data_path, fileid in fileids:
        fileid = os.path.basename(fileid).split(".")[0]
        folder1, folder2, file = fileid.split("-")
        file_text = folder1 + "-" + folder2 + text_extension
        file_text = os.path.join(data_path, folder1, folder2, file_text)
        file_audio = folder1 + "-"+ folder2 + "-" + file + audio_extension
        file_audio = os.path.join(data_path, folder1, folder2, file_audio)
        waveform, sample_rate = torchaudio.load(file_audio)
        
        found = False
        for line in open(file_text):
            fileid_text, content = line.strip().split(" ", 1)
            if fileid == fileid_text:
                found = True
                break
        if not found:
            from warnings import warn
            warn("Translation not found for {}.".format(fileid))
            continue

        yield {
            "id": fileid,
            "content": content,
            "waveform": waveform,
            "sample_rate": sample_rate,
        }
        

def LIBRISPEECH(root, selection="dev-clean"):
    """
    Cache a pipeline loading LIBRISPEECH.
    """
    
    # http://www.openslr.org/resources/12/dev-clean.tar.gz
    # http://www.openslr.org/resources/12/test-clean.tar.gz
    # http://www.openslr.org/resources/12/test-other.tar.gz
    # http://www.openslr.org/resources/12/train-clean-100.tar.gz
    # http://www.openslr.org/resources/12/train-clean-360.tar.gz
    # http://www.openslr.org/resources/12/train-other-500.tar.gz

    selections = [
        "dev-clean",
        "test-clean",
        "test-other",
        "train-clean-100",
        "train-clean-360",
        "train-other-500"
    ]
        
    base = "http://www.openslr.org/resources/12/"
    url = [
        (os.path.join(base, selection + ".tar.gz"), os.path.join("LibriSpeech", selection))
    ]
     
    path = download(url, root_path=root)
    path = extract(path)
    path = walk(path, extension=".flac")
    path = shuffle(path)
    data = load_librispeech(path)
    
    # return Cache(data, "tmp/")
    return data


data = LIBRISPEECH("/Users/vincentqb/librispeechtest/")

next(data)

{'id': '7850-73752-0015',
 'content': 'WAS IT NOT ALL A DREAM OF HIS OWN CREATION WHILE HIS EYE HAD BEEN FIXED IN ABSTRACTION ON THAT BRIGHT AND FLOWING RIVER',
 'waveform': tensor([[-0.0017, -0.0019, -0.0016,  ...,  0.0017,  0.0018,  0.0015]]),
 'sample_rate': 16000}

# CommonVoice

[original](https://voice.mozilla.org/en/datasets)

In [61]:
def load_commonvoice(fileids, tsv_file):
    """
    Load data corresponding to each COMMONVOICE fileids.
    
    Input: path, file name identifying a row of data
    Output: client_id, path, sentence, up_votes, down_votes, age, gender, accent, waveform, sample_rate
    """
    
    for path, fileid in fileids:
        filename = os.path.join(path, "clips", fileid)
        tsv = os.path.join(path, tsv_file)

        found = False
        with open(tsv) as tsv:
            first_line = True
            for line in unicode_csv_reader(tsv, delimiter='\t'):
                if first_line:
                    header = line
                    first_line = False
                    continue
                if fileid in line:
                    found = True
                    break
        if not found:
            continue

        waveform, sample_rate = torchaudio.load(filename)    

        dic = dict(zip(header, line))
        dic["waveform"] = waveform
        dic["sample_rate"] = sample_rate

        yield dic


def COMMONVOICE(root, language="tatar", tsv="train.tsv"):
    """
    Cache a pipeline loading COMMONVOICE.
    """
    
    web = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-3/"

    languages = {
        "tatar": "tt",
        "english": "en",
        "german": "de",
        "french": "fr",
        "welsh": "cy",
        "breton": "br",
        "chuvash": "cv",
        "turkish": "tr",
        "kyrgyz": "ky",
        "irish": "ga-IE",
        "kabyle": "kab",
        "catalan": "ca",
        "taiwanese": "zh-TW",
        "slovenian": "sl",
        "italian": "it",
        "dutch": "nl",
        "hakha chin": "cnh",
        "esperanto": "eo",
        "estonian": "et",
        "persian": "fa",
        "basque": "eu",
        "spanish": "es",
        "chinese": "zh-CN",
        "mongolian": "mn",
        "sakha": "sah",
        "dhivehi": "dv",
        "kinyarwanda": "rw",
        "swedish": "sv-SE",
        "russian": "ru",
    }

    url = web + languages[language] + ".tar.gz"
    url = [(url, "")]
     
    path = download(url, root_path=root)
    path = extract(path)
    path = walk(path, extension=".mp3")
    # path = shuffle(path)
    # path = filtering(path, reference=tsv)
    data = load_commonvoice(path, tsv)
    
    # return Cache(data, "tmp/")
    return data


data = COMMONVOICE("/Users/vincentqb/commonvoicetest/")

next(data)

{'client_id': '11d5e99f7bd5b4f8492a06bb1ec22aa9110bba6ea9918f2a9adec05d686304d568ab7063daf8915d3fccfb4dd44b81646bd13a33ca130ac4014560bba4c2db0b',
 'path': 'common_voice_tt_17531596.mp3',
 'sentence': 'Мин анда ялгыз бара алмам бит.',
 'up_votes': '2',
 'down_votes': '0',
 'age': 'thirties',
 'gender': 'male',
 'accent': '',
 'waveform': tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -2.8685e-07,
          -2.3097e-06, -2.8796e-06]]),
 'sample_rate': 48000}