# Prepare data

In [1]:
from os.path import join
from pathlib import Path

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import random
import seaborn as sns

from pytorch_lightning.metrics.functional import accuracy, confusion_matrix
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from threadpoolctl import threadpool_info, threadpool_limits 
from torch import nn
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchaudio import transforms
from torchvision.transforms import RandomApply, Compose

from audio_loader.features.raw_audio import WindowedAudio
from audio_loader.features.mfcc import WindowedMFCC
from audio_loader.ground_truth.timit import TimitGroundTruth
from audio_loader.samplers.dynamic_sampler import DynamicSamplerFromGt
from audio_loader.dl_frontends.pytorch.fill_ram import pad_collate_supervised



In [2]:
for lib in threadpool_info(): 
    threadpool_limits(limits=max(2, int(lib['num_threads']/6)), user_api=lib['user_api'])

# Dataloader lightning

In [3]:
def aug_func(p=0.):
    return Compose([
        RandomApply([transforms.TimeMasking(2)], p=p),
        RandomApply([transforms.FrequencyMasking(1)], p=p)
    ])

class TrainValidTestPrototypicalDataset(Dataset):
    """ """

    def __init__(self, data):
        """ 
        Args:
            data: list of tuples (data, gt)
            anchor_data: list of tuples (data, gt) being the anchors
        """
        self.data = data

    def __len__(self):
        """Return the size of the set times the number of anchors used."""
        return len(self.data)

    def __getitem__(self, index):
        """Return a tuple of (data, anchor ) corresponding to the index."""
        datum = self.data[index]
        return datum[0], datum[1]

In [4]:
class TimitMFCCDataModule(pl.LightningDataModule):

    def __init__(self, data_dir, batch_size, shots=5, queries=15, batch_size_test=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.batch_size_test = batch_size_test
        self.shots = shots
        self.queries = queries
    
    def prepare_data(self):
        self.timit_gt = TimitGroundTruth(self.data_dir, with_silences=False, phon_class="phon_class2", return_original_gt=True)
        self.timit_gt.set_gt_format(phonetic=True, word=False, speaker_id=True)
        self.mfcc_feature_processor = WindowedMFCC(400, 160, 16000, 13, delta_orders=[1, 2], delta_width=9)
        self.sampler = DynamicSamplerFromGt([self.mfcc_feature_processor], self.timit_gt)
        
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            data_train = [(torch.tensor(x[0]),
                           y[:-1].argmax(),
                           y_label,
                           y[-1],
                           self.timit_gt.index2speaker_id[int(y[-1])][0],
                           self.timit_gt.df_all[self.timit_gt.df_all['speaker_id'] == self.timit_gt.index2speaker_id[int(y[-1])]]["dialect_region"].values[0])
                          for x, (y, y_label) in self.sampler.get_samples_from("train", randomize_files=False)]
            # select the anchors, the rest will be used for valid data
            self.df_data_train = pd.DataFrame(data_train, columns=['datum', 'gt', 'original_label', 'speaker_id', "sex", 'region']) 
            all_speaker_id = self.df_data_train['speaker_id'].unique()
            all_unique_gt = self.df_data_train['gt'].unique()
            
            dict_type_per_class = {}
            for gt in all_unique_gt:
                dict_type_per_class[gt] = self.df_data_train[self.df_data_train['gt'] == gt]['original_label'].unique()
                
            all_original_labels = self.df_data_train['original_label'].unique()
            speaker_used_overall = [[] for i in range(self.number_of_classes)]

            self.anchor_data, self.queries_data, speaker_used = self.select_samples(data_train,
                                                                                  self.shots,
                                                                                  speaker_used=speaker_used_overall,
                                                                                  region_constraint=False,
                                                                                   dict_type_per_class=dict_type_per_class)
            
            self.queries_data, self.valid_data, speaker_used = self.select_samples(self.queries_data,
                                                                                   self.queries,
                                                                                   speaker_used=speaker_used_overall,
                                                                                   region_constraint=False,
                                                                                   dict_type_per_class=dict_type_per_class)
            self.train_dataset = TrainValidTestPrototypicalDataset(self.queries_data)
            
            # take less few samples for the valid set to avoid too long computation time
            i_classes = np.zeros(self.number_of_classes)
            examples_per_class_valid = 20
    
            self.valid_data, self.nu_data, speaker_used = self.select_samples(self.valid_data,
                                                                              examples_per_class_valid,
                                                                              speaker_used=speaker_used_overall,
                                                                              region_constraint=False,
                                                                             dict_type_per_class=dict_type_per_class)
            
            self.valid_dataset = TrainValidTestPrototypicalDataset(self.valid_data)
            return
        
        if stage == 'test' or stage is None:
            data_test = [(torch.tensor(x[0]),
                               y[:-1].argmax(),
                               y[-1],
                               self.timit_gt.index2speaker_id[int(y[-1])][0],
                               self.timit_gt.df_all[self.timit_gt.df_all['speaker_id'] == self.timit_gt.index2speaker_id[int(y[-1])]]["dialect_region"].values[0])
                              for x, (y, y_label) in self.sampler.get_samples_from("test")]
            self.df_data_test = pd.DataFrame(data_test, columns=['datum', 'gt', 'speaker_id', "sex", 'region']) 
            self.test_dataset = TrainValidTestPrototypicalDataset(data_test)
            return

    
    def select_samples(self, data, shots, speaker_used, region_constraint=False, dict_type_per_class=None):
        """Return a tuple with anchors, unused data and the speaker id used per class."""
        
        anchor_data = []
        # initialise counters
        number_of_sample_selected = np.zeros(self.number_of_classes)
        male_per_phon = np.zeros(self.number_of_classes)
        female_per_phon = np.zeros(self.number_of_classes)
        limit_female = round(shots / 2.)
        limit_male = shots - limit_female
        
        dict_limit_per_type = {}
        dict_type_count = {}
        for label in dict_type_per_class:
            limit = math.ceil(shots/len(dict_type_per_class[label]))
            for original_label in dict_type_per_class[label]:
                dict_limit_per_type[original_label] = limit
                dict_type_count[original_label] = 0
        

        # region_used = [[] for i in range(self.number_of_classes)]
        region_used = np.zeros((self.number_of_classes, 8)) # 8 number of regions 
        region_goal = np.zeros((self.number_of_classes, 8)) # 8 number of regions
        for i in range(shots):
            region_goal[:, i%8] += 1

        if region_constraint:
            # fix as there is no example from R8 in some classes
            # don't work for shots > 10
            if self.number_of_classes == 58:
                region_goal[29, 7] = 0
                region_goal[50, 7] = 0
                region_goal[51, 7] = 0
                region_goal[29, 2] += 1
                region_goal[50, 2] += 1
                region_goal[51, 2] += 1

                if region_goal[50, 0] > 1: 
                    region_goal[50, 0] = 1
                    region_goal[51, 0] = 1
                    region_goal[50, 4] += 1
                    region_goal[51, 4] += 1

            elif self.number_of_classes == 47:
                # don't work for shots > 10
                region_goal[40, 7] -= 0
                region_goal[40, 0] -= 0
                region_goal[40, 2] += 1
                region_goal[40, 3] += 1
            elif self.number_of_classes == 39:
                # don't work for shots > 10
                region_goal[33, 7] -= 0
                region_goal[33, 0] -= 0
                region_goal[33, 2] += 1
                region_goal[33, 3] += 1

        i = 0
        number_anchors = 0
        expected_anchors = shots * self.number_of_classes
        # apply strategies to select data
        while i < len(data) and number_anchors < expected_anchors:
            datum, gt, original_gt, speaker_id, sex, region = data[i]
            if region_constraint:
                condition = number_of_sample_selected[gt] < shots and speaker_id not in speaker_used[gt] and region_used[gt, int(region[-1])-1] < region_goal[gt, int(region[-1])-1] and dict_type_count[original_gt] < dict_limit_per_type[original_gt]
            else:
                condition = number_of_sample_selected[gt] < shots and speaker_id not in speaker_used[gt] and dict_type_count[original_gt] < dict_limit_per_type[original_gt]
                
            if condition:
                if sex == "F":
                    if female_per_phon[gt] < limit_female:
                        female_per_phon[gt] += 1
                    else:
                        i += 1
                        continue
                else:
                    if female_per_phon[gt] < limit_male:
                        male_per_phon[gt] += 1
                    else:
                        i += 1
                        continue

                region_used[gt, int(region[-1])-1] += 1
                number_of_sample_selected[gt] += 1
                dict_type_count[original_gt] += 1
                speaker_used[gt].append(speaker_id)
                number_anchors += 1
                anchor_data.append((datum, gt))
                del data[i]
            else:
                i += 1

        if number_anchors < expected_anchors:
            print(region_goal-region_used)
            print(region_goal)
            print(region_used)
            print(f"number_of_sample_selected:\n{number_of_sample_selected}")
            print(f"speaker_used:\n{speaker_used}")
            raise Exception("strategy not possible")

        return anchor_data, data, speaker_used
    
    @property
    def number_of_classes(self):
        return self.sampler.number_of_classes-1  # minus one as the last dimension is for the speaker id
    
    @property
    def anchors(self):
        return self.anchor_data

    def train_dataloader(self):
        if len(self.train_dataset) < self.batch_size:
            raise Exception(f"Error the dataset size ({len(self.train_dataset)}) if inferior to the desired batch size ({self.batch_size})")
    
        data_loader =  DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                                  collate_fn=pad_collate_supervised,
                                  num_workers=1,
                                  drop_last=True)
        print(f"train {len(data_loader)}")
        return data_loader

    def val_dataloader(self):
        data_loader = DataLoader(self.valid_dataset, batch_size=self.batch_size_test, shuffle=False,
                                 collate_fn=pad_collate_supervised,
                                 num_workers=0,
                                 drop_last=False)
        print(f"valid {len(data_loader)}")
        return data_loader

    def test_dataloader(self):
        data_loader = DataLoader(self.test_dataset, self.batch_size_test , shuffle=False,
                                 collate_fn=pad_collate_supervised,
                                 num_workers=0,
                                 drop_last=False)
        print(f"test {len(data_loader)}")
        return data_loader

# Model definition

In [5]:
class lit_mfcc_model(pl.LightningModule):
    def __init__(self, feature_size, anchor_data, num_classes, num_shots, labels):
        """Init all parameters.
        
        feature_size: int
            size of the expected features for the forward step
        """
        super().__init__()
        self.feature_size = feature_size
        self.num_classes = num_classes
        self.num_shots = num_shots
        self.labels = labels
        
        # prepare the anchors for test and valid
        self.anchors_gt = [gt for data, gt in anchor_data]
        anchors = [data for data, gt in anchor_data]
        
        # sort anchors
        sorted_idx = sorted(range(len(self.anchors_gt)), key=lambda k: self.anchors_gt[k])
        self.anchors_gt = [self.anchors_gt[idx] for idx in sorted_idx]
        anchors = [anchors[idx] for idx in sorted_idx]
        
        
        # prepare anchor data
        anchors_lens = [len(x) for x in anchors]
        anchors_padded = pad_sequence(anchors, batch_first=True, padding_value=0.)
        self.packed_anchors = pack_padded_sequence(anchors_padded, anchors_lens,
                                                   batch_first=True, enforce_sorted=False)

        # Initialize the model
        self.layer_1_grus = nn.GRU(
            feature_size, 256, 5,
            bidirectional=True,
            batch_first=True,
            dropout=0.2
        )
        
        self.bn_fwd = nn.BatchNorm1d(256)
        self.bn_bwd = nn.BatchNorm1d(256)
        self.layer_2_dense = torch.nn.Linear(512, 128)
        
        self.alpha = torch.nn.Linear(128, 128, bias=False)
        
        self.loss = torch.nn.CrossEntropyLoss() #torch.nn.BCEWithLogitsLoss(size_average=True)
        
    def forward_one(self, x):
        batch_size = x.batch_sizes[0]
        # shape: (num_layers*directions, batch_size, hidden_size?)
        h_0 = torch.zeros(5*2, batch_size, 256, device=self.device)
        output, h_n = self.layer_1_grus(x, h_0)

        fwd_h = h_n.view(5, 2, batch_size, 256)[-1, 0]
        bwd_h = h_n.view(5, 2, batch_size, 256)[-1, 1]
    
        fwd_h = self.bn_fwd(fwd_h.view(batch_size, 256))
        bwd_h = self.bn_bwd(bwd_h.view(batch_size, 256))

        h = torch.cat((fwd_h, bwd_h), 1)
        return self.layer_2_dense(h)
        

    def compute_prototypes(self):
        cks = [[] for i in range(len(np.unique(self.anchors_gt)))]
        
        out = self.forward_one(self.packed_anchors.to(self.device))
        for anchor_emb, gt in zip(out, self.anchors_gt):
            cks[gt].append(anchor_emb)
            
        for i in range(len(cks)):
            cks[i] = torch.mean(torch.stack(cks[i]), axis=0)
            
        return cks
    
    def distance(self, x1, x2):
        """Euclidean distance."""
        #return torch.abs(x1-x2).sum(axis=list(range(1, len(x1.shape)))).reshape(len(x1), 1)
        return torch.abs(x1-x2).sum()

    def forward(self, inputs):
        """Forward of the model over the data."""
        xs, cks = inputs
        
        embs = self.forward_one(xs)
        res = []
        for emb in embs:
            distances = []
            for ck in cks:
                distances.append(-self.distance(emb, ck))
            
            res.append(torch.stack(distances))
        
        return torch.stack(res)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0006) # initial fixed:  lr=0.0004)
        return {
            'optimizer': optimizer,
            'lr_scheduler': ReduceLROnPlateau(optimizer, mode='max', factor=0.95, patience=5, threshold=0.0001, min_lr=0.0002),
            'monitor': 'val_accuracy'
        }
    
    def training_step(self, batch, batch_idx):
        xs, ys = batch
        
        cks = self.compute_prototypes()
        y_hat = self.forward([xs.to(self.device), torch.stack(cks).to(self.device)])
        
        loss = self.loss(y_hat, torch.tensor(ys, device=self.device))
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        xs, ys = batch

        cks = self.compute_prototypes()
        y_hat = self.forward([xs.to(self.device), torch.stack(cks).to(self.device)])

        val_loss = self.loss(y_hat, torch.tensor(ys, device=self.device))
        y_hat = y_hat.argmax(axis=1)

        return {'loss': val_loss, 'preds_strat1': y_hat, 'target': torch.tensor(batch[1], device=self.device)}

    def validation_epoch_end(self, outputs):  
        preds = torch.cat([tmp['preds_strat1'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        
        # simple metrics
        self.log('val_loss', torch.stack([tmp['loss'] for tmp in outputs]).mean())
        self.log('val_accuracy', accuracy(preds, targets))
        
    def test_step(self, batch, batch_idx):
        xs, ys = batch

        cks = self.compute_prototypes()
        
        y_hat = self.forward([xs.to(self.device), torch.stack(cks).to(self.device)])
        test_loss = self.loss(y_hat, torch.tensor(ys, device=self.device))
        y_hat = y_hat.argmax(axis=1)
        
        return {'loss': test_loss, 'preds_strat1': y_hat, 'target': torch.tensor(batch[1], device=self.device)}
    
    def test_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds_strat1'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        
        # simple metrics
        self.log('test_loss', torch.stack([tmp['loss'] for tmp in outputs]).mean())
        self.log('test_accuracy', accuracy(preds, targets))
        
        # confusion matrix
        cm = confusion_matrix(preds, targets, num_classes=self.num_classes).cpu().numpy()
        normalized_cm = np.around(cm/cm.sum(axis=1)[:, None], 3)*100 # normalize by line

        df_cm = pd.DataFrame(normalized_cm, index=self.labels, columns=self.labels)
        plt.figure(figsize = (15,12))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)

        self.logger.experiment.add_figure(f"Test - Confusion matrix", fig_, self.current_epoch)

# Train the model

## Load data

In [6]:
batch_size = 32

shots = 10 # number of anchors perclass (for the suport set)
queries = 15 # number of query per class

# the data
data_dir = join(Path.home(), "data/kaggle_TIMIT")

mfcc_timit = TimitMFCCDataModule(join(Path.home(), "data/kaggle_TIMIT"), batch_size, shots, queries=queries)

mfcc_timit.prepare_data()
mfcc_timit.setup("fit") # to be able to call anchors property

keys_timit = [i for i in range(mfcc_timit.timit_gt.phon_size)]
labels_timit = [mfcc_timit.timit_gt.index2phn[i] for i in keys_timit]

## Stats on data

In [7]:
do_stats = False
if do_stats:
    mfcc_timit.setup("test")
    
    mfcc_timit.df_data_train["gt"] = mfcc_timit.df_data_train["gt"].map(lambda x: labels_timit[x])
    mfcc_timit.df_data_test["gt"] = mfcc_timit.df_data_test["gt"].map(lambda x: labels_timit[x])
    
    print('train stats:')
    print(f'number of examples in train: {mfcc_timit.df_data_train["gt"].count()}')
    print(f"number of examples per gt:\n{mfcc_timit.df_data_train['gt'].value_counts()}")
    print(f"number of different speakers: {mfcc_timit.df_data_train['speaker_id'].nunique()}")
    print(f"number of examples per sex:\n{mfcc_timit.df_data_train['sex'].value_counts()}")
    print(f"number of examples per region:\n{mfcc_timit.df_data_train['region'].value_counts()}")
    
    print('test stats:')
    print(f'number of examples in test: {mfcc_timit.df_data_test["gt"].count()}')
    print(f"number of examples per gt:\n{mfcc_timit.df_data_test['gt'].value_counts()}")
    print(f"number of different speakers: {mfcc_timit.df_data_test['speaker_id'].nunique()}")
    print(f"number of examples per sex:\n{mfcc_timit.df_data_test['sex'].value_counts()}")
    print(f"number of examples per region:\n{mfcc_timit.df_data_test['region'].value_counts()}")
    
    # show anchors
    import matplotlib.pyplot as plt
    
    anchors = {}
    for i in labels_timit:
        anchors[i] = []
        
    for anchor in mfcc_timit.anchors:
        anchors[labels_timit[anchor[1]]].append(anchor[0].numpy().transpose())
        
    for anchor in anchors.keys():
        fig, ax = plt.subplots(2, 5, figsize=(30,15), sharey=True)
        fig.suptitle(f'Anchors {anchor}')
        i = 0
        for a in anchors[anchor]:
            ax[i//5, i%5].imshow(a, origin='lower')
            i += 1
        
        plt.show()


## First 1000 epochs

In [8]:
# init model
model = lit_mfcc_model(13*3, mfcc_timit.anchors, mfcc_timit.timit_gt.phon_size, shots, labels_timit) # features = MFCC + delta+ deltas deltas
model.to('cuda')

# trainer definition
trainer = pl.Trainer(
     callbacks=[
#        EarlyStopping(monitor='val_accuracy_argmin', patience=200, mode="max"),
        LearningRateMonitor(logging_interval='epoch')
    ],
    checkpoint_callback=ModelCheckpoint(save_top_k=5, monitor="val_accuracy", mode="max"),
    progress_bar_refresh_rate=None,
    gpus=1, auto_select_gpus=True,
    check_val_every_n_epoch=10,
    precision=16,
    max_epochs=1000
)

trainer.fit(model, mfcc_timit)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | layer_1_grus  | GRU              | 5.2 M 
1 | bn_fwd        | BatchNorm1d      | 512   
2 | bn_bwd        | BatchNorm1d      | 512   
3 | layer_2_dense | Linear           | 65.7 K
4 | alpha         | Linear           | 16.4 K
5 | loss          | CrossEntropyLoss | 0     
---------------------------------------------------
5.3 M     Trainable params
0         Non-trainable params
5.3 M     Total params
21.081    Total estimated model params size (MB)


valid 4
Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



train 18                                                              
Epoch 0:   0%|          | 0/18 [00:00<?, ?it/s] 



Epoch 9:  82%|████████▏ | 18/22 [00:05<00:01,  3.48it/s, loss=3.57, v_num=423]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/4 [00:00<?, ?it/s][A
Epoch 9:  91%|█████████ | 20/22 [00:05<00:00,  3.54it/s, loss=3.57, v_num=423]
Validating:  50%|█████     | 2/4 [00:00<00:00,  2.19it/s][A
Epoch 9: 100%|██████████| 22/22 [00:06<00:00,  3.32it/s, loss=3.57, v_num=423]
Epoch 19:  82%|████████▏ | 18/22 [00:05<00:01,  3.53it/s, loss=2.79, v_num=423]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/4 [00:00<?, ?it/s][A
Epoch 19:  91%|█████████ | 20/22 [00:05<00:00,  3.59it/s, loss=2.79, v_num=423]
Validating:  50%|█████     | 2/4 [00:00<00:00,  2.20it/s][A
Epoch 19: 100%|██████████| 22/22 [00:06<00:00,  3.37it/s, loss=2.79, v_num=423]
Epoch 29:  82%|████████▏ | 18/22 [00:05<00:01,  3.47it/s, loss=1.69, v_num=423]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/4 [00:00<?, ?it/s][A
Epoch 29:  91%|█████████ | 20/22 [00:05<00:00,  3.54it/s, lo

1

## Last epochs

In [9]:
# trainer = pl.Trainer(
#    callbacks=[
# #        EarlyStopping(monitor='val_accuracy_argmin', patience=200, mode="max"),
#        LearningRateMonitor(logging_interval='epoch')
#    ],
#    checkpoint_callback=ModelCheckpoint(save_top_k=5, monitor="val_accuracy", mode="max"),
#    progress_bar_refresh_rate=None,
#    gpus=1, auto_select_gpus=True,
#    check_val_every_n_epoch=5,
#    precision=16,
#    max_epochs=2000
#)
# trainer.fit(model, mfcc_timit)

## Test

In [10]:
if not do_stats:
    mfcc_timit.setup("test")

In [11]:
trainer.test(test_dataloaders=mfcc_timit.test_dataloader(), ckpt_path='best')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


test 218
Testing: 0it [00:00, ?it/s]



Testing: 100%|██████████| 218/218 [01:49<00:00,  1.99it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.3975563645362854, 'test_loss': 14.415218353271484}
--------------------------------------------------------------------------------


[{'test_loss': 14.415218353271484, 'test_accuracy': 0.3975563645362854}]