In [1]:
!pip install torch==1.6.0
!pip install opencv-python
!pip install torchvision==0.7.0
!pip install albumentations
!pip install tensorflow
!pip install pytorch-lightning

Defaulting to user installation because normal site-packages is not writeable
Collecting torch==1.6.0
  Downloading torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl (748.8 MB)
[K     |████████████████████████████████| 748.8 MB 16 kB/s 
Installing collected packages: torch
Successfully installed torch-1.6.0
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
Collecting opencv-python
  Downloading opencv_python-4.5.1.48-cp36-cp36m-manylinux2014_x86_64.whl (50.4 MB)
[K     |████████████████████████████████| 50.4 MB 4.1 MB/s 
Installing collected packages: opencv-python
Successfully installed opencv-python-4.5.1.48
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade 

In [4]:
!pip install audiomentations

Defaulting to user installation because normal site-packages is not writeable
Collecting audiomentations
  Downloading audiomentations-0.15.0-py3-none-any.whl (24 kB)
Installing collected packages: audiomentations
Successfully installed audiomentations-0.15.0
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import typing as tp
import yaml
import random
import os
import sys
import soundfile as sf
import librosa
import cv2
import matplotlib.pyplot as plt
import time
import glob

import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# import resnest.torch as resnest_torch

from torchvision import models

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
# from resnet import ResNet, Bottleneck

from albumentations.core.transforms_interface import DualTransform, BasicTransform
import albumentations as A
import audiomentations as AA


from sklearn.model_selection import StratifiedKFold

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

## util

In [2]:
config_set = {
    'dataset': {
          'name': 'SpectrogramDataset',
          'params': {
            'img_size': 224, 
            'melspectrogram_parameters': {
              'n_mels': 128, 
              'fmin': 50, 
              'fmax': 15000, 
            }
      }
    },
    'loader': {
      'train': {
        'batch_size': 6,
        'shuffle': True,
        'num_workers': 2,
        'pin_memory': True,
        'drop_last': True,
      },
      'valid': {
        'batch_size': 2,
        'shuffle': False,
        'num_workers': 2,
        'pin_memory': True,
        'drop_last': True,
      }
    }
}
SEED=1213
PERIOD = 5
OK_RANGE = 60
SPECIES_NUM = 24
EPOCH = 80
OUTPUT_DIR = './output/'
HOP_LEN = 512
SR = 48000

In [3]:
config = config_set

In [4]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
set_seed(SEED)

In [5]:
INPUT_ROOT = Path("/home/knikaido/work/Rainforest-Connection/data")
RAW_DATA = INPUT_ROOT / "rfcx-species-audio-detection"
TRAIN_AUDIO_DIR = RAW_DATA / "train"
# TRAIN_RESAMPLED_AUDIO_DIRS = [
#   INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)
# ]
TEST_AUDIO_DIR = RAW_DATA / "test"

In [6]:
class AudioTransform(BasicTransform):
    """ Transform for audio task. This is the main class where we override the targets and update params function for our need"""

    @property
    def targets(self):
        return {"data": self.apply}
    
    def update_params(self, params, **kwargs):
        if hasattr(self, "interpolation"):
            params["interpolation"] = self.interpolation
        if hasattr(self, "fill_value"):
            params["fill_value"] = self.fill_value
        return params
    
class PitchShift(AudioTransform):
    """ Do time shifting of audio """
    def __init__(self, always_apply=False, p=0.5 , n_steps=None):
        super(PitchShift, self).__init__(always_apply, p)
        '''
        nsteps here is equal to number of semitones
        '''
        
        self.n_steps = n_steps
        
    def apply(self,data,**params):
        '''
        data : ndarray of audio timeseries
        '''        
        return librosa.effects.pitch_shift(data,sr=SR,n_steps=self.n_steps)
    
class AddGaussianNoise(AudioTransform):
    """ Do time shifting of audio """
    def __init__(self, always_apply=False, p=0.5):
        super(AddGaussianNoise, self).__init__(always_apply, p)
        
        
    def apply(self,data,**params):
        '''
        data : ndarray of audio timeseries
        ''' 
        noise = np.random.randn(len(data))
        data_wn = data + 0.05*noise
        return data_wn
    
class NoAugment(AudioTransform):
    """ Do time shifting of audio """
    def __init__(self, always_apply=False):
        super(NoAugment, self).__init__(always_apply)
        
        
    def apply(self,data,**params):
        '''
        data : ndarray of audio timeseries
        ''' 
        return data
    
def get_augmentation():
    train_transform = [
#         PitchShift(p=1.0,n_steps=4),
        AddGaussianNoise(p=0.2),
    ]
    return A.OneOf(train_transform)  # <- Compose

In [7]:
train_audio_transform = AA.Compose([
    AA.AddGaussianNoise(p=0.2, min_amplitude=0.05, max_amplitude=0.05),
#     AA.AddGaussianSNR(p=0.2),
    #AA.AddBackgroundNoise("../input/train_audio/", p=1)
    #AA.AddImpulseResponse(p=0.1),
    #AA.AddShortNoises("../input/train_audio/", p=1)
    #AA.FrequencyMask(min_frequency_band=0.0,  max_frequency_band=0.2, p=0.1),
    #AA.TimeMask(min_band_part=0.0, max_band_part=0.2, p=0.1),
    #AA.PitchShift(min_semitones=-0.5, max_semitones=0.5, p=0.1),
    #AA.Shift(p=0.1),
    #AA.Normalize(p=0.1),
    #AA.ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=1, p=0.05),
    #AA.PolarityInversion(p=0.05),
    #AA.Gain(p=0.2)
])

In [8]:
train_df = pd.read_pickle(RAW_DATA / "train_gby_wav_raw.pkl")
train_df

Unnamed: 0,recording_id,species_id,songtype_id,t_min,f_min,t_max,f_max,name
0,003bec244,14,1,44.5440,2531.250,45.1307,5531.25,/home/knikaido/work/Rainforest-Connection/data...
1,006ab765f,23,1,39.9615,7235.160,46.0452,11283.40,/home/knikaido/work/Rainforest-Connection/data...
2,007f87ba2,12,1,39.1360,562.500,42.2720,3281.25,/home/knikaido/work/Rainforest-Connection/data...
3,0099c367b,17,4,51.4206,1464.260,55.1996,4565.04,/home/knikaido/work/Rainforest-Connection/data...
4,009b760e6,10,1,50.0854,947.461,52.5293,10852.70,/home/knikaido/work/Rainforest-Connection/data...
...,...,...,...,...,...,...,...,...
1211,fe8d9ac40,13,1,53.4720,93.750,54.0960,843.75,/home/knikaido/work/Rainforest-Connection/data...
1212,fea6b438a,4,1,43.5787,2531.250,45.7653,4031.25,/home/knikaido/work/Rainforest-Connection/data...
1213,ff2eb9ce5,0,1,15.2267,5906.250,16.0213,8250.00,/home/knikaido/work/Rainforest-Connection/data...
1214,ffb8d8391,5,1,14.3467,4781.250,16.6987,10406.20,/home/knikaido/work/Rainforest-Connection/data...


In [9]:
train_gby = pd.read_pickle(RAW_DATA / "train_gby.pkl")
train_gby

Unnamed: 0,recording_id,species_id,songtype_id,t_min,f_min,t_max,f_max,name
0,003bec244,[14],[1],[44.544],[2531.25],[45.1307],[5531.25],/home/knikaido/work/Rainforest-Connection/data...
1,006ab765f,[23],[1],[39.9615],[7235.16],[46.0452],[11283.4],/home/knikaido/work/Rainforest-Connection/data...
2,007f87ba2,[12],[1],[39.135999999999996],[562.5],[42.272],[3281.25],/home/knikaido/work/Rainforest-Connection/data...
3,0099c367b,[17],[4],[51.4206],[1464.26],[55.1996],[4565.04],/home/knikaido/work/Rainforest-Connection/data...
4,009b760e6,[10],[1],[50.0854],[947.461],[52.5293],[10852.7],/home/knikaido/work/Rainforest-Connection/data...
...,...,...,...,...,...,...,...,...
1127,fe8d9ac40,[13],[1],[53.472],[93.75],[54.096000000000004],[843.75],/home/knikaido/work/Rainforest-Connection/data...
1128,fea6b438a,[4],[1],[43.5787],[2531.25],[45.7653],[4031.25],/home/knikaido/work/Rainforest-Connection/data...
1129,ff2eb9ce5,[0],[1],[15.2267],[5906.25],[16.0213],[8250.0],/home/knikaido/work/Rainforest-Connection/data...
1130,ffb8d8391,[5],[1],[14.3467],[4781.25],[16.6987],[10406.2],/home/knikaido/work/Rainforest-Connection/data...


In [10]:
def mono_to_color(
    X: np.ndarray, mean=None, std=None,
    norm_max=None, norm_min=None, eps=1e-6
):
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

In [11]:
class SpectrogramTrainDataset(data.Dataset):
    def __init__(
        self,
        gby_df: pd.DataFrame,
        setting: tp.Dict
    ):
        self.img_size = setting['img_size']
        self.melspectrogram_parameters = setting['melspectrogram_parameters']
#         self.transform = get_augmentation()
        self.audio_transform = train_audio_transform
        
        self.gby_df = gby_df

    def __len__(self):
        return len(self.gby_df)

    def __getitem__(self, idx: int):
        
        wav_path = self.gby_df['name'][idx]
        train_element = self.gby_df.iloc[idx]
        
        y, sr = sf.read(wav_path)
    
        len_y = len(y)
        effective_length = sr * PERIOD

        tmin = int(SR * train_element['t_min'])
        tmax = int(SR * train_element['t_max'])
        
        #時間かかる
        tgt_len = int(tmax - tmin) 
        if(effective_length < tgt_len):
            start = int(max(0.0, min(tmin - (effective_length - tgt_len) / 2, min(tmax + (effective_length - tgt_len) / 2, len_y) - effective_length)))
            end = start + effective_length
        else:
            while(1):
                start = np.random.randint(len_y - effective_length)
                end = start + effective_length
                tgt_len = int((tmax - tmin) * OK_RANGE / 100)
                if( (start < tmin and tmin + tgt_len < end) or (start < tmax - tgt_len and tmax < end) ):
                    break
        
#         start = np.random.randint(len_y - effective_length)
#         end = start + effective_length

        
        y = y[start:end].astype(np.float32)
        y = self.audio_transform(samples=y, sample_rate=sr)
#         y = self.transform(data=y)['data']

        melspec = librosa.feature.melspectrogram(y, sr=sr, **self.melspectrogram_parameters)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        
        label = np.zeros(SPECIES_NUM, dtype="f")
        label[train_element['species_id']] = 1

        return image, label
    
class SpectrogramValidDataset(data.Dataset):
    def __init__(
        self,
        gby_df: pd.DataFrame,
        setting: tp.Dict
    ):
        self.img_size = setting['img_size']
        self.melspectrogram_parameters = setting['melspectrogram_parameters']
        
        self.gby_df = gby_df

    def __len__(self):
        return len(self.gby_df)

    def __getitem__(self, idx: int):
        
        wav_path = self.gby_df['name'][idx]
        train_element = self.gby_df.iloc[idx]
        
        y, sr = sf.read(wav_path)
        
        len_y = len(y)
        effective_length = sr * PERIOD

        tmin = int(SR * train_element['t_min'])
        tmax = int(SR * train_element['t_max'])
        
        #時間かかる
        while(1):
            start = np.random.randint(len_y - effective_length)
            end = start + effective_length
            tgt_len = int((tmax - tmin) * 50 / 100)
            if( (start < tmin and tmin + tgt_len < end) or (start < tmax - tgt_len and tmax < end) ):
                break
        
        y = y[start:end].astype(np.float32)

        melspec = librosa.feature.melspectrogram(y, sr=sr, **self.melspectrogram_parameters)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        
        label = np.zeros(SPECIES_NUM, dtype="f")
        label[train_element['species_id']] = 1

        return image, label

In [12]:
def get_criterion():
    pos_weights = torch.ones(SPECIES_NUM)
    pos_weights = pos_weights * SPECIES_NUM
#     loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    loss_function = nn.BCELoss()
    return loss_function

In [13]:
early_stop_callback = EarlyStopping(
   monitor='valid_epoch_lwlap',
   min_delta=0.00,
   patience=5,
   verbose=True,
   mode='max'
)

In [14]:
def _one_sample_positive_class_precisions(scores, truth):
    num_classes = scores.shape[0]
    pos_class_indices = np.flatnonzero(truth > 0)

    if not len(pos_class_indices):
        return pos_class_indices, np.zeros(0)

    retrieved_classes = np.argsort(scores)[::-1]

    class_rankings = np.zeros(num_classes, dtype=np.int)
    class_rankings[retrieved_classes] = range(num_classes)

    retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
    retrieved_class_true[class_rankings[pos_class_indices]] = True

    retrieved_cumulative_hits = np.cumsum(retrieved_class_true)

    precision_at_hits = (
            retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
            (1 + class_rankings[pos_class_indices].astype(np.float)))
    return pos_class_indices, precision_at_hits

def lwlrap(truth, scores):
    assert truth.shape == scores.shape
    num_samples, num_classes = scores.shape
    precisions_for_samples_by_classes = np.zeros((num_samples, num_classes))
    for sample_num in range(num_samples):
        pos_class_indices, precision_at_hits = _one_sample_positive_class_precisions(scores[sample_num, :], truth[sample_num, :])
        precisions_for_samples_by_classes[sample_num, pos_class_indices] = precision_at_hits

    labels_per_class = np.sum(truth > 0, axis=0)
    weight_per_class = labels_per_class / float(np.sum(labels_per_class))

    per_class_lwlrap = (np.sum(precisions_for_samples_by_classes, axis=0) /
                        np.maximum(1, labels_per_class))
    return per_class_lwlrap, weight_per_class

def lwlap_wrapper(y_true, y_score):
    y_true = y_true.to('cpu').detach().numpy().copy()
    y_score = y_score.to('cpu').detach().numpy().copy()
    score_class, weight = lwlrap(y_true, y_score)
    score_class = torch.from_numpy(score_class.astype(np.float32)).clone()
    weight = torch.from_numpy(weight.astype(np.float32)).clone()
    return score_class, weight

y_true = np.array([[1, 0, 0], [0, 0, 1]])
y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]])
y_true = torch.from_numpy(y_true.astype(np.float32)).clone()
y_score = torch.from_numpy(y_score.astype(np.float32)).clone()

score_class, weight = lwlap_wrapper(y_true, y_score)
score = (score_class * weight).sum()
score

tensor(0.4167)

In [15]:
def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    pad = framewise_output[:, -1:, :].repeat(
        1, frames_num - framewise_output.shape[1], 1)
    """tensor for padding"""

    output = torch.cat((framewise_output, pad), dim=1)
    """(batch_size, frames_num, classes_num)"""

    return output

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


class AttBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear",
                 temperature=1.0):
        super().__init__()

        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.bn_att = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)

In [16]:
class LitModule(pl.LightningModule):
    
    def __init__(self, num_classes=SPECIES_NUM):
        super().__init__()
        self.interpolate_ratio = 30  # Downsampled ratio
        # load pretrained models, using ResNeSt-50 as an example
        base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        layers = list(base_model.children())[:-2]
        self.encoder = nn.Sequential(*layers)

        in_features = base_model.fc.in_features

        self.fc1 = nn.Linear(in_features, in_features, bias=True)
        self.att_block = AttBlock(in_features, num_classes, activation="sigmoid")

        self.init_weight()
        self.criterion = get_criterion()

    def init_weight(self):
        init_layer(self.fc1)

    def forward(self, input):
        frames_num = input.size(3)

        # (batch_size, channels, freq, frames)
        x = self.encoder(input)

        # (batch_size, channels, frames)
        x = torch.mean(x, dim=2)

        # channel smoothing
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)
        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       self.interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        output_dict = {
            "framewise_output": framewise_output,
            "logit": logit,
            "clipwise_output": clipwise_output
        }

        return output_dict
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.01, momentum=0.9)
#         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
#         lr_scheduler = {"scheduler": scheduler }
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y_pred = self.forward(x)
#         print(y_pred['clipwise_output'].shape, y.shape)
        loss = self.criterion(y_pred['clipwise_output'], y)
        self.log('train_loss', loss,  on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y_pred = self.forward(x)
#         print(y_pred)
        loss_max = self.criterion(torch.max(y_pred['framewise_output'], 1)[0], y)
        loss_clip = self.criterion(y_pred['clipwise_output'], y)
        loss = loss_clip + 0.5*loss_max
        y_pred_act = torch.sigmoid(y_pred['clipwise_output'])
        lwlap_step, weight_step = lwlap_wrapper(y, y_pred_act)
        lwlap_step = (lwlap_step * weight_step).sum()
#         print('valid_epoch_loss = ', loss)
#         print('valid_epoch_loss = ', lwlap_step)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('lwlap_score', lwlap_step, on_epoch=True, prog_bar=True, logger=True)
        return loss, lwlap_step

    def validation_epoch_end(self, validation_step_outputs):
        validation_step_outputs = np.array(validation_step_outputs)
        validation_step_losses = validation_step_outputs[:, 0]
        mean_loss = torch.stack([x for x in validation_step_losses]).mean()
        
        validation_step_scores = validation_step_outputs[:, 1]
        mean_score = torch.stack([x for x in validation_step_scores]).mean()

        print('valid_epoch_loss = ', mean_loss)
        print('valid_epoch_lwlap = ', mean_score)
        self.log('valid_epoch_loss', mean_loss, prog_bar=True, logger=True)
        self.log('valid_epoch_lwlap', mean_score, prog_bar=True, logger=True)
#         tqdm.write('Dice: \t%.3f' % mean_loss)
        return mean_loss, mean_score

In [17]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

In [18]:
species_id_head_list = []
for l_ in train_gby['species_id']:
    species_id_head_list.append(l_[0])

In [19]:
for fold_id, (train_index, val_index) in enumerate(skf.split(train_gby, species_id_head_list)):
    # Picking only first fold to train/val on
    # This means loss of 20% training data
    # To avoid this, you can train 5 different models on 5 folds and average predictions
    train_data = train_df[~train_df['name'].isin(train_gby.iloc[val_index]['name'])]
    valid_data = train_df[train_df['name'].isin(train_gby.iloc[val_index]['name'])]    
    
    train_data.reset_index(inplace=True, drop=True)
    valid_data.reset_index(drop=True, inplace=True)
    
    train_dataset = SpectrogramTrainDataset(train_data, config["dataset"]["params"])
    valid_dataset = SpectrogramValidDataset(valid_data, config["dataset"]["params"])
    
    train_loader = data.DataLoader(train_dataset, **config["loader"]["train"])
    valid_loader = data.DataLoader(valid_dataset, **config["loader"]["valid"])
        
    model = LitModule()
    
    trainer = pl.Trainer(
        max_epochs=EPOCH,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
#         callbacks=[early_stop_callback],
        deterministic=True,
        benchmark=True
    )
    trainer.fit(model, train_loader, valid_loader)
    
    break
    
    torch.save(model.state_dict(), OUTPUT_DIR + 'model')

Using cache found in /home/user/.cache/torch/hub/zhanghang1989_ResNeSt_master
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | Sequential | 25.4 M
1 | fc1       | Linear     | 4.2 M 
2 | att_block | AttBlock   | 98.4 K
3 | criterion | BCELoss    | 0     
-----------------------------------------
29.7 M    Trainable params
0         Non-trainable params
29.7 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

valid_epoch_loss =  tensor(2.3934, device='cuda:0')
valid_epoch_lwlap =  tensor(0.0670)




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2939, device='cuda:0')
valid_epoch_lwlap =  tensor(0.2017)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2827, device='cuda:0')
valid_epoch_lwlap =  tensor(0.2415)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2682, device='cuda:0')
valid_epoch_lwlap =  tensor(0.3271)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2637, device='cuda:0')
valid_epoch_lwlap =  tensor(0.3412)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2489, device='cuda:0')
valid_epoch_lwlap =  tensor(0.3892)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2469, device='cuda:0')
valid_epoch_lwlap =  tensor(0.4048)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2315, device='cuda:0')
valid_epoch_lwlap =  tensor(0.4796)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2343, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5107)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2366, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5040)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2210, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5219)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2221, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5453)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2166, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5552)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2117, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5926)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2094, device='cuda:0')
valid_epoch_lwlap =  tensor(0.5893)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2015, device='cuda:0')
valid_epoch_lwlap =  tensor(0.6232)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1987, device='cuda:0')
valid_epoch_lwlap =  tensor(0.6505)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1951, device='cuda:0')
valid_epoch_lwlap =  tensor(0.6661)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1934, device='cuda:0')
valid_epoch_lwlap =  tensor(0.6754)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1876, device='cuda:0')
valid_epoch_lwlap =  tensor(0.6857)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1758, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7130)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1691, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7357)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1692, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7558)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1650, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7678)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1663, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7476)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1603, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7779)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1556, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7957)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1598, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7847)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1595, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8048)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1553, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8084)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1576, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8048)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1567, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8212)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1559, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8199)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1488, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8215)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1582, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8403)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1549, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8426)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1599, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8316)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1669, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8428)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1703, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8460)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1618, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8462)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1746, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8590)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1807, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8403)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1696, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8602)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1837, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8469)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1961, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8367)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.1892, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8364)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2054, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8366)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2025, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8527)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2245, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8245)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2029, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8493)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2214, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8375)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2205, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8461)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2597, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7883)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2510, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8485)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2561, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8193)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2465, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8605)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2520, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8107)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2572, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8300)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3175, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8097)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2577, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8666)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2682, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7734)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2813, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8439)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2712, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8354)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2849, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8544)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2970, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7818)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2599, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8621)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3012, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8079)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2894, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8547)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3276, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7974)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3160, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8223)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3180, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8281)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3060, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8026)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3474, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8013)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3384, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7996)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3073, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8208)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3771, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8278)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3210, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8286)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.2928, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8247)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3249, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8266)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.3677, device='cuda:0')
valid_epoch_lwlap =  tensor(0.7896)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

valid_epoch_loss =  tensor(0.4061, device='cuda:0')
valid_epoch_lwlap =  tensor(0.8092)



In [21]:
torch.save(model.state_dict(), OUTPUT_DIR + 'model')