In [30]:
import os
import re
from tqdm.auto import tqdm
from glob import glob

import pandas as pd
import numpy as np
import sklearn
from sklearn.metrics import accuracy_score, log_loss, f1_score

import librosa
from audiomentations import Compose, AddGaussianNoise, Shift, TimeStretch, PitchShift

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch 
import torch.nn as nn
import torchvision

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [31]:
data_dir = './dataset'
train_npy_dir = './dataset/train_npy'
test_npy_dir = './dataset/test_npy'

In [32]:
train_df = pd.read_pickle(os.path.join(data_dir, 'new_train.pkl'))
test_df = pd.read_pickle(os.path.join(data_dir, 'test.pkl'))

In [33]:
import easydict

args = easydict.EasyDict({'sr': 16000,
                          'n_mels': 128,
                          'n_fft': [1024],
                          'win_length': [600],
                          'hop_length':120,
                          'min_length': 120000,
                          'min_level_db': -80,
                          'lr': 1e-4,
                          'epochs':20,
                          'seed': 2021,
                          'batch_num':32,
                          'fp16': True
})

pl.seed_everything(args['seed'])

Global seed set to 2021


2021

In [None]:
# def get_length(path_list):
#     length = []
#     for i in tqdm(path_list):
#         audio = np.load(i)
#         length.append(audio.shape[0])
#     return length

# train_length = get_length(train_df['path'])

In [27]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, hparams, csv, transform=None):
        self.hparams = hparams
        self.csv = csv.reset_index(drop=True)
        self.aug = Compose([AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
                            TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
                            PitchShift(min_semitones=-4,
                                       max_semitones=4, p=0.5),
                            Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5)])
        self.transform = transform

    def __len__(self):
        return len(self.csv)
    
    def generate_mel(self, audio, sr, n_fft, win_length, hop_length, n_mels):
        S = librosa.feature.melspectrogram(y=audio,
                                           sr=sr,
                                           n_fft=n_fft,
                                           win_length=win_length,
                                           hop_length=hop_length,
                                           n_mels=n_mels)
        S = librosa.power_to_db(S, ref=np.max)
        S = np.clip((S - self.hparams.min_level_db) / -
                    self.hparams.min_level_db, 0, 1)
        return S
    
    def features_extractor(self, audio_path):
        audio = np.load(audio_path)
        audio = librosa.util.fix_length(audio, self.hparams.min_length)
        
        if self.transform:
            audio = self.aug(audio, sample_rate=self.hparams.sr)

        mel = []
        for n_fft, win_length in zip(self.hparams.n_fft, self.hparams.win_length):
            S = self.generate_mel(audio,
                                  self.hparams.sr,
                                  n_fft, win_length, self.hparams.hop_length,
                                  self.hparams.n_mels)
            mel.append(S)
        return np.array(mel)

    def __getitem__(self, index):
        path = self.csv.iloc[index, -1]
        label = self.csv.iloc[index, 1]
        mel = self.features_extractor(path)
        return (
            torch.tensor(mel, dtype=torch.float),
            torch.tensor(label, dtype=torch.long)
        )

In [34]:
class audio_resnet34(nn.Module):
    def __init__(self):
        super().__init__()
        self.org_model = torchvision.models.resnet34()
        self.org_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.features = nn.Sequential(*(list(self.org_model.children())[:-1]))
        self.fc = nn.Linear(512, 6)

    def forward(self, x):
        x = self.features(x)
        bs,c,w,h = x.size()
        x = x.view(bs, -1)
        out = self.fc(x)
        return out

In [35]:
class build_fn(pl.LightningModule):
    def __init__(self, hparams, train_loader=None, val_loader=None):
        super().__init__()
        self.hparams = hparams
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = audio_resnet34()
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, x):
        return self.model(x)

    def step(self, batch, batch_idx):
        x, labels = batch
        output = self(x)
        loss = self.loss_fn(output, labels)
        
        logits = nn.functional.softmax(output)
        
        y_true = list(labels.detach().cpu().numpy())
        y_pred = list(logits.detach().cpu().numpy())

        return {
            'loss': loss,
            'y_true': y_true,
            'y_pred': y_pred,
        }

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx)

    def epoch_end(self, outputs, state='train'):
        loss = 0.0
        y_true = []
        y_pred = []
        
        for i in outputs:
            loss += i['loss'].item()
            y_true += i['y_true']
            y_pred += i['y_pred']
            
        loss = loss / len(outputs)

        self.log(state+'_loss', float(loss), on_epoch=True, prog_bar=True)
        self.log(state+'_acc', accuracy_score(y_true, np.argmax(y_pred, axis=-1)), on_epoch=True, prog_bar=True, logger=True)
        self.log(state+'_f1', f1_score(y_true, np.argmax(y_pred, axis=-1), average='weighted'), on_epoch=True, prog_bar=True, logger=True)
        return {'loss': loss}

    def train_epoch_end(self, outputs):
        return self.epoch_end(outputs, state='train')

    def validation_epoch_end(self, outputs):
        return self.epoch_end(outputs, state='val')

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = {'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                                             patience=2, 
                                                                             mode='min', verbose=True),
                     'interval': 'epoch',
                     'monitor': 'val_loss'}
        return [optimizer], [scheduler]
                 
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

In [None]:
skf = sklearn.model_selection.StratifiedKFold(n_splits=5, random_state=args['seed'], shuffle=True)

for fold_, (trn_idx, val_idx) in enumerate(skf.split(train_df.values, train_df['accent'])):
    trn_df, val_df = train_df.iloc[trn_idx], train_df.iloc[val_idx]
    
    train_ds = AudioDataset(args, trn_df, train=True, transform=True)
    valid_ds = AudioDataset(args, val_df, train=False, transform=False)
    
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args['batch_num'], shuffle=True, num_workers=4, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=args['batch_num'], shuffle=False, num_workers=4, pin_memory=False)
    
    checkpoint_callback = ModelCheckpoint(
        filename= '{epoch}-{val_acc:.2f}-{val_loss:.3f}',
        monitor='val_loss',
        save_top_k=1,
        mode='min')
    
    early_stop_callback = EarlyStopping(monitor='val_loss', 
                                        patience=4, 
                                        verbose=True, 
                                        mode='min')
    
    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, early_stop_callback],
        max_epochs=args['epochs'],
        deterministic=torch.cuda.is_available(),
        gpus=-1 if torch.cuda.is_available() else None,
        precision= 16 if args['fp16'] else 32)

    pl_model = build_fn(args, train_loader, valid_loader)
    
    trainer.fit(pl_model)
    print(checkpoint_callback.best_model_path)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.

  | Name    | Type             | Params
---------------------------------------------
0 | model   | audio_resnet34   | 21.8 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
21.8 M    Trainable params
0         Non-trainable params
21.8 M    Total params
87.178    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Inference

In [36]:
test_df.head()

Unnamed: 0,path,id
0,./dataset/test_npy/1.npy,1
1,./dataset/test_npy/2.npy,2
2,./dataset/test_npy/3.npy,3
3,./dataset/test_npy/4.npy,4
4,./dataset/test_npy/5.npy,5


In [37]:
class Test_AudioDataset(torch.utils.data.Dataset):
    def __init__(self, hparams, csv, transform=None):
        self.hparams = hparams
        self.csv = csv.reset_index(drop=True)
        self.aug = Compose([AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
                            TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
                            PitchShift(min_semitones=-4,
                                       max_semitones=4, p=0.5),
                            Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5)])
        self.transform = transform

    def __len__(self):
        return len(self.csv)
    
    def generate_mel(self, audio, sr, n_fft, win_length, hop_length, n_mels):
        S = librosa.feature.melspectrogram(y=audio,
                                           sr=sr,
                                           n_fft=n_fft,
                                           win_length=win_length,
                                           hop_length=hop_length,
                                           n_mels=n_mels)
        S = librosa.power_to_db(S, ref=np.max)
        S = np.clip((S - self.hparams.min_level_db) / -
                    self.hparams.min_level_db, 0, 1)
        return S
    
    def features_extractor(self, audio_path):
        audio = np.load(audio_path)
        audio = librosa.util.fix_length(audio, self.hparams.min_length)
        
        if self.transform:
            audio = self.aug(audio, sample_rate=self.hparams.sr)

        mel = []
        for n_fft, win_length in zip(self.hparams.n_fft, self.hparams.win_length):
            S = self.generate_mel(audio,
                                  self.hparams.sr,
                                  n_fft, win_length, self.hparams.hop_length,
                                  self.hparams.n_mels)
            mel.append(S)
        return np.array(mel)

    def __getitem__(self, index):
        path = self.csv.iloc[index, 0]
        mel = self.features_extractor(path)
        return torch.tensor(mel, dtype=torch.float)

In [38]:
test_ds = Test_AudioDataset(args, test_df, transform=False)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=256, num_workers=4)

In [None]:
# PATHS = ["./lightning_logs/version_0/checkpoints/epoch=15-val_acc=0.78-val_loss=0.679.ckpt",
#          "./lightning_logs/version_1/checkpoints/epoch=8-val_acc=0.72-val_loss=0.782.ckpt",
#          "./lightning_logs/version_2/checkpoints/epoch=10-val_acc=0.76-val_loss=0.707.ckpt",
#          "./lightning_logs/version_3/checkpoints/epoch=7-val_acc=0.73-val_loss=0.780.ckpt",
#          "./lightning_logs/version_4/checkpoints/epoch=10-val_acc=0.76-val_loss=0.676.ckpt"]

In [None]:
# PATHS = ["./lightning_logs/version_5/checkpoints/epoch=16-val_acc=0.80-val_loss=0.651.ckpt",
#          "./lightning_logs/version_6/checkpoints/epoch=16-val_acc=0.82-val_loss=0.581.ckpt",
#          "./lightning_logs/version_7/checkpoints/epoch=14-val_acc=0.81-val_loss=0.564.ckpt",
#          "./lightning_logs/version_8/checkpoints/epoch=15-val_acc=0.84-val_loss=0.492.ckpt",
#          "./lightning_logs/version_9/checkpoints/epoch=18-val_acc=0.84-val_loss=0.499.ckpt"]

In [40]:
PATHS = ["./lightning_logs/version_0/checkpoints/epoch=17-val_acc=0.82-val_loss=0.539.ckpt",
         "./lightning_logs/version_1/checkpoints/epoch=19-val_acc=0.82-val_loss=0.519.ckpt",
         "./lightning_logs/version_2/checkpoints/epoch=18-val_acc=0.84-val_loss=0.439.ckpt",
         "./lightning_logs/version_3/checkpoints/epoch=19-val_acc=0.83-val_loss=0.495.ckpt",
         "./lightning_logs/version_4/checkpoints/epoch=19-val_acc=0.81-val_loss=0.555.ckpt"]

In [41]:
def ensemble_fn(test_loader, ckpt_paths, device):
    final_preds = np.zeros(shape=(6100,6))
    
    for path in ckpt_paths:
        model = build_fn.load_from_checkpoint(path)
        model.to(device)
        model.eval()
        
        pred = []
        for _, x in enumerate(tqdm(test_loader)):
            x = x.to(device)
            with torch.no_grad():
                output = model(x)
                logit = torch.nn.functional.softmax(output)
                pred.append(logit.cpu().numpy())        
        final_preds += np.concatenate(pred)
    return final_preds/len(ckpt_paths)

In [42]:
final = ensemble_fn(test_loader, PATHS, 'cuda:0')

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

In [43]:
final.shape

(6100, 6)

In [44]:
sub_df = pd.read_csv(os.path.join(data_dir, 'sample_submission.csv'))
sub_df.iloc[:, 1:] = final

In [45]:
sub_df.head()

Unnamed: 0,id,africa,australia,canada,england,hongkong,us
0,1,0.018227,0.036662,0.005901,0.366862,0.03234,0.540008
1,2,0.005408,0.006099,0.003706,0.827213,0.003215,0.154358
2,3,0.104748,0.015337,0.000781,0.801053,0.00052,0.077561
3,4,0.026231,0.008206,0.001176,0.632881,0.001815,0.329692
4,5,0.146706,0.002311,0.00517,0.381705,0.014126,0.449983


In [46]:
sub_df.to_csv('./infer/resnet34_ensemble.csv', index=False)