In [1]:
def add_to_class(Class):
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
from torch import nn
from torch.nn import functional as F
import torchaudio
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import os
import sys
import IPython.display as ipd

In [4]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.11.0-py3-none-any.whl (25 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=1.13.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [5]:
import lightning as L

In [6]:
sys.path.append('/content/drive/MyDrive/GSC/GSC_helper')

# import GSC
#from GSC_preprocessing import GSC_TrainAugment, GSC_TestAugment
from MDTC import MDTC
from GSC_zip import unzipzip, zipzip
#from mdtc_1 import MDTC
#from mdtc_git import MDTC_git

## BC_Resnet Augment

In [None]:
ZIP_MAP_12 = {
    'train': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/train_12.zip',
    'val': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/val_12.zip',
    'test': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/test_12.zip'
}
ZIP_MAP_35 = {
    'train': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/train.zip',
    'val': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/val.zip',
    'test': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/test.zip'
}

CSV_MAP_12 = {
    'train': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/train_12.csv',
    'val': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/val_12.csv',
    'test': '/content/drive/MyDrive/GSC/Background_Noise/GSC_12_BC/test_12.csv'
}

CSV_MAP_35 = {
    'train': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/train.csv',
    'val': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/val.csv',
    'test': '/content/drive/MyDrive/GSC/Background_Noise/GSC_35_BC_Resnet/test.csv'
}

In [None]:
ZIP_MAP_12 = {
    'train': '/content/drive/MyDrive/Dataset/GSC_12_BC/train_12.zip',
    'val': '/content/drive/MyDrive/Dataset/GSC_12_BC/val_12.zip',
    'test': '/content/drive/MyDrive/Dataset/GSC_12_BC/test_12.zip'
}

CSV_MAP_12 = {
    'train': '/content/drive/MyDrive/Dataset/GSC_12_BC/train_12.csv',
    'val': '/content/drive/MyDrive/Dataset/GSC_12_BC/val_12.csv',
    'test': '/content/drive/MyDrive/Dataset/GSC_12_BC/test_12.csv'
}


In [None]:
class GSC(torch.utils.data.Dataset):
    def __init__(self, root, subset = 'train', zip_map = None, csv_map = None, unzip = True):
        super().__init__()
        local_path = os.path.join(root, subset)
        self.root = root
        if not os.path.exists(local_path):
            os.mkdir(local_path)
            unzipzip(zip_map[subset], local_path)
        if unzip:
            unzipzip(zip_map[subset], local_path)
        self.csv = pd.read_csv(csv_map[subset])

    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        spec = np.load(os.path.join(self.root, row['link']))['arr_0']
        return torch.from_numpy(spec).unsqueeze(0), row['label']

    def __len__(self):
        return len(self.csv)

In [None]:
class SC_12(L.LightningDataModule):
    def __init__(self, root, batch_size, zip_map, csv_map):
        super().__init__()
        self.root = root
        self.batch_size = batch_size
        self.train_dataset = GSC(root,
                                 subset = 'train',
                                 zip_map = zip_map,
                                 csv_map = csv_map,
                                 unzip = False)
        self.val_dataset = GSC(root,
                               subset = 'val',
                               zip_map = zip_map,
                               csv_map = csv_map,
                               unzip = False)
        self.test_dataset = GSC(root,
                                subset = 'test',
                                zip_map = zip_map,
                                csv_map = csv_map,
                                unzip = False)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = True,
    #                                       collate_fn = self.collate_fn,
    #                                       num_workers = 1,
    #                                       prefetch_factor = 1,
                                           pin_memory = True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = False,
     #                                      collate_fn = self.collate_fn,
     #                                      num_workers = 1,
     #                                      prefetch_factor = 1,
                                           pin_memory = True)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = False,
      #                                     collate_fn = self.collate_fn,
      #                                     num_workers = 1,
                                           prefetch_factor = 1)

In [None]:
data_12 = SC_12('/content/GSC_12', 128, ZIP_MAP_12, CSV_MAP_12)

Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/train_12.zip
Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/val_12.zip
Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/test_12.zip


In [None]:
X, y = next(iter(data_12.train_dataloader()))
X.shape

torch.Size([128, 1, 40, 101])

In [8]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
from torchinfo import summary

summary(model, input_size = (128, 40, 81), device = 'cpu')

Layer (type:depth-idx)                        Output Shape              Param #
MDTC                                          [128, 12]                 --
├─DTCBlock: 1-1                               [128, 64, 81]             --
│    └─CausalConv1d: 2-1                      [128, 40, 81]             240
│    └─BatchNorm1d: 2-2                       [128, 40, 81]             80
│    └─Conv1d: 2-3                            [128, 64, 81]             2,624
│    └─ReLU: 2-4                              [128, 64, 81]             --
│    └─BatchNorm1d: 2-5                       [128, 64, 81]             128
│    └─Conv1d: 2-6                            [128, 64, 81]             4,160
│    └─BatchNorm1d: 2-7                       [128, 64, 81]             128
│    └─ReLU: 2-8                              [128, 64, 81]             --
├─ModuleList: 1-2                             --                        --
│    └─DTCStack: 2-9                          [128, 64, 81]             --
│    │    └

In [10]:
class MDTC_training(L.LightningModule):
    def __init__(self, lr, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self.automatic_optimization = False
        self.lr = lr
        self.linear = nn.Linear(40, 64)
        self.net = MDTC(in_channels = 64,
             out_channels = 64,
             kernel_size = 5,
             stack_num = 4,
             stack_size = 4,
             classification = True,
             hidden_size = 64,
             num_classes = num_classes,
             dropout = 0.5)

    def forward(self, input):
        input = self.linear(input.squeeze(1).transpose(1, 2))
        return self.net(input.transpose(1, 2))

In [11]:
from torchinfo import summary

net = MDTC_training(lr = 0.001, num_classes = 12)
summary(net, input_size = (1, 1, 40, 81))

Layer (type:depth-idx)                             Output Shape              Param #
MDTC_training                                      [12]                      --
├─Linear: 1-1                                      [1, 81, 64]               2,624
├─MDTC: 1-2                                        [12]                      --
│    └─DTCBlock: 2-1                               [1, 64, 81]               --
│    │    └─CausalConv1d: 3-1                      [1, 64, 81]               384
│    │    └─BatchNorm1d: 3-2                       [1, 64, 81]               128
│    │    └─Conv1d: 3-3                            [1, 64, 81]               4,160
│    │    └─BatchNorm1d: 3-4                       [1, 64, 81]               128
│    │    └─ReLU: 3-5                              [1, 64, 81]               --
│    │    └─Conv1d: 3-6                            [1, 64, 81]               4,160
│    │    └─BatchNorm1d: 3-7                       [1, 64, 81]               128
│    │    └─ReLU: 3-8 

In [None]:
class MDTC_training(L.LightningModule):
    def __init__(self, lr, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self.automatic_optimization = False
        self.lr = lr
        self.net = MDTC_git(4, 4, 40, 64, 5, True, 64, num_classes)

    def forward(self, input):
        return self.net(input)

In [None]:
@add_to_class(MDTC_training)
def accuracy(self, Y_hat, Y, averaged = True):
    """
    Compute the number of correct predictions
    """
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(dim = 1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

@add_to_class(MDTC_training)
def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)

    # single scheduler
    # sch = self.lr_schedulers()
    #sch.step()
    torch.nn.utils.clip_grad_norm(self.parameters(), 5)

    values = {"train_loss": loss, "train_acc": acc}
    self.log_dict(values, prog_bar = True)
    return loss

@add_to_class(MDTC_training)
def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)
    values = {"val_loss": loss, "val_acc": acc}
    self.log_dict(values, prog_bar = True)
    return values

@add_to_class(MDTC_training)
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay = 0.00005)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 3, gamma = 0.8)
    #return [optimizer], [lr_scheduler]
    return optimizer

#@add_to_class(MDTC_training)
#def lr_schedulers(self):
#    lr_scheduler = torch.optim.lr_scheduler.StepLR(self.configure_optimizers(), step_size = 3, gamma = 0.8)
#    return lr_scheduler

#@add_to_class(MDTC_training)
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # update params
#    optimizer.step(closure = optimizer_closure)

    # manually warm up lr withou a scheduler
#    if self.trainer.global_step < 2000:
#        lr_scale = self.trainer.global_step/2000
#    else:
#        lr_scale = (2000/self.trainer.global_step)**0.5

#    for pg in optimizer.param_groups:
#        pg['lr'] = lr_scale*self.lr

In [None]:
@add_to_class(MDTC_training)
def loss(self, y_hat, y):
    return F.cross_entropy(y_hat, y, reduction = 'mean')

In [None]:
from typing import Tuple

import torch
from torch import nn, Tensor


def convert_label_to_similarity(normed_feature: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
    similarity_matrix = normed_feature @ normed_feature.transpose(1, 0)
    label_matrix = label.unsqueeze(1) == label.unsqueeze(0)

    positive_matrix = label_matrix.triu(diagonal=1)
    negative_matrix = label_matrix.logical_not().triu(diagonal=1)

    similarity_matrix = similarity_matrix.view(-1)
    positive_matrix = positive_matrix.view(-1)
    negative_matrix = negative_matrix.view(-1)
    return similarity_matrix[positive_matrix], similarity_matrix[negative_matrix]


class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
        ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = - ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

        return loss

In [None]:
class CircleLossLikeCE(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLossLikeCE, self).__init__()
        self.m = m
        self.gamma = gamma
        self.loss = nn.CrossEntropyLoss()

    def forward(self, inp: Tensor, label: Tensor) -> Tensor:
        a = torch.clamp_min(inp + self.m, min=0).detach()
        src = torch.clamp_min(
            - inp.gather(dim=1, index=label.unsqueeze(1)) + 1 + self.m,
            min=0,
        ).detach()
        a.scatter_(dim=1, index=label.unsqueeze(1), src=src)

        sigma = torch.ones_like(inp, device=inp.device, dtype=inp.dtype) * self.m
        src = torch.ones_like(label.unsqueeze(1), dtype=inp.dtype, device=inp.device) - self.m
        sigma.scatter_(dim=1, index=label.unsqueeze(1), src=src)

        return self.loss(a * (inp - sigma) * self.gamma, label)

In [None]:
@add_to_class(MDTC_training)
def loss(self, y_hat, y):
    loss1 = CircleLossLikeCE(0.25, 80)
    y_hat = nn.functional.normalize(y_hat, dim = -1)
    loss2 = CircleLoss(0.4, 30)
    sp, sn = convert_label_to_similarity(y_hat, y)

    loss = loss1(y_hat, y) + loss2(sp, sn)

    return loss

In [None]:
data = SC_12('/content/GSC_12', 128, ZIP_MAP_12, CSV_MAP_12)

Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/train_12.zip
Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/val_12.zip
Extracted /content/drive/MyDrive/Dataset/GSC_12_BC/test_12.zip


In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
early_stopping_callback = EarlyStopping(monitor = "val_acc", min_delta = 0.0001, patience = 5, mode = "max")
checkpoint_callback = ModelCheckpoint(dirpath = '/content/best_model',
                                      save_top_k = 5, monitor = 'val_acc',
                                      mode = 'max',
                                      filename = 'mdtc-gsc-12-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}')
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
from lightning.pytorch import seed_everything

seed_everything(42)

net = MDTC_training(0.001, 12)

trainer = L.Trainer(accelerator="gpu",
                    callbacks = [early_stopping_callback, checkpoint_callback,
                                 lr_monitor],
                    enable_checkpointing=True,
                    default_root_dir = "/content/mdtc1",
                    max_epochs=100)
trainer.fit(net, data)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name   | Type   | Params
----------------------------------
0 | linear | Linear | 2.6 K 
1 | net    | MDTC   | 159 K 
----------------------------------
162 K     Trainable params
0         Non-trainable params
162 K     Total params
0.648     Total estimated model params size (MB)
IN

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

  torch.nn.utils.clip_grad_norm(self.parameters(), 5)


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
from lightning.pytorch import seed_everything

seed_everything(42)

net = MDTC_training(0.001, 12)

trainer = L.Trainer(accelerator="gpu",
                    callbacks = [early_stopping_callback, checkpoint_callback,
                                 lr_monitor],
                    enable_checkpointing=True,
                    default_root_dir = "/content/mdtc1",
                    max_epochs=100)
trainer.fit(net, data, ckpt_path = '/content/best_model/mdtc-gsc-12-epoch=22-val_loss=0.26-val_acc=0.92.ckpt')

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /content/best_model exists and is not empty.
INFO: Restoring states from the checkpoint path at /content/best_model/mdtc-gsc-12-epoch=22-val_loss=0.26-val_acc=0.92.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /content/best_model/mdtc-gsc-12-epoch=22-val_loss=0.26-v

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

  torch.nn.utils.clip_grad_norm(self.parameters(), 5)


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
zipzip('/content/best_model', '/content/drive/MyDrive/Dataset/best_model_mdtc.zip')


zipping...: 100%|██████████| 5/5 [00:00<00:00, 55.45it/s]

/content/drive/MyDrive/Dataset/best_model_mdtc.zip created





# Test

In [None]:
labels = ['backward',
 'bed',
 'bird',
 'cat',
 'dog',
 'down',
 'eight',
 'five',
 'follow',
 'forward',
 'four',
 'go',
 'happy',
 'house',
 'learn',
 'left',
 'marvin',
 'nine',
 'no',
 'off',
 'on',
 'one',
 'right',
 'seven',
 'sheila',
 'six',
 'stop',
 'three',
 'tree',
 'two',
 'up',
 'visual',
 'wow',
 'yes',
 'zero']
len(labels)

35

In [None]:
import lightning as L
from torchaudio import datasets
import torchtext
from torch.utils.data import Dataset, DataLoader

In [None]:
from MDTC import MDTC

In [None]:
train_dataset = datasets.SPEECHCOMMANDS('./', 'speech_commands_v0.02', download = True, subset = 'training')

100%|██████████| 2.26G/2.26G [01:47<00:00, 22.6MB/s]


In [None]:
class SC(L.LightningDataModule):
    def __init__(self, root, batch_size, train_transform = None, test_transform = None):
        super().__init__()
        self.root = root
        self.batch_size = batch_size
        self.train_dataset = datasets.SPEECHCOMMANDS(root, 'speech_commands_v0.02', download = True, subset = 'training')
        self.val_dataset = datasets.SPEECHCOMMANDS(root, 'speech_commands_v0.02', download = True, subset = 'validation')
        self.test_dataset = datasets.SPEECHCOMMANDS(root, 'speech_commands_v0.02', download = True, subset = 'testing')
        self.train_transform = train_transform
        self.test_transform = test_transform

    def collate_fn(self, batch, transform):
        data_batch = []
        tar_batch = []
        #for waveform, _, label, *_ in batch:
        for wav, _, label, *_ in batch:
            wav = transform(wav)
            data_batch.append(wav)
            tar_batch.append(labels.index(label))
        return torch.stack(data_batch), torch.tensor(tar_batch)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size = self.batch_size,
                          shuffle = True,
                          collate_fn = lambda x: self.collate_fn(x, self.train_transform),
                          num_workers = 1,
                          prefetch_factor = 1)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size = self.batch_size,
                          shuffle = False,
                          collate_fn = lambda x: self.collate_fn(x, self.test_transform),
                          num_workers = 1,
                          prefetch_factor = 1)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size = self.batch_size,
                          shuffle = False,
                          collate_fn = lambda x: self.collate_fn(x, self.test_transform),
                          num_workers = 1,
                          prefetch_factor = 1)

In [None]:
import random

def pad_truncate(wav, max_length = 16000, pad_value = 0):
    wav_length = len(wav)
    if wav_length < max_length:
        pad = torchtext.transforms.PadTransform(max_length, 0)
        wav = pad(wav)
    return wav

def time_shift(wav, shift, sr = 16000, max_length = 16000):
    """

    shift: float
        Unit: Seconds
    """
    wav = torch.roll(wav, int(shift*sr))
    return wav[:, :max_length]

## Add Noise

def normalzieNoise(wav, noise, max_length = 16000):
    len_wav = wav.shape[1]
    len_noise = noise.shape[1]
    if len_wav > len_noise:
        buf = torch.zeros_like(wav)
        start_point = int((len_wav - len_noise)*random.uniform(0, 1))
        end_point = start_point + len_noise
        buf[:, start_point: end_point] = noise
        noise = buf
    elif len_wav < len_noise:
        start_point = int((len_noise - len_wav)*random.uniform(0, 1))
        end_point = start_point + len_wav
        noise = noise[:, start_point: end_point]
    return noise[:, :max_length]

def randomNoise(noise_directory):
    listnoise = [f for f in os.listdir(noise_directory)
                 if f.endswith('.wav')]
    noise = random.choice(listnoise)
    noise, sr = torchaudio.load(os.path.join(noise_directory, noise))
    return noise

def addNoise(wav, noise):
    noise = normalzieNoise(wav, noise)
    addnsy = torchaudio.transforms.AddNoise()
    return addnsy(wav, noise, snr = torch.Tensor([random.uniform(0, 15)]))

class AddBGNoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.NOISE_PATH = '/content/SpeechCommands/speech_commands_v0.02/_background_noise_'

    def forward(self, x):
        p = random.uniform(0, 1)
        if p >= 0.8:
            noise = randomNoise(self.NOISE_PATH)
            x = addNoise(x, noise)
        return x

class GSC_TrainAugment(nn.Module):
    def __init__(self, sr):
        super().__init__()
        #self.resample = torchaudio.transforms.Resample(sr, int(sr*random.uniform(0.85, 1.15)))
        self.time_shift = lambda x: time_shift(x, random.uniform(-0.1, 0.1))
        self.pad_trunc = lambda x: pad_truncate(x, sr)
        self.add_noise = AddBGNoise()
        self.mel = torchaudio.transforms.MelSpectrogram(sr, n_mels = 40,
                                                        win_length=480,
                                                        n_fft = 480,
                                                        hop_length = 160)
        self.specaugment = torchaudio.transforms.SpecAugment(n_time_masks = 2,
                                      time_mask_param = 20,
                                      n_freq_masks = 2,
                                      freq_mask_param = 7)
    def forward(self, x):
        #x = self.resample(x)
        x = self.time_shift(x)
        x = self.pad_trunc(x)
        x = self.add_noise(x)
        x = self.mel(x)
        x = self.specaugment(x)
        return x

class GSC_TestAugment(nn.Module):
    def __init__(self, sr):
        super().__init__()
        self.pad_trunc = lambda x: pad_truncate(x, sr)
        self.mel = torchaudio.transforms.MelSpectrogram(sr, n_mels = 40,
                                                        win_length = 480,
                                                        n_fft = 480,
                                                        hop_length = 160)

    def forward(self, x):
        x = self.pad_trunc(x)
        x = self.mel(x)
        return x

train_transform = GSC_TrainAugment(16000)
test_transform = GSC_TestAugment(16000)

In [None]:
x = torch.Tensor(1, 15000)
train_transform(x).shape

torch.Size([1, 40, 101])

In [None]:
class MDTC_training(L.LightningModule):
    def __init__(self, lr, num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self.automatic_optimization = False
        self.lr = lr
        self.net = MDTC(in_channels = 40,
             out_channels = 64,
             kernel_size = 5,
             stack_num = 4,
             stack_size = 4,
             classification = True,
             hidden_size = 64,
             num_classes = num_classes)

    def forward(self, input):
        return self.net(input.squeeze())

In [None]:
@add_to_class(MDTC_training)
def accuracy(self, Y_hat, Y, averaged = True):
    """
    Compute the number of correct predictions
    """
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(dim = 1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

@add_to_class(MDTC_training)
def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)

    # single scheduler
    #sch.step()

    values = {"train_loss": loss, "train_acc": acc}
    self.log_dict(values, prog_bar = True)
    return loss

@add_to_class(MDTC_training)
def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.forward(x)
    loss = self.loss(y_hat, y)
    acc = self.accuracy(y_hat, y)
    values = {"val_loss": loss, "val_acc": acc}
    self.log_dict(values, prog_bar = True)
    return values

@add_to_class(MDTC_training)
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), self.lr)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 3, gamma = 0.8)
    #return [optimizer], [lr_scheduler]
    return optimizer

#@add_to_class(MDTC_training)
#def lr_schedulers(self):
#    lr_scheduler = torch.optim.lr_scheduler.StepLR(self.configure_optimizers(), step_size = 3, gamma = 0.8)
#    return lr_scheduler

@add_to_class(MDTC_training)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # update params
    optimizer.step(closure = optimizer_closure)

    # manually warm up lr withou a scheduler
    if self.trainer.global_step < 2000:
        lr_scale = self.trainer.global_step/2000
    else:
        lr_scale = (2000/self.trainer.global_step)**0.5

    for pg in optimizer.param_groups:
        pg['lr'] = lr_scale*self.lr

@add_to_class(MDTC_training)
def loss(self, y_hat, y):
    return F.cross_entropy(y_hat, y, reduction = 'mean')

In [None]:
data = SC('./', 128, train_transform, test_transform)

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
early_stopping_callback = EarlyStopping(monitor = "val_acc", min_delta = 0.001, patience = 5, mode = "max")
checkpoint_callback = ModelCheckpoint(dirpath = '/content/best_model',
                                      save_top_k = 5, monitor = 'val_acc',
                                      mode = 'max',
                                      filename = 'mdtc-gsc-12-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}')
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
from lightning.pytorch import seed_everything

seed_everything(42)

net = MDTC_training(0.004, 35)

trainer = L.Trainer(accelerator="gpu",
                    callbacks = [early_stopping_callback, checkpoint_callback,
                                 lr_monitor],
                    enable_checkpointing=True,
                    default_root_dir = "/content",
                    max_epochs=100)
trainer.fit(net, data)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name | Type | Params
------------------------------
0 | net  | MDTC | 159 K 
------------------------------
159 K     Trainable params
0         Non-trainable params
159 K     Total params
0.637     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]