In [1]:
import os
from torch.utils.tensorboard import SummaryWriter
datadir = "speech_commands"

samples_by_target = {
    cls: [os.path.join(datadir, cls, name) for name in os.listdir("./speech_commands/{}".format(cls))]
    for cls in os.listdir(datadir)
    if os.path.isdir(os.path.join(datadir, cls))
}
print('Classes:', ', '.join(sorted(samples_by_target.keys())[1:]))

Classes: absfly, four, go, happy, house, left, marvin, nine, no, off, on, one, right, seven, six, stop, three, tree, two, up, wow, yes, zero


In [2]:
import pandas as pd
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

import torchaudio
from IPython import display as display_

In [3]:
def add_rand_noise(audio):
    background_noises = [
        'speech_commands/_background_noise_/white_noise.wav',
       'speech_commands/_background_noise_/dude_miaowing.wav',
       'speech_commands/_background_noise_/doing_the_dishes.wav',
       'speech_commands/_background_noise_/exercise_bike.wav',
       'speech_commands/_background_noise_/pink_noise.wav',
       'speech_commands/_background_noise_/running_tap.wav'
    ]
    
    noise_num = torch.randint(low=0, high=len(background_noises), size=(1,)).item()    
    noise = torchaudio.load(background_noises[noise_num])[0].squeeze()    
    
    noise_level = torch.Tensor([1])  # [0, 40]

    noise_energy = torch.norm(noise)
    audio_energy = torch.norm(audio)
    alpha = (audio_energy / noise_energy) * torch.pow(10, -noise_level / 20)

    start = torch.randint(low=0, high=int(noise.size(0) - audio.size(0) - 1), size=(1,)).item()
    noise_sample = noise[start : start + audio.shape[0]]

    audio_new = audio + alpha * noise_sample
    audio_new.clamp_(-1, 1)
    return audio_new

In [4]:
import torch
import random
import numpy as np

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch import distributions

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence

In [5]:
BATCH_SIZE = 256
NUM_EPOCHS = 35
N_MELS     = 40

In [6]:
class TrainDataset(torch.utils.data.Dataset):
    """Custom competition dataset."""

    def __init__(self, root='', csv_path='labels_sheila.csv', kw='sheila', transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root = root
        self.kw = kw
        self.csv = pd.read_csv(csv_path)
        self.transform = transform
        

    def __len__(self):
        return self.csv.shape[0]
    

    def __getitem__(self, idx):
        utt_name = self.root + self.csv.loc[idx, 'name']
        utt = torchaudio.load(utt_name)[0].squeeze()
        word = self.csv.loc[idx, 'word']
        label = self.csv.loc[idx, 'label']
        
        if self.transform:
            utt = self.transform(utt)

        sample = {'utt': utt, 'word': word, 'label': label}
        return sample

In [7]:
def transform_tr(wav):
    aug_num = torch.randint(low=0, high=4, size=(1,)).item()
    augs = [
        lambda x: x,
        lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
        lambda x: torchaudio.transforms.Vol(.25)(x),
        lambda x: add_rand_noise(x)
    ]
    
    return augs[aug_num](wav)

In [8]:

my_dataset = TrainDataset(csv_path='labels.csv', transform=transform_tr)
print('all train+val samples:', len(my_dataset))

all train+val samples: 50972


In [9]:
train_len = 44855
val_len = 50972 - train_len 
train_set, val_set = torch.utils.data.random_split(my_dataset, [train_len, val_len])
print(train_set.dataset.csv['label'])

0        0
1        0
2        0
3        0
4        0
        ..
50967    1
50968    1
50969    1
50970    1
50971    1
Name: label, Length: 50972, dtype: int64


In [10]:
def get_sampler(target):
    class_sample_count = np.array(
        [len(np.where(target == t)[0]) for t in np.unique(target)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in target])

    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

In [11]:
train_sampler = get_sampler(train_set.dataset.csv['label'][train_set.indices].values)
val_sampler   = get_sampler(val_set.dataset.csv['label'][val_set.indices].values)

In [12]:
def preprocess_data(data):
    wavs = []
    labels = []    
        
    for el in data:
        wavs.append(el['utt'])
        labels.append(el['label'])
    wavs = pad_sequence(wavs, batch_first=True)
    labels = torch.Tensor(labels).type(torch.long)
    return wavs, labels

In [13]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                          shuffle=False, collate_fn=preprocess_data, 
                          sampler=train_sampler, drop_last=False,
                          num_workers=1, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=preprocess_data, 
                        sampler=val_sampler, drop_last=False,
                        num_workers=1, pin_memory=True)

In [14]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
set_seed(21)

In [15]:
def count_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [17]:
torch.cuda.is_available()

True

In [18]:
# with augmentations
melspec_train = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000,  n_mels=N_MELS),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
    torchaudio.transforms.TimeMasking(time_mask_param=35),
).to(device)


# no augmentations
melspec_val = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000,
    n_mels=N_MELS
).to(device)

In [19]:
def count_FA_FR(preds, labels):
    FA = torch.sum(preds[labels == 0])
    FR = torch.sum(labels[preds == 0])
    return FA.item()/torch.numel(preds), FR.item()/torch.numel(preds)

In [20]:
IN_SIZE = 40
HIDDEN_SIZE = 128
KERNEL_SIZE = (20, 5)
STRIDE = (8, 2)
GRU_NUM_LAYERS = 2
NUM_DIRS = 2
NUM_CLASSES = 2

In [21]:
def sepconv(in_size, out_size, kernel_size, stride=1, dilation=1, padding=0):
    return nn.Sequential(
        torch.nn.Conv1d(in_size, in_size, kernel_size[1], 
                        stride=stride[1], dilation=dilation, groups=in_size,
                        padding=padding),
        
        torch.nn.Conv1d(in_size, out_size, kernel_size=1, 
                        stride=stride[0], groups=int(in_size/kernel_size[0])),
    )

In [22]:
class CRNN(nn.Module):
    def __init__(self, in_size, hidden_size, kernel_size, stride, gru_nl, ):
        super(CRNN, self).__init__()
          
        self.sepconv = sepconv(in_size=in_size, out_size=hidden_size, kernel_size=kernel_size, stride=stride)
        self.gru = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, num_layers=gru_nl, dropout=0.1, bidirectional=True)
        self.init_weights()
        

    def init_weights(self):
        pass

    
    def forward(self, x, hidden):
        x = self.sepconv(x)
        
        # (BS, HS, ?) -> (HS, BS, ?) ->(seq_len, BS, HS)
        x = x.transpose(0, 1).transpose(0, 2)
        
        x, hidden = self.gru(x, hidden)
        # x : (seq_len, BS, HS * num_dirs)
        # hidden : (num_layers * num_dirs, BS, HS)
                        
        return x, hidden

In [23]:
class ApplyAttn(nn.Module):
    def __init__(self, in_size, num_classes):
        super(ApplyAttn, self).__init__()
        self.U = nn.Linear(in_size, num_classes, bias=False)
        
    
    def init_weights(self):
        pass
    
    
    def forward(self, e, data):
        data = data.transpose(0, 1)           # (BS, seq_len, hid_size*num_dirs)
        a = F.softmax(e, dim=-1).unsqueeze(1)
        c = torch.bmm(a, data).squeeze()
        Uc = self.U(c)        
        return F.log_softmax(Uc, dim=-1)

In [24]:
class FullModel(nn.Module):
    def __init__(self, CRNN_model, attn_layer, apply_attn):
        super(FullModel, self).__init__()
        
        self.CRNN_model = CRNN_model
        self.attn_layer = attn_layer
        self.apply_attn = apply_attn

        
    def forward(self, batch, hidden):
        output, hidden = self.CRNN_model(batch, hidden)
        # output: (seq_len, BS, hidden*num_dir)
        
        e = []
        for el in output:
            e_t = self.attn_layer(el)       # -> (BS, 1)
            e.append(e_t)
        e = torch.cat(e, dim=1)        # -> (BS, seq_len)
        
        probs = self.apply_attn(e, output)
        return probs

In [25]:
class AttnMech(nn.Module):
    def __init__(self, lin_size):
        super(AttnMech, self).__init__()
        
        self.Wx_b = nn.Linear(lin_size, lin_size)
        self.Vt   = nn.Linear(lin_size, 1, bias=False)
        
        
    def init_weights(self):
        pass
    
    
    def forward(self, x):
        x = torch.tanh(self.Wx_b(x))
        e = self.Vt(x)
        return e

In [26]:
CRNN_model = CRNN(IN_SIZE, HIDDEN_SIZE, KERNEL_SIZE, STRIDE, GRU_NUM_LAYERS)
attn_layer = AttnMech(HIDDEN_SIZE * NUM_DIRS)
apply_attn = ApplyAttn(HIDDEN_SIZE * 2, NUM_CLASSES)

full_model = FullModel(CRNN_model, attn_layer, apply_attn)
writer  = SummaryWriter('runs/experiment_2')
print(full_model.to(device))

FullModel(
  (CRNN_model): CRNN(
    (sepconv): Sequential(
      (0): Conv1d(40, 40, kernel_size=(5,), stride=(2,), groups=40)
      (1): Conv1d(40, 128, kernel_size=(1,), stride=(8,), groups=2)
    )
    (gru): GRU(128, 128, num_layers=2, dropout=0.1, bidirectional=True)
  )
  (attn_layer): AttnMech(
    (Wx_b): Linear(in_features=256, out_features=256, bias=True)
    (Vt): Linear(in_features=256, out_features=1, bias=False)
  )
  (apply_attn): ApplyAttn(
    (U): Linear(in_features=256, out_features=2, bias=False)
  )
)


In [41]:
def get_au_fa_fr(probs, labels, device):
    sorted_probs, indices = torch.sort(probs)

    sorted_probs = torch.cat((torch.Tensor([0]).to(device), sorted_probs))
    sorted_probs = torch.cat((sorted_probs, torch.Tensor([1]).to(device)))
    labels = torch.cat(labels, dim=0)
        
    FAs, FRs = [], []
    for prob in sorted_probs:
        ones = (probs >= prob) * 1
        FA, FR = count_FA_FR(ones, labels)        
        FAs.append(FA)
        FRs.append(FR)
    # plt.plot(FAs, FRs)
    # plt.show()
    return -np.trapz(FRs, x=FAs)

In [42]:
opt = torch.optim.Adam(full_model.parameters(), weight_decay=1e-5)

In [43]:
def train_epoch(model, opt, loader, melspec, gru_nl, hidden_size, epoch, device):
    model.train()
    loss = None
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = torch.log(melspec(batch) + 1e-9).to(device)

        opt.zero_grad()
        
        # define frist hidden with 0
        hidden = torch.zeros(gru_nl*2, batch.size(0), hidden_size).to(device)    # (num_layers*num_dirs,  BS, HS)
        # run model
        probs = model(batch, hidden)
        loss = F.nll_loss(probs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        
        opt.step()
        

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)                
        FA, FR = count_FA_FR(argmax_probs, labels)
        acc = torch.true_divide(
                            torch.sum(argmax_probs == labels), 
                            torch.numel(argmax_probs)
        )
        #wandb.log({'loss':loss.item(), 'train_FA':FA, 'train_FR':FR, 'train_acc':acc})
        writer.add_scalar('Loss/train', loss.item(), epoch * len(loader) + i)

    return loss

In [44]:
def save_checkpoint(model, optimizer, epoch, loss, filename="checkpoint.pth"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")

In [51]:
def validation(model, loader, melspec, gru_nl, hidden_size, epoch, device):
    model.eval()
    with torch.no_grad():
        val_losses, accs, FAs, FRs = [], [], [], []
        all_probs, all_labels = [], []
        for i, (batch, labels) in tqdm(enumerate(loader)):
            batch, labels = batch.to(device), labels.to(device)
            batch = torch.log(melspec(batch) + 1e-9).to(device)  

            # define frist hidden with 0
            hidden = torch.zeros(gru_nl*2, batch.size(0), hidden_size).to(device)    # (num_layers*num_dirs,  BS, HS)
            # run model
            probs = model(batch, hidden)
            loss = F.nll_loss(probs, labels)
            
            # logging
            argmax_probs = torch.argmax(probs, dim=-1)
            all_probs.append(torch.exp(probs)[:, 1])
            all_labels.append(labels)
            val_losses.append(loss.item())
            accs.append(torch.true_divide(
                                torch.sum(argmax_probs == labels), 
                                torch.numel(argmax_probs)).item()
                       )

            FA, FR = count_FA_FR(argmax_probs, labels)
            FAs.append(FA)
            FRs.append(FR)
            
        # area under FA/FR curve for whole loader
        au_fa_fr = get_au_fa_fr(torch.cat(all_probs, dim=0), all_labels, device)    
        writer.add_scalar('Accuracy/train', np.mean(accs), epoch * len(loader) + i)


In [52]:
for n in range(NUM_EPOCHS):
    
    loss = train_epoch(full_model, opt, train_loader, melspec_train, 
          GRU_NUM_LAYERS, HIDDEN_SIZE, n, device=device)           


    validation(full_model, val_loader, melspec_val, 
        GRU_NUM_LAYERS, HIDDEN_SIZE, n, device=device)

    print('END OF EPOCH', n)
    if n % 1 == 0:
        save_checkpoint(full_model, opt, n, loss)

176it [00:52,  3.34it/s]
24it [00:07,  3.29it/s]


END OF EPOCH 0
Checkpoint saved: checkpoint.pth


  return -np.trapz(FRs, x=FAs)
176it [00:53,  3.27it/s]
24it [00:07,  3.25it/s]


END OF EPOCH 1
Checkpoint saved: checkpoint.pth


176it [00:54,  3.26it/s]
24it [00:07,  3.20it/s]


END OF EPOCH 2
Checkpoint saved: checkpoint.pth


176it [00:52,  3.36it/s]
24it [00:07,  3.21it/s]


END OF EPOCH 3
Checkpoint saved: checkpoint.pth


176it [00:55,  3.19it/s]
24it [00:07,  3.28it/s]


END OF EPOCH 4
Checkpoint saved: checkpoint.pth


176it [00:51,  3.40it/s]
24it [00:07,  3.20it/s]


END OF EPOCH 5
Checkpoint saved: checkpoint.pth


176it [00:56,  3.14it/s]
24it [00:07,  3.16it/s]


END OF EPOCH 6
Checkpoint saved: checkpoint.pth


176it [00:54,  3.20it/s]
24it [00:07,  3.23it/s]


END OF EPOCH 7
Checkpoint saved: checkpoint.pth


176it [00:56,  3.10it/s]
24it [00:07,  3.29it/s]


END OF EPOCH 8
Checkpoint saved: checkpoint.pth


176it [00:54,  3.22it/s]
24it [00:07,  3.13it/s]


END OF EPOCH 9
Checkpoint saved: checkpoint.pth


176it [00:53,  3.31it/s]
24it [00:07,  3.32it/s]


END OF EPOCH 10
Checkpoint saved: checkpoint.pth


176it [00:53,  3.32it/s]
24it [00:07,  3.27it/s]


END OF EPOCH 11
Checkpoint saved: checkpoint.pth


176it [00:53,  3.27it/s]
24it [00:07,  3.23it/s]


END OF EPOCH 12
Checkpoint saved: checkpoint.pth


176it [00:53,  3.26it/s]
24it [00:07,  3.37it/s]


END OF EPOCH 13
Checkpoint saved: checkpoint.pth


176it [00:52,  3.36it/s]
24it [00:07,  3.34it/s]


END OF EPOCH 14
Checkpoint saved: checkpoint.pth


176it [00:53,  3.27it/s]
24it [00:07,  3.25it/s]


END OF EPOCH 15
Checkpoint saved: checkpoint.pth


176it [00:53,  3.30it/s]
24it [00:07,  3.15it/s]


END OF EPOCH 16
Checkpoint saved: checkpoint.pth


176it [00:53,  3.30it/s]
24it [00:07,  3.13it/s]


END OF EPOCH 17
Checkpoint saved: checkpoint.pth


176it [00:52,  3.35it/s]
24it [00:07,  3.24it/s]


END OF EPOCH 18
Checkpoint saved: checkpoint.pth


176it [00:53,  3.30it/s]
24it [00:07,  3.19it/s]


END OF EPOCH 19
Checkpoint saved: checkpoint.pth


176it [00:53,  3.29it/s]
24it [00:07,  3.18it/s]


END OF EPOCH 20
Checkpoint saved: checkpoint.pth


176it [00:52,  3.35it/s]
24it [00:07,  3.22it/s]


END OF EPOCH 21
Checkpoint saved: checkpoint.pth


176it [00:53,  3.26it/s]
24it [00:07,  3.26it/s]


END OF EPOCH 22
Checkpoint saved: checkpoint.pth


176it [00:54,  3.20it/s]
24it [00:07,  3.13it/s]


END OF EPOCH 23
Checkpoint saved: checkpoint.pth


176it [00:53,  3.30it/s]
24it [00:07,  3.22it/s]


END OF EPOCH 24
Checkpoint saved: checkpoint.pth


176it [00:52,  3.38it/s]
24it [00:07,  3.22it/s]


END OF EPOCH 25
Checkpoint saved: checkpoint.pth


176it [00:52,  3.33it/s]
24it [00:07,  3.11it/s]


END OF EPOCH 26
Checkpoint saved: checkpoint.pth


176it [00:53,  3.28it/s]
24it [00:07,  3.20it/s]


END OF EPOCH 27
Checkpoint saved: checkpoint.pth


176it [00:55,  3.16it/s]
24it [00:07,  3.27it/s]


END OF EPOCH 28
Checkpoint saved: checkpoint.pth


176it [00:53,  3.26it/s]
24it [00:07,  3.14it/s]


END OF EPOCH 29
Checkpoint saved: checkpoint.pth


176it [00:54,  3.24it/s]
24it [00:07,  3.26it/s]


END OF EPOCH 30
Checkpoint saved: checkpoint.pth


176it [00:53,  3.27it/s]
24it [00:07,  3.20it/s]


END OF EPOCH 31
Checkpoint saved: checkpoint.pth


176it [00:55,  3.16it/s]
24it [00:07,  3.15it/s]


END OF EPOCH 32
Checkpoint saved: checkpoint.pth


176it [00:53,  3.26it/s]
24it [00:07,  3.10it/s]


END OF EPOCH 33
Checkpoint saved: checkpoint.pth


176it [00:52,  3.34it/s]
24it [00:07,  3.18it/s]


END OF EPOCH 34
Checkpoint saved: checkpoint.pth
