# Prepare data

In [1]:
from os.path import join
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
import seaborn as sns

from torch import nn
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from pytorch_lightning.metrics.functional import accuracy, confusion_matrix
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from audio_loader.features.raw_audio import WindowedAudio
from audio_loader.features.mfcc import WindowedMFCC
from audio_loader.ground_truth.timit import TimitGroundTruth
from audio_loader.samplers.dynamic_sampler import DynamicSampler
from audio_loader.dl_frontends.pytorch.fill_ram import get_dataset_dynamic_size

# Dataloader lightning

In [2]:
class TimitMFCCDataModule(pl.LightningDataModule):

    def __init__(self, data_dir, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
    
    def prepare_data(self):
        self.timit_gt = TimitGroundTruth(self.data_dir, phon_class="phon_class2", with_silences=False)
        self.mfcc_feature_processor = WindowedMFCC(400, 160, 16000, 13, delta_orders=[1, 2], delta_width=9)
        self.mfcc_sampler = DynamicSampler([self.mfcc_feature_processor], self.timit_gt)
        self.original_train_dataset, self.collate_func = get_dataset_dynamic_size(self.mfcc_sampler, "train")
        self.test_dataset, self.collate_func = get_dataset_dynamic_size(self.mfcc_sampler, "test")
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.val_nb_samples = round(len(self.original_train_dataset)/100)
            self.train_nb_samples = len(self.original_train_dataset) - self.val_nb_samples
            self.train_dataset, self.val_dataset = random_split(
                self.original_train_dataset,
                [self.train_nb_samples, self.val_nb_samples]
            )
        
        if stage == 'test' or stage is None:
            return
            

    def train_dataloader(self):
        print(f"Train number of examples: {len(self.train_dataset)}\n")
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          collate_fn=self.collate_func,
                          drop_last=True)

    def val_dataloader(self):
        #return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
        #                  collate_fn=self.collate_func,
        #                  drop_last=False)
        return DataLoader(self.test_dataset, self.batch_size , shuffle=False,
                          collate_fn=self.collate_func,
                          drop_last=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, self.batch_size , shuffle=False,
                          collate_fn=self.collate_func,
                          drop_last=False)

In [3]:
# the data
mfcc_timit = TimitMFCCDataModule(join(Path.home(), "data/kaggle_TIMIT"), 16)

# Model definition

In [4]:
class lit_mfcc_model(pl.LightningModule):
    def __init__(self, feature_size, labels):
        """Init all parameters.
        
        feature_size: int
            size of the expected features for the forward step
        labels: list
            Ground truth labels to display in the confusion matrix
        """
        super().__init__()
        self.feature_size = feature_size
        self.labels = labels
        
        self.layer_1_grus = nn.GRU(
            feature_size, 550, 5,
            bidirectional=True,
            batch_first=True,
            dropout=0.2
        )
        
        self.bn_fwd = nn.BatchNorm1d(550)
        self.bn_bwd = nn.BatchNorm1d(550)
        self.layer_2_dense = torch.nn.Linear(1100, 128)
        self.bn_layer_2 = nn.BatchNorm1d(128)
        self.act_layer_2 = nn.LeakyReLU(0.1) # in pytorch kaldi it is softmax
        
        self.layer_3_dense = torch.nn.Linear(128, 58)
        
    def forward(self, x):
        """Forward of the model over the data."""
        batch_size = x.batch_sizes[0]
        # shape: (num_layers*directions, batch_size, hidden_size?)
        h_0 = torch.zeros(5*2, batch_size, 550, device="cuda")
        output, h_n = self.layer_1_grus(x, h_0)

        fwd_h = h_n.view(5, 2, batch_size, 550)[-1, 0]
        bwd_h = h_n.view(5, 2, batch_size, 550)[-1, 1]
    
        fwd_h = self.bn_fwd(fwd_h.view(batch_size, 550))
        bwd_h = self.bn_bwd(bwd_h.view(batch_size, 550))

        h = torch.cat((fwd_h, bwd_h), 1)
        dense1 = self.bn_layer_2(self.act_layer_2(self.layer_2_dense(h)))
        return self.layer_3_dense(dense1)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0004)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x.to('cuda'))
        _, y = torch.stack(y).max(dim=1)
        
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x.to('cuda'))
        _, target = torch.stack(y).max(dim=1)
        _, pred = y_hat.max(dim=1)

        val_loss = F.cross_entropy(y_hat, target)
        return {'loss': val_loss, 'preds_strat1': pred, 'target': target}

    def validation_epoch_end(self, outputs):  
        preds = torch.cat([tmp['preds_strat1'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        
        # simple metrics
        self.log('val_loss', torch.stack([tmp['loss'] for tmp in outputs]).mean())
        self.log('val_accuracy', accuracy(preds, targets))
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x.to('cuda'))
        _, target = torch.stack(y).max(dim=1)
        _, pred = y_hat.max(dim=1)
        
        test_loss = F.cross_entropy(y_hat, target)
        return {'loss': test_loss, 'preds_strat1': pred, 'target': target}
    
    def test_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds_strat1'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        
        # simple metrics
        self.log('test_loss', torch.stack([tmp['loss'] for tmp in outputs]).mean())
        self.log('test_accuracy', accuracy(preds, targets))
        
        # confusion matrix
        num_classes = len(self.labels)
        
        cm = confusion_matrix(preds, targets, num_classes=num_classes).cpu().numpy()
        normalized_cm = np.around(cm/cm.sum(axis=1)[:, None], 3)*100 # normalize by line

        df_cm = pd.DataFrame(normalized_cm, index=self.labels, columns=self.labels)
        plt.figure(figsize = (15,12))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)

        self.logger.experiment.add_figure(f"Test - Confusion matrix", fig_, self.current_epoch)

In [5]:
# get labels
mfcc_timit.prepare_data()
keys_timit = [i for i in range(mfcc_timit.timit_gt.phon_size)]
labels = [mfcc_timit.timit_gt.index2phn[i] for i in keys_timit]

# init model
model = lit_mfcc_model(13*3, labels) # MFCC + delta+ deltas deltas
model.to('cuda')

lit_mfcc_model(
  (layer_1_grus): GRU(39, 550, num_layers=5, batch_first=True, dropout=0.2, bidirectional=True)
  (bn_fwd): BatchNorm1d(550, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_bwd): BatchNorm1d(550, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer_2_dense): Linear(in_features=1100, out_features=128, bias=True)
  (bn_layer_2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act_layer_2): LeakyReLU(negative_slope=0.1)
  (layer_3_dense): Linear(in_features=128, out_features=58, bias=True)
)

# Train model

In [6]:
# trainer definition
trainer = pl.Trainer(
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=10, mode="min")
    ],
    checkpoint_callback=ModelCheckpoint(save_top_k=5, monitor="val_loss", mode="min"),
    progress_bar_refresh_rate=1000,
    gpus=1, auto_select_gpus=True,
    precision=16,
    max_epochs=100
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [None]:
trainer.fit(model, mfcc_timit)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type        | Params
----------------------------------------------
0 | layer_1_grus  | GRU         | 23.8 M
1 | bn_fwd        | BatchNorm1d | 1.1 K 
2 | bn_bwd        | BatchNorm1d | 1.1 K 
3 | layer_2_dense | Linear      | 140 K 
4 | bn_layer_2    | BatchNorm1d | 256   
5 | act_layer_2   | LeakyReLU   | 0     
6 | layer_3_dense | Linear      | 7.5 K 
----------------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
95.630    Total estimated model params size (MB)


Train number of examples: 139791                              

Epoch 0:   0%|          | 0/11947 [00:00<?, ?it/s] 



Epoch 0:  75%|███████▌  | 9000/11947 [06:20<02:04, 23.67it/s, loss=1.28, v_num=469]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3211 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▎ | 10000/11947 [06:26<01:15, 25.85it/s, loss=1.28, v_num=469]
Epoch 0:  92%|█████████▏| 11000/11947 [06:33<00:33, 27.95it/s, loss=1.28, v_num=469]
Validating:  93%|█████████▎| 3000/3211 [00:19<00:01, 150.51it/s][A
Epoch 0: 100%|██████████| 11947/11947 [06:41<00:00, 29.75it/s, loss=1.21, v_num=469]
Epoch 1:  75%|███████▌  | 9000/11947 [04:47<01:34, 31.33it/s, loss=0.979, v_num=469]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/3211 [00:00<?, ?it/s][A
Epoch 1:  84%|████████▎ | 10000/11947 [04:53<00:57, 34.02it/s, loss=0.979, v_num=469]
Epoch 1:  92%|█████████▏| 11000/11947 [05:00<00:25, 36.59it/s, loss=0.979, v_num=469]
Validating:  93%|█████████▎| 3000/3211 [00:19<00:01, 150.41it/s][A
Epoch 1: 100%|██████████| 11947/11947 [05:08<00:00, 38.71it/s, loss=1.05, v_num=469] 
Ep

In [None]:
trainer.test(model, mfcc_timit.test_dataloader())