-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
42 lines (31 loc) · 1.38 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
import os
class SubsetSC(SPEECHCOMMANDS):
def __init__(self, subset: str = None):
super().__init__("./", download=True)
def load_list(filename):
filepath = os.path.join(self._path, filename)
with open(filepath) as fileobj:
return [os.path.join(self._path, line.strip()) for line in fileobj]
if subset == "validation":
self._walker = load_list("validation_list.txt")
elif subset == "testing":
self._walker = load_list("testing_list.txt")
elif subset == "training":
excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
excludes = set(excludes)
# paths in load_list have "./" in them, need to add to walker too
self._walker = [w for w in self._walker if "./" + w not in excludes]
def get_transform(sample_rate):
new_sample_rate = 8000
transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
return transform
def label_to_index(word, labels):
# Return the position of the word in labels
return torch.tensor(labels.index(word))
def index_to_label(index, labels):
# Return the word corresponding to the index in labels
# This is the inverse of label_to_index
return labels[index]