## Keyword Spotting (KWS) using PyTorch Lightning 

We illustrate how to use PyTorch Lightning to train a model for keyword spotting. We also show how to build a model that is data agnostic, i.e. it can be trained on any dataset. A separate module handles the datasets and dataloaders.

Using Experience, Task and Performance concept, KWS can be described as follows:

1. **Experience:** The model is trained on a dataset of audio files containing 1sec speech samples, each is a single keyword. There are 35 distinct words in the dataset such as `yes`, `no`, `right`, `left` and so forth. However, 2 additional categories are added to the dataset: `silence` and `unknown`. `silence` is when there is silence or non-word audio audio activity such as background noise. `unknown` is when there is a keyword but not one of those 35 distinct words.

2. **Task:** The model is trained to classify keywords in audio samples. We use a modified ResNet18 model.

3. **Performance:** We measure the performance of the model by calculating the accuracy of the model on the test set. 

For simplicity, we do not use the validation set for hyperparameter tuning.

We use KWS Version 2. The original KWS paper can be found [here](https://arxiv.org/pdf/1804.03209.pdf).


In [None]:
%pip install pytorch-lightning --upgrade
%pip install torchmetrics --upgrade

In [None]:
import torch
import torchaudio, torchvision
import os
import matplotlib.pyplot as plt 
import librosa
import argparse
import numpy as np
import wandb
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torchvision.transforms import ToTensor
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import load_speechcommands_item


# use second GPU
os.environ["CUDA_VISIBLE_DEVICES"]="1"

### Custom Silence and Unknown Datasets

We create a custom `silence` dataset. The dataset randomly samples background audio supplied in the KWS dataset. These files are under the `_background_noise_` folder.

We also create `unknown` dataset uses random audio samples from the train set but labelled as `unknown`. 

The creation of these 2 datasets is described in the [KWS paper](https://arxiv.org/pdf/1804.03209.pdf) and implemented below.

We limit the number of samples to about the size of train dataset divided by 35 (the number of distinct words in the KWS dataset).

In [None]:
class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return waveform, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return waveform, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

### The PyTorch Lightning Data Module for KWS

`KWSDataModule` cleanly separates the data handling from the model. The data module handles the datasets and dataloaders.

We use `torchaudio` `SPEECHCOMMANDS` dataset to load the training, testing and validation sets. 

A custom `collate_fn` is used to handle the different lengths of the audio samples. The function also converts the wav files into mel spectrograms for the ResNet18 model input layer. A mel spectrogram is a log-mel spectrogram. It is an image that shows the power spectrum of the audio signal in dB. Basically, we convert audio to image. Then, we can use an image classifier like ResNet18.

In [None]:
class KWSDataModule(LightningDataModule):
    def __init__(self, path, batch_size=128, num_workers=0, n_fft=512, 
                 n_mels=128, win_length=None, hop_length=256, class_dict={}, 
                 **kwargs):
        super().__init__(**kwargs)
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length
        self.class_dict = class_dict

    def prepare_data(self):
        self.train_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                download=True,
                                                                subset='training')

        silence_dataset = SilenceDataset(self.path)
        unknown_dataset = UnknownDataset(self.path)
        self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset, silence_dataset, unknown_dataset])
                                                                
        self.val_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                              download=True,
                                                              subset='validation')
        self.test_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                               download=True,
                                                               subset='testing')                                                    
        _, sample_rate, _, _, _ = self.train_dataset[0]
        self.sample_rate = sample_rate
        self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                              n_fft=self.n_fft,
                                                              win_length=self.win_length,
                                                              hop_length=self.hop_length,
                                                              n_mels=self.n_mels,
                                                              power=2.0)

    def setup(self, stage=None):
        self.prepare_data()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def collate_fn(self, batch):
        mels = []
        labels = []
        wavs = []
        for sample in batch:
            waveform, sample_rate, label, speaker_id, utterance_number = sample
            # ensure that all waveforms are 1sec in length; if not pad with zeros
            if waveform.shape[-1] < sample_rate:
                waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)
            elif waveform.shape[-1] > sample_rate:
                waveform = waveform[:,:sample_rate]

            # mel from power to db
            mels.append(ToTensor()(librosa.power_to_db(self.transform(waveform).squeeze().numpy(), ref=np.max)))
            labels.append(torch.tensor(self.class_dict[label]))
            wavs.append(waveform)

        mels = torch.stack(mels)
        labels = torch.stack(labels)
        wavs = torch.stack(wavs)
   
        return mels, labels, wavs

### The PL LightningModule for KWS

The `KWSModel` is a ResNet18 model with first and last layers modified to suport single channel mel spectrogram inputs and 37-category outputs.

In [None]:
class KWSModel(LightningModule):
    def __init__(self, num_classes=37, epochs=30, lr=0.001, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = torchvision.models.resnet18(num_classes=num_classes)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        mels, labels, _ = batch
        preds = self.model(mels)
        loss = self.hparams.criterion(preds, labels)
        return {'loss': loss}

    # calls to self.log() are recorded in wandb
    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("train_loss", avg_loss, on_epoch=True)

    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

    def test_step(self, batch, batch_idx):
        mels, labels, wavs = batch
        preds = self.model(mels)
        loss = self.hparams.criterion(preds, labels)
        acc = accuracy(preds, labels) * 100.
        return {"preds": preds, 'test_loss': loss, 'test_acc': acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), lr=self.hparams.lr)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.hparams.epochs)
        return [optimizer], [lr_scheduler]

    def setup(self, stage=None):
        self.hparams.criterion = torch.nn.CrossEntropyLoss()


### PyTorch Lightning Callback

We can instantiate a callback object to perform certain tasks during training. In this case, we log sample audio files, ground truth labels, and predicted labels.

We can also `ModelCheckpoint` callback to save the model after each epoch.

In [None]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # log 10 sample audio predictions from the first batch
        if batch_idx == 0:
            n = 10
            mels, labels, wavs = batch
            preds = outputs["preds"]
            preds = torch.argmax(preds, dim=1)

            labels = labels.cpu().numpy()
            preds = preds.cpu().numpy()
            
            wavs = torch.squeeze(wavs, dim=1)
            wavs = [ (wav.cpu().numpy()*32768.0).astype("int16") for wav in wavs]
            
            sample_rate = pl_module.hparams.sample_rate
            idx_to_class = pl_module.hparams.idx_to_class
            
            # log audio samples and predictions as a W&B Table
            columns = ['audio', 'mel', 'ground truth', 'prediction']
            data = [[wandb.Audio(wav, sample_rate=sample_rate), wandb.Image(mel), idx_to_class[label], idx_to_class[pred]] for wav, mel, label, pred in list(
                zip(wavs[:n], mels[:n], labels[:n], preds[:n]))]
            wandb_logger.log_table(
                key='ResNet18 on KWS using PyTorch Lightning',
                columns=columns,
                data=data)


### Program Arguments

The default configuration is shown below. 

We also define `plot_waveform` for plotting the waveform of the audio samples.



In [None]:
def get_args():
    parser = argparse.ArgumentParser()
    # model training hyperparameters
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 30)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')

    # where dataset will be stored
    parser.add_argument("--path", type=str, default="data/speech_commands/")

    # 35 keywords + silence + unknown
    parser.add_argument("--num-classes", type=int, default=37)
   
    # mel spectrogram parameters
    parser.add_argument("--n-fft", type=int, default=1024)
    parser.add_argument("--n-mels", type=int, default=128)
    parser.add_argument("--win-length", type=int, default=None)
    parser.add_argument("--hop-length", type=int, default=512)

    # 16-bit fp model to reduce the size
    parser.add_argument("--precision", default=16)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--devices", default=1)
    parser.add_argument("--num-workers", type=int, default=48)

    parser.add_argument("--no-wandb", default=False, action='store_true')

    args = parser.parse_args("")
    return args

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f'Channel {c+1}')
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

### KWS Training and Evaluation using `Trainer`

The actual training and validation using `Trainer`.

Accuracy is about `94.5%` (`94.0%` 16-bit) on the test set. Benchmark accuracy is `88.2%` on the original paper. The state of the art (as of 1 Apr 2022) is `97.0%`. For updated benchmarks, see [paperswithcode.com](https://paperswithcode.com/sota/keyword-spotting-on-google-speech-commands).

In [None]:

if __name__ == "__main__":

    args = get_args()
    CLASSES = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
               'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
               'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
               'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']
    
    # make a dictionary from CLASSES to integers
    CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}

    if not os.path.exists(args.path):
        os.makedirs(args.path, exist_ok=True)

    model = KWSModel(num_classes=args.num_classes, epochs=args.max_epochs, lr=args.lr)
    print(model)
    datamodule = KWSDataModule(batch_size=args.batch_size, num_workers=args.num_workers,
                               path=args.path, n_fft=args.n_fft, n_mels=args.n_mels,
                               win_length=args.win_length, hop_length=args.hop_length,
                               class_dict=CLASS_TO_IDX)
    datamodule.setup()

    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="pl-kws")

    model_checkpoint = ModelCheckpoint(
        dirpath=os.path.join(args.path, "checkpoints"),
        filename="resnet18-kws-best-acc",
        save_top_k=1,
        verbose=True,
        monitor='test_acc',
        mode='max',
    )
    idx_to_class = {v: k for k, v in CLASS_TO_IDX.items()}
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      precision=args.precision,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger if not args.no_wandb else None,
                      callbacks=[model_checkpoint, WandbCallback() if not args.no_wandb else None])
    model.hparams.sample_rate = datamodule.sample_rate
    model.hparams.idx_to_class = idx_to_class
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    wandb.finish()
    #trainer.save_checkpoint('../mnist/checkpoint.ckpt')


In [None]:
# https://pytorch-lightning.readthedocs.io/en/stable/common/production_inference.html
model = model.load_from_checkpoint(os.path.join(
    args.path, "checkpoints", "resnet18-kws-best-acc.ckpt"))
model.eval()
script = model.to_torchscript()

# save for use in production environment
model_path = os.path.join(args.path, "checkpoints",
                          "resnet18-kws-best-acc.pt")
torch.jit.save(script, model_path)

# list wav files given a folder
label = CLASSES[2:]
label = np.random.choice(label)
path = os.path.join(args.path, "SpeechCommands/speech_commands_v0.02/")
path = os.path.join(path, label)
wav_files = [os.path.join(path, f)
             for f in os.listdir(path) if f.endswith('.wav')]
# select random wav file
wav_file = np.random.choice(wav_files)
waveform, sample_rate = torchaudio.load(wav_file)
transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                 n_fft=args.n_fft,
                                                 win_length=args.win_length,
                                                 hop_length=args.hop_length,
                                                 n_mels=args.n_mels,
                                                 power=2.0)

mel = ToTensor()(librosa.power_to_db(
    transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)
scripted_module = torch.jit.load(model_path)
pred = torch.argmax(scripted_module(mel), dim=1)
print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")
