This file aims to split the original GSC v2 dataset 35 classes to 12 calsses follow the proceduce mentioned in the original paper <https://arxiv.org/pdf/1804.03209.pdf>.

In [1]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

# Build the base

In [2]:
import torch
import torchaudio
import numpy as np
import pandas as pd

import os
import sys
import tarfile
import hashlib
import re
import glob

import random
import math

import IPython.display as ipd
from tensorflow.python.util import compat

In [3]:
DATA_URL = ['http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
            'http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz']

OFFICIAL_TEST_URL = ['http://download.tensorflow.org/data/speech_commands_test_set_v0.01.tar.gz',
                     'http://download.tensorflow.org/data/speech_commands_test_set_v0.02.tar.gz']

WORDS = ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']

MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  # ~134M
SILENCE_LABEL = '_silence_'
SILENCE_INDEX = 0
UNKNOWN_WORD_LABEL = '_unknown_'
UNKNOWN_WORD_INDEX = 1
BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
RANDOM_SEED = 59185
SR = 16000

In [4]:
def prepare_words_list(wanted_words):
  """Prepends common tokens to the custom word list.

  Args:
    wanted_words: List of strings containing the custom words.

  Returns:
    List with the standard silence and unknown tokens added.
  """
  return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words

In [None]:
os.path.basename('/content/aye.wav')

'aye.wav'

In [None]:
re.sub(r'_nohash_.*$','', 'e9287461_nohash_1.wav')

'e9287461'

In [5]:
def which_set(filename, validation_percentage, testing_percentage):
  """Determines which data partition the file should belong to.

  We want to keep files in the same training, validation, or testing sets even
  if new ones are added over time. This makes it less likely that testing
  samples will accidentally be reused in training when long runs are restarted
  for example. To keep this stability, a hash of the filename is taken and used
  to determine which set it should belong to. This determination only depends on
  the name and the set proportions, so it won't change as other files are added.

  It's also useful to associate particular files as related (for example words
  spoken by the same person), so anything after '_nohash_' in a filename is
  ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
  'bobby_nohash_1.wav' are always in the same set, for example.

  Args:
    filename: File path of the data sample.
    validation_percentage: How much of the data set to use for validation.
    testing_percentage: How much of the data set to use for testing.

  Returns:
    String, one of 'training', 'validation', or 'testing'.
  """
  base_name = os.path.basename(filename)
  # We want to ignore anything after '_nohash_' in the file name when
  # deciding which set to put a wav in, so the data set creator has a way of
  # grouping wavs that are close variations of each other.
  hash_name = re.sub(r'_nohash_.*$', '', base_name)
  # This looks a bit magical, but we need to decide whether this file should
  # go into the training, testing, or validation sets, and we want to keep
  # existing files in the same set even if more files are subsequently
  # added.
  # To do that, we need a stable way of deciding based on just the file name
  # itself, so we do a hash of that and then use that to generate a
  # probability value that we use to assign it.
  hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
  percentage_hash = ((int(hash_name_hashed, 16) %
                      (MAX_NUM_WAVS_PER_CLASS + 1)) *
                     (100.0 / MAX_NUM_WAVS_PER_CLASS))
  if percentage_hash < validation_percentage:
    result = 'validation'
  elif percentage_hash < (testing_percentage + validation_percentage):
    result = 'testing'
  else:
    result = 'training'
  return result

In [None]:
os.path.dirname('/content/GSC_12/up/e9287461_nohash_1.wav')

'/content/GSC_12/up'

In [6]:
def prepare_data_index(data_dir, silence_percentage, unknown_percentage,
                         wanted_words, validation_percentage,
                         testing_percentage):
    """Prepares a list of the samples organized by set and label.

    The training loop needs a list of all the available data, organized by
    which partition it should belong to, and with ground truth labels attached.
    This function analyzes the folders below the `data_dir`, figures out the
    right
    labels for each file based on the name of the subdirectory it belongs to,
    and uses a stable hash to assign it to a data set partition.

    Args:
      silence_percentage: How much of the resulting data should be background.
      unknown_percentage: How much should be audio outside the wanted classes.
      wanted_words: Labels of the classes we want to be able to recognize.
      validation_percentage: How much of the data set to use for validation.
      testing_percentage: How much of the data set to use for testing.

    Returns:
      Dictionary containing a list of file information for each set partition,
      and a lookup map for each class to determine its numeric index.

    Raises:
      Exception: If expected files are not found.
    """
    # Make sure the shuffling and picking of unknowns is deterministic.
    random.seed(RANDOM_SEED)
    wanted_words_index = {}
    for index, wanted_word in enumerate(wanted_words):
        wanted_words_index[wanted_word] = index + 2

    data_index = {'validation': [], 'testing': [], 'training': []}
    unknown_index = {'validation': [], 'testing': [], 'training': []}
    all_words = {}

    # Look through all the subfolders to find audio samples
    search_path = glob.glob(os.path.join(data_dir, '*', '*.wav'))
    for wav_path in search_path:
        _, word = os.path.split(os.path.dirname(wav_path))
        word = word.lower()
        # Treat the '_background_noise_' folder as a special case, since we expect
        # it to contain long audio samples we mix in to improve training.
        if word == BACKGROUND_NOISE_DIR_NAME:
            continue
        all_words[word] = True
        set_index = which_set(wav_path, validation_percentage, testing_percentage)
        # If it's a known class, store its detail, otherwise add it to the list
        # we'll use to train the unknown label.
        if word in wanted_words_index:
            data_index[set_index].append({'label': word, 'file': wav_path})
        else:
            unknown_index[set_index].append({'label': word, 'file': wav_path})

    if not all_words:
        raise Exception('No .wavs found at ' + search_path)

    for index, wanted_word in enumerate(wanted_words):
        if wanted_word not in all_words:
            raise Exception('Expected to find ' + wanted_word +
                        ' in labels but only found ' +
                        ', '.join(all_words.keys()))

    # We need an arbitrary file to load as the input for the silence samples.
    # It's multiplied by zero later, so the content doesn't matter.
    silence_wav_path = data_index['training'][0]['file']
    for set_index in ['validation', 'testing', 'training']:
        set_size = len(data_index[set_index])
        silence_size = int(math.ceil(set_size * silence_percentage / 100))
        for _ in range(silence_size):
            data_index[set_index].append({
                'label': SILENCE_LABEL,
              'file': silence_wav_path
            })

      # Pick some unknowns to add to each partition of the data set.
        random.shuffle(unknown_index[set_index])
        unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
        data_index[set_index].extend(unknown_index[set_index][:unknown_size])

    # Make sure the ordering is random.
    for set_index in ['validation', 'testing', 'training']:
        random.shuffle(data_index[set_index])

    # Prepare the rest of the result data structure.
    words_list = prepare_words_list(wanted_words)
    word_to_index = {}
    for word in all_words:
        if word in wanted_words_index:
            word_to_index[word] = wanted_words_index[word]
        else:
            word_to_index[word] = UNKNOWN_WORD_INDEX
    word_to_index[SILENCE_LABEL] = SILENCE_INDEX

    return data_index, word_to_index

In [7]:
def prepare_official_test(data_dir, wanted_words):
    wanted_words_index = {}
    for index, wanted_word in enumerate(wanted_words):
        wanted_words_index[wanted_word] = index + 2
    wanted_words_index[SILENCE_LABEL] = SILENCE_INDEX
    wanted_words_index[UNKNOWN_WORD_LABEL] = UNKNOWN_WORD_INDEX

    test_data = []

    search_path = glob.glob(os.path.join(data_dir, '*', '*.wav'))
    for wav_path in search_path:
        _, word = os.path.split(os.path.dirname(wav_path))
        word = word.lower()
        test_data.append({'label': word, 'file': wav_path})

    return test_data, wanted_words_index

In [15]:
from torchaudio._internal import download_url_to_file

In [None]:
_CHECKSUMS = {
    "http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "743935421bb51cccdb6bdd152e04c5c70274e935c82119ad7faeec31780d811d",  # noqa: E501
    "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "af14739ee7dc311471de98f5f9d2c9191b18aedfe957f4a6ff791c709868ff58",  # noqa: E501
}

In [None]:
download_url_to_file(DATA_URL, './GSC', _CHECKSUMS[DATA_URL])

100%|██████████| 2.26G/2.26G [00:24<00:00, 97.8MB/s]


In [8]:
def _download(data_url, dest_directory):
    filename = os.path.split(data_url)[-1]

    if not os.path.exists(dest_directory):
        os.mkdir(dest_directory)
    filepath = os.path.join(dest_directory, filename)

    download_url_to_file(data_url, filepath)
    tarfile.open(filepath, 'r:gz').extractall(dest_directory)

In [None]:
download(DATA_URL, 'GSC_12')

100%|██████████| 2.26G/2.26G [00:13<00:00, 183MB/s]


In [10]:
class SpeechCommands12(torch.utils.data.Dataset):
    """

    Args
    subset: str
        Select a subset of the dataset ['training', 'validation', 'testing', 'official_testing']
    """

    def __init__(self,
                 root: str,
                 download: bool = True,
                 version: int = 2,
                 subset: str = 'training',
                 transform = None) -> None:
        super().__init__()
        self.transform = transform
        self.subset = subset

        if subset != 'official_testing':
            if download:
                url = DATA_URL[version-1]
                filename = os.path.split(url)[-1]
                print('>> Downloading %s' % filename)
                _download(url, root)
            data_index, self.word_to_index = prepare_data_index(root,
                                                                silence_percentage = 10,
                                                                unknown_percentage = 10,
                                                                wanted_words = WORDS,
                                                                validation_percentage = 10,
                                                                testing_percentage = 10)
            self.dataset = data_index[subset]
        else:
            if download:
                url = OFFICIAL_TEST_URL[version-1]
                filename = os.path.split(url)[-1]
                print('>> Downloading %s' % filename)
                _download(url, root)
            self.dataset, self.word_to_index = prepare_official_test(root,
                                                                     wanted_words = WORDS)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        filepath = row['file']
        label = row['label']
        if label == SILENCE_LABEL and self.subset != 'official_testing':
            wav = torch.zeros([1, SR])
        else:
            wav, _ = torchaudio.load(filepath)

        if self.transform:
            wav = self.transform(wav, label)
        return wav, self.word_to_index[label]

In [None]:
val_dataset = SpeechCommands12('/content/GSC_12', download = False, subset = 'validation')

In [None]:
wav, label = val_dataset[1]
print(wav.shape)
print(label)
print(len(val_dataset))

torch.Size([1, 16000])
1
4445


In [None]:
test_dataset = SpeechCommands12('/content/GSC_12_test', download = False, subset = 'official_testing')

In [None]:
wav, label = test_dataset[1]
print(wav.shape)
print(label)
print(len(test_dataset))

torch.Size([1, 16000])
7
4890


# Build the Processing

In [None]:
LIB_PATH = '/content/drive/MyDrive/GSC/GSC_helper'

sys.path.append(LIB_PATH)
from GSC12 import SpeechCommands12

In [None]:
test_dataset = SpeechCommands12('/content/GSC_12_test', download = False, subset = 'official_testing')
wav, label = test_dataset[1]
print(wav.shape)
print(label)
print(len(test_dataset))

torch.Size([1, 16000])
7
4890


In [11]:
from torch import Tensor

def normalzieNoise(wav: Tensor,
                   noise: Tensor,
                   max_length: int = 16000) -> Tensor:
    len_wav = wav.shape[1]
    len_noise = noise.shape[1]
    if len_wav > len_noise:
        buf = torch.zeros_like(wav)
        start_point = int((len_wav - len_noise)*random.uniform(0, 1))
        end_point = start_point + len_noise
        buf[:, start_point: end_point] = noise
        noise = buf
    elif len_wav < len_noise:
        start_point = int((len_noise - len_wav)*random.uniform(0, 1))
        end_point = start_point + len_wav
        noise = noise[:, start_point: end_point]
    return noise[:, :max_length]

def addNoise(wav: Tensor,
             noise: Tensor,
             snr: list) -> Tensor:
    noise = normalzieNoise(wav, noise)
    addnsy = torchaudio.transforms.AddNoise()
    return addnsy(wav, noise, snr = torch.Tensor([random.uniform(*snr)]))

def pad_truncate(wav: Tensor,
                 max_length: int = 16000,
                 pad_value: int = 0) -> Tensor:
    wav_length = wav.shape[1]
    if wav_length < max_length:
        buff = torch.zeros([1, max_length])
        buff[:, :wav_length] = wav
        wav = buff
    elif wav_length > max_length:
        wav = wav[:, :max_length]
    return wav

def time_shift(wav: Tensor,
               shift: list,
               sr: int = 16000) -> Tensor:
    x_shift = int(random.uniform(*shift)*sr)
    padding = torch.zeros(1, np.abs(x_shift))
    if x_shift < 0:
        wav = torch.cat([padding, wav[:, :x_shift]], dim=-1)
    else:
        wav = torch.cat([wav[:, x_shift:], padding], dim=-1)
    return wav

class Preprocessing:
    def __init__(self,
                 noise_dir: str,
                 noise_prob: float,
                 snr: list,
                 shift: list = None,
                 is_train: bool = False,
                 augment: bool = True,
                 transform = None) -> None:
        self.noise_paths = glob.glob(os.path.join(noise_dir, '*.wav'))
        self.is_train = is_train
        self.noise_prob = noise_prob
        self.augment = augment
        self.transform = transform
        self.add_noise = lambda x, noise: addNoise(x, noise, snr)
        self.pad_trunc = lambda x: pad_truncate(x, SR)
        self.shift = shift
        if shift:
            self.time_shift = lambda x: time_shift(x, shift)

    def __call__(self,
                 wav: Tensor,
                 label: str) -> Tensor:
        # padding to SR
        wav = self.pad_trunc(wav)

        if self.augment:
            # time shifting for training
            if self.is_train:
                if self.shift:
                    wav = self.time_shift(wav)

            p = random.random()
            if label == SILENCE_LABEL or (self.is_train and p<= self.noise_prob):
                noise, _ = torchaudio.load(random.choice(self.noise_paths))
                if label == SILENCE_LABEL:
                    p = random.random()
                    wav = normalzieNoise(wav, noise*p)
                else:
                    wav = self.add_noise(wav, noise)

        if self.transform:
            wav = self.transform(wav)

        return wav

In [None]:
wav2 = time_shift(wav,
                  [-0.5, 0.5])
wav2.shape

torch.Size([1, 16000])

In [None]:
ipd.Audio(wav, rate = 16000)

In [None]:
ipd.Audio(wav2, rate = 16000)

In [None]:
f_pre = Preprocessing('/content/GSC_12/_background_noise_',
                      noise_prob = 0.8,
                      snr = [-5, 10],
                      shift = [-0.1, 0.1],
                      is_train = True,
                      augment = True)
wav3 = f_pre(wav, 'on')
wav3.shape

torch.Size([1, 16000])

In [None]:
ipd.Audio(wav3, rate = SR)

In [None]:
wav3.min()

tensor(-0.4501)

In [None]:
mel = torchaudio.transforms.MelSpectrogram(SR,
                                           n_mels = 40,
                                           n_fft = 480,
                                           win_length = 480,
                                           hop_length = 160)

In [None]:
for i in range(200):
    mel(torch.Tensor(128, 1, 16000))

In [None]:
for i in range(128*200):
    mel(torch.Tensor(1, 1, 16000))

In [None]:
specaug = torchaudio.transforms.SpecAugment(2, 20,
                                            2, 20)

In [None]:
for i in range(200):
    a = mel(torch.Tensor(128, 1, 16000))
    specaug(a)

In [None]:
for i in range(200*128):
    a = mel(torch.Tensor(1, 1, 16000))
    specaug(a)

## Make Data

In [13]:
from torch import nn

class LFBE_Delta(nn.Module):
    def __init__(self,
                 sample_rate: int,
                 n_mfcc: int,
                 n_mels: int,
                 melkwargs: dict,
                 ) -> None:
        super().__init__()
        self.mfcc = torchaudio.transforms.MFCC(sample_rate = sample_rate,
                                               n_mfcc = n_mfcc,
                                               melkwargs = melkwargs)
        self.mel = torchaudio.transforms.MelSpectrogram(sample_rate = sample_rate,
                                                        n_mels = n_mels,
                                                        **melkwargs)
        self.todb = torchaudio.transforms.AmplitudeToDB()

    def forward(self, input: Tensor) -> Tensor:
        """
        Args:
        input: Tensor input: (N, C, T)
        """
        logmel = self.todb(self.mel(input))
        mfcc = self.mfcc(input)
        delta = torchaudio.functional.compute_deltas(mfcc)
        delta2 = torchaudio.functional.compute_deltas(delta)
        return torch.concat([logmel, delta, delta2], dim = 1)

In [18]:
melkwargs = {
    'n_fft' : 480,
    'hop_length' : 160,
    'f_min' : 20,
    'f_max' : 4000,
}
lfbe = LFBE_Delta(sample_rate = SR,
                  n_mfcc = 13,
                  n_mels = 13,
                  melkwargs = melkwargs)
f_pre = lambda is_train, augment: Preprocessing(noise_dir = '/content/GSC_12/_background_noise_',
                                                noise_prob = 0.8,
                                                snr = [-5, 10],
                                                shift  = [-0.1, 0.1],
                                                is_train = is_train,
                                                augment = augment,
                                                transform = lfbe
                                                )
train_pre = f_pre(True, True)
val_pre = f_pre(False, True)
test_pre = f_pre(False, False)



In [20]:
train_dataset = SpeechCommands12('/content/GSC_12', download = False, subset = 'training', transform = train_pre)
val_dataset = SpeechCommands12('/content/GSC_12', download = False, subset = 'validation', transform = val_pre)
test_dataset = SpeechCommands12('/content/GSC_12_test', download = True, subset = 'official_testing', transform = test_pre)

>> Downloading speech_commands_test_set_v0.02.tar.gz


100%|██████████| 107M/107M [00:00<00:00, 164MB/s]


In [21]:
x, y = train_dataset[1]
print(x.shape)
print(y)
print(len(train_dataset))

torch.Size([1, 39, 101])
5
36923


In [26]:
from tqdm import tqdm

def GSC_preprocessing(dataset, output_directory, num_classes = 12, mul_factor = 1, set = 'train', csv_file_name = 'analysised_spec.csv'):
    """
    Preprocessing for each dataset

    mul_factor: increasing the number of data samples by mul_factor times.
    """
    out_df = {
        'link': [],
        'label': [],
    }
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    for idx in range(mul_factor):
        # def f(ix, ex):
        for ix, (wav, label) in tqdm(enumerate(dataset)):
            #if ix%1000 == 0:
            #    print(f'{ix}/{len(data_df)}')
            #row = data_df.iloc[ix]
            #label = row['label']

            fname = f'{set}_{label}_{ix}_{idx}.npz'
            #row = {
            #    'link': os.path.join(output_directory, f'label_{label}', fname),
            #    'label': label,
            #    'set': set
            #}
            out_df['link'].append(os.path.join(set, fname))
            out_df['label'].append(label)

            if os.path.exists(os.path.join(output_directory, fname)):
                continue

            #out_df = out_df._append(row, ignore_index = True)
            np.savez_compressed(os.path.join(output_directory, fname), wav.squeeze(0).numpy())

        # Parallel(n_jobs = os.cpu_count())(delayed(f)(i, ex) for i, ex in tqdm(enumerate(dataset)))

    out_df = pd.DataFrame(out_df)
    out_df.to_csv(csv_file_name, index = False)

In [27]:
GSC_preprocessing(val_dataset, 'val_3', set = 'val', csv_file_name = '/content/val.csv')

4445it [00:46, 94.62it/s]


In [29]:
GSC_preprocessing(test_dataset, 'test_3', set = 'test', csv_file_name = '/content/test.csv')

4890it [00:45, 107.69it/s]


In [30]:
GSC_preprocessing(train_dataset, 'train_3', set = 'train', csv_file_name = '/content/train.csv', mul_factor = 4)

36923it [19:19, 31.83it/s]
36923it [18:47, 32.76it/s]
36923it [19:31, 31.52it/s]
36923it [22:07, 27.82it/s]


In [31]:
LIB_PATH = '/content/drive/MyDrive/GSC/GSC_helper'

sys.path.append(LIB_PATH)
from utils import zipzip

In [33]:
zipzip('/content/train_3', '/content/drive/MyDrive/GSC12/edgecrnn_train.zip')
zipzip('/content/val_3', '/content/drive/MyDrive/GSC12/edgecrnn_val.zip')
zipzip('/content/test_3', '/content/drive/MyDrive/GSC12/edgecrnn_test.zip')

zipping...: 100%|██████████| 147692/147692 [03:08<00:00, 782.54it/s]


/content/drive/MyDrive/GSC12/edgecrnn_train.zip created


zipping...: 100%|██████████| 4445/4445 [00:05<00:00, 821.41it/s]


/content/drive/MyDrive/GSC12/edgecrnn_val.zip created


zipping...: 100%|██████████| 4890/4890 [00:09<00:00, 523.93it/s]


/content/drive/MyDrive/GSC12/edgecrnn_test.zip created


In [34]:
import shutil
shutil.move('/content/test.csv', '/content/drive/MyDrive/GSC12')
shutil.move('/content/val.csv', '/content/drive/MyDrive/GSC12')
shutil.move('/content/train.csv', '/content/drive/MyDrive/GSC12')

'/content/drive/MyDrive/GSC12/train.csv'

## Test Training

In [None]:
!pip install Lightning

In [None]:
import lightning as L

In [None]:
class EdgeCRNN_training(L.LightningModule):
    def __init__(self,
                 lr: float,
                 in_channels: int,
                 hidden_size: int,
                 num_classes: int,
                 width_multiplier: int = 1,
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        #self.automatic_optimization = False
        self.lr = lr
        self.net = EdgeCRNN(in_channels = in_channels,
                            hidden_size = hidden_size,
                            num_classes = num_classes,
                            width_multiplier = width_multiplier
                            )

    def forward(self, input):
        return self.net(input)

In [None]:
def warm_up(lr_init, lr_end, epoch, total_epochs):
    return lr_init - epoch*(lr_init - lr_end)/total_epochs

In [None]:
@add_to_class(EdgeCRNN_training)
def accuracy(self, Y_hat, Y, averaged = True):
    """
    Compute the number of correct predictions
    """
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(dim = 1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

@add_to_class(EdgeCRNN_training)
def training_step(self, batch, batch_idx):
    x, y = batch
    x = mel(x)
    x = specaug(x)
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)

    # single scheduler
    # sch = self.lr_schedulers()
    #sch.step()
    torch.nn.utils.clip_grad_norm(self.parameters(), 5)

    values = {"train_loss": loss, "train_acc": acc}
    self.log_dict(values, prog_bar = True)
    return loss

@add_to_class(EdgeCRNN_training)
def validation_step(self, batch, batch_idx):
    x, y = batch
    x = mel(x)
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)
    values = {"val_loss": loss, "val_acc": acc}
    self.log_dict(values, prog_bar = True)

@add_to_class(EdgeCRNN_training)
def test_step(self, batch, batch_idx):
    x, y = batch
    x = mel(x)
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)
    values = {"test_loss": loss, "test_acc": acc}
    self.log_dict(values, prog_bar = True)

@add_to_class(EdgeCRNN_training)
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay = 0.00005)
    return optimizer

@add_to_class(EdgeCRNN_training)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # update params
    optimizer.step(closure = optimizer_closure)

    # manually warm up lr withou a scheduler
    lr = warm_up(self.lr, 1e-4, epoch, 500)
    for pg in optimizer.param_groups:
        pg['lr'] = lr

In [None]:
import torch.nn.functional as F
@add_to_class(EdgeCRNN_training)
def loss(self, y_hat, y):
    return F.cross_entropy(y_hat, y, reduction = 'mean')

In [None]:
class SC_12(L.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        f_pre = lambda is_train, augment: Preprocessing(noise_dir = '/content/GSC_12/_background_noise_',
                                                               noise_prob = 0.8,
                                                               snr = [-5, 10],
                                                               shift  = [-0.1, 0.1],
                                                               is_train = is_train,
                                                               augment = augment)
        train_pre = f_pre(True, True)
        val_pre = f_pre(False, True)
        test_pre = f_pre(False, False)

        self.train_dataset = SpeechCommands12(root = '/content/GSC_12',
                                              download = False,
                                              subset = 'training',
                                              transform = train_pre)
        self.val_dataset = SpeechCommands12(root = '/content/GSC_12',
                                            download = False,
                                            subset = 'validation',
                                            transform = val_pre)
        self.test_dataset = SpeechCommands12(root = '/content/GSC_12_test',
                                             download = False,
                                             subset = 'official_testing',
                                             transform = test_pre)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = True,
    #                                       collate_fn = self.collate_fn,
                                           num_workers = 1,
                                           prefetch_factor = 1,
                                           pin_memory = True,
                                           drop_last = True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = False,
     #                                      collate_fn = self.collate_fn,
                                           num_workers = 1,
                                           prefetch_factor = 1,
                                           pin_memory = True,
                                           drop_last = True)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = False,
      #                                     collate_fn = self.collate_fn,
                                           num_workers = 1,
                                           prefetch_factor = 1,
                                           pin_memory = True,
                                           drop_last = True)

In [None]:
data = SC_12(512)

In [None]:
train_dataloader = data.train_dataloader()
X, y = next(iter(train_dataloader))
print(X.shape)
print(y)

torch.Size([128, 1, 16000])
tensor([ 4,  4,  2, 11, 11,  5,  8,  6,  9,  7,  2,  3,  9,  4, 11,  6,  6,  7,
         9,  3,  0,  7,  7, 11,  9,  3,  6,  6, 10,  8,  7,  0,  2,  6,  7,  8,
        10, 11, 11,  1,  4,  0,  4,  7, 11,  9,  8,  6,  1,  4,  5,  5,  1,  6,
         7,  9, 10, 10,  7,  0,  7,  3,  1,  5,  6,  6,  5,  2,  5,  4,  5,  0,
         0,  1,  4, 11,  4,  4,  0,  9,  4,  2, 10,  0,  2,  1,  5,  0,  6,  7,
        11,  1,  8,  5,  6,  7,  2,  2,  3,  3,  3,  4,  3,  7,  1,  9,  0,  9,
         8,  1,  7,  0,  0,  7,  9,  3,  9,  4,  0,  4,  6,  5,  3,  7,  9,  4,
         8,  6])


In [None]:
val_dataloader = data.val_dataloader()

In [None]:
mel = torchaudio.transforms.MelSpectrogram(SR,
                                           n_mels = 40,
                                           n_fft = 480,
                                           win_length = 480,
                                           hop_length = 160).to('cuda')
specaug = torchaudio.transforms.SpecAugment(2, 20,
                                            2, 20).to('cuda')

1 epoch

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
early_stopping_callback = EarlyStopping(monitor = "val_acc", min_delta = 0.00001, patience = 10, mode = "max")
checkpoint_callback = ModelCheckpoint(dirpath = 'best_model',
                                      save_top_k = 5, monitor = 'val_acc',
                                      mode = 'max',
                                      filename = 'edgecrnn-gsc-12-{epoch:02d}-{val_loss:.2f}-{val_acc:.4f}')

In [None]:
from lightning.pytorch import seed_everything

seed_everything(42)

net = EdgeCRNN_training(lr = 1e-3,
                        in_channels = 1,
                        hidden_size = 64,
                        num_classes = 12)

trainer = L.Trainer(accelerator="gpu",
                    callbacks = [early_stopping_callback, checkpoint_callback],
                    enable_checkpointing=True,
                    default_root_dir = "/content/logging",
                    max_epochs=500)
trainer.fit(net, data)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name | Type     | Params
----------------------------------
0 | net  | EdgeCRNN | 454 K 
----------------------------------
454 K     Trainable params
0         Non-trainable params
454 K     Total params
1.819     Total estimated model params size (MB)
INFO:lightning.pytorch.callback

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

  torch.nn.utils.clip_grad_norm(self.parameters(), 5)


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]