# This notebook shows the custom training of RFCX data on Tensorflow TPU.
 
In my earlier [notebook](http://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2) i have trained the model using keras.fit method, but if we want to take control of every little detail then we need to write custom loop so here i have trained model using optimized custom training loop.


The dataset used in this notebook is 10 fold Groupkfold tp only tfrecords that i have created [here](http://www.kaggle.com/ashusma/rfcx-audio-detection) and the simple script for the notebook is [this](https://www.kaggle.com/ashusma/rfcx-audio-creating-tfrecords?scriptVersionId=51531240).

Training description :

* training with 10 sec clip around true positives
* taking full spectrogram size 
* label smoothing, random_augmentation and gaussian noise
* stepwise cosine decay with warm restarts and early stopping
* for inference 10sec clip is used and then aggregrating and taking max of the audio wav prediction 


* version1 : efficientnet b4 , image_size = (512, 2000)
* version3 : resnet50, image_size= (512, 1280)
* version4 : some hyperparameters tweaking , and image_size = (256, 1024), no_augand label smoothing

In [None]:
! pip install -q efficientnet

In [None]:
import math, os, re, warnings, random , time
from collections import namedtuple
import tensorflow as tf
import numpy as np
import pandas as pd
import librosa
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
from IPython.display import Audio
from tensorflow.keras import Model, layers , optimizers
from sklearn.model_selection import KFold
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.layers import GlobalAveragePooling2D, Input, Dense, Dropout, GaussianNoise, concatenate
from tensorflow.keras.applications import ResNet50
import efficientnet.keras as efn
import seaborn as sns

# TPU Detection And Initialization

In [None]:
# TPU or GPU detection
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 42
seed_everything(seed)
warnings.filterwarnings('ignore')

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

# train_files

TRAIN_DATA_DIR = 'rfcx-audio-detection'
TRAIN_GCS_PATH = KaggleDatasets().get_gcs_path(TRAIN_DATA_DIR)
FILENAMES = tf.io.gfile.glob(TRAIN_GCS_PATH + '/tp*.tfrec')


#test_files
TEST_DATA_DIR = 'rfcx-species-audio-detection'
TEST_GCS_PATH =  KaggleDatasets().get_gcs_path(TEST_DATA_DIR)
TEST_FILES = tf.io.gfile.glob(TEST_GCS_PATH + '/tfrecords/test/*.tfrec')

no_of_training_samples = count_data_items(FILENAMES)

print('num_training_samples are', no_of_training_samples)

In [None]:
CUT = 10
TIME = 10
EPOCHS = 25
GLOBAL_BATCH_SIZE = 4 * REPLICAS
LEARNING_RATE = 0.0015
WARMUP_LEARNING_RATE = 1e-5
WARMUP_EPOCHS = int(EPOCHS*0.1)
PATIENCE = 10
STEPS_PER_EPOCH = 64
N_FOLDS = 5
NUM_TRAINING_SAMPLES = no_of_training_samples

class params:
    sample_rate = 48000
    stft_window_seconds: float = 0.025
    stft_hop_seconds: float = 0.005
    frame_length: int =  1200
    mel_bands: int = 256
    mel_min_hz: float = 50.0
    mel_max_hz: float = 16000.0
    log_offset: float = 0.001

  
    patch_bands = mel_bands
    conv_padding: str = 'same'
    batchnorm_center: bool = True
    batchnorm_scale: bool = False
    batchnorm_epsilon: float = 1e-4
    num_classes: int = 24
    dropout = 0.40
    classifier_activation: str = 'sigmoid'
    height = mel_bands
    width = 1024

In [None]:
feature_description = {
    'wav': tf.io.FixedLenFeature([], tf.string),
    'recording_id': tf.io.FixedLenFeature([], tf.string ),
    'target' : tf.io.FixedLenFeature([], tf.float32),
    'song_id': tf.io.FixedLenFeature([], tf.float32),
     'tmin' : tf.io.FixedLenFeature([], tf.float32),
     'fmin' : tf.io.FixedLenFeature([], tf.float32),
     'tmax' : tf.io.FixedLenFeature([], tf.float32),
     'fmax' : tf.io.FixedLenFeature([], tf.float32),
}
feature_dtype = {
    'wav': tf.float32,
    'recording_id': tf.string,
    'target': tf.float32,
    'song_id': tf.float32,
    't_min': tf.float32,
    'f_min': tf.float32,
    't_max': tf.float32,
    'f_max':tf.float32,
}

In [None]:
def waveform_to_log_mel_spectrogram(waveform,target_or_rec_id):
    """Compute log mel spectrogram patches of a 1-D waveform."""
    # waveform has shape [<# samples>]

    # Convert waveform into spectrogram using a Short-Time Fourier Transform.
    # Note that tf.signal.stft() uses a periodic Hann window by default.

    window_length_samples = int(
      round(params.sample_rate * params.stft_window_seconds))
    hop_length_samples = int(
      round(params.sample_rate * params.stft_hop_seconds))
    fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
#     print(fft_length, window_length_samples, hop_length_samples)
    num_spectrogram_bins = fft_length // 2 + 1
    magnitude_spectrogram = tf.abs(tf.signal.stft(
      signals=waveform,
      frame_length=params.frame_length,
      frame_step=hop_length_samples,
      fft_length= fft_length))
    # magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]

    # Convert spectrogram into log mel spectrogram.
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins=params.mel_bands,
        num_spectrogram_bins=num_spectrogram_bins,
        sample_rate=params.sample_rate,
        lower_edge_hertz=params.mel_min_hz,
        upper_edge_hertz=params.mel_max_hz)
    mel_spectrogram = tf.matmul(
      magnitude_spectrogram, linear_to_mel_weight_matrix)
    log_mel = tf.math.log(mel_spectrogram + params.log_offset)
#     log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]
    log_mel = tf.transpose(log_mel)
    log_mel_spectrogram = tf.reshape(log_mel , [tf.shape(log_mel)[0] ,tf.shape(log_mel)[1],1])
    
    return log_mel_spectrogram, target_or_rec_id

# Data augmentation

In [None]:
def frequency_masking(mel_spectrogram):
    
    frequency_masking_para = 80, 
    frequency_mask_num = 2
    
    fbank_size = tf.shape(mel_spectrogram)
#     print(fbank_size)
    n, v = fbank_size[0], fbank_size[1]

    for i in range(frequency_mask_num):
        f = tf.random.uniform([], minval=0, maxval= tf.squeeze(frequency_masking_para), dtype=tf.int32)
        v = tf.cast(v, dtype=tf.int32)
        f0 = tf.random.uniform([], minval=0, maxval= tf.squeeze(v-f), dtype=tf.int32)

        # warped_mel_spectrogram[f0:f0 + f, :] = 0
        mask = tf.concat((tf.ones(shape=(n, v - f0 - f,1)),
                          tf.zeros(shape=(n, f,1)),
                          tf.ones(shape=(n, f0,1)),
                          ),1)
        mel_spectrogram = mel_spectrogram * mask
    return tf.cast(mel_spectrogram, dtype=tf.float32)


def time_masking(mel_spectrogram):
    time_masking_para = 40, 
    time_mask_num = 1
    
    fbank_size = tf.shape(mel_spectrogram)
    n, v = fbank_size[0], fbank_size[1]

   
    for i in range(time_mask_num):
        t = tf.random.uniform([], minval=0, maxval=tf.squeeze(time_masking_para), dtype=tf.int32)
        t0 = tf.random.uniform([], minval=0, maxval= n-t, dtype=tf.int32)

        # mel_spectrogram[:, t0:t0 + t] = 0
        mask = tf.concat((tf.ones(shape=(n-t0-t, v,1)),
                          tf.zeros(shape=(t, v,1)),
                          tf.ones(shape=(t0, v,1)),
                          ), 0)
        
        mel_spectrogram = mel_spectrogram * mask
    return tf.cast(mel_spectrogram, dtype=tf.float32)


def random_brightness(image):
    return tf.image.random_brightness(image, 0.2)

def random_gamma(image):
    return tf.image.random_contrast(image, lower = 0.1, upper = 0.3)

def random_flip_right(image):
    return tf.image.random_flip_left_right(image)

def random_flip_up_down(image):
    return tf.image.random_flip_left_right(image)

available_ops = [
          frequency_masking ,
          time_masking, 
          random_brightness, 
          random_flip_up_down,
          random_flip_right 
         ]

def apply_augmentation(image, target):
    num_layers = int(np.random.uniform(low = 0, high = 3))
    
    for layer_num in range(num_layers):
        op_to_select = tf.random.uniform([], maxval=len(available_ops), dtype=tf.int32)
        for (i, op_name) in enumerate(available_ops):
            image = tf.cond(
            tf.equal(i, op_to_select),
            lambda selected_func=op_name,: selected_func(
                image),
            lambda: image)
    return image, target

# Training Data Pipeline

In [None]:
def preprocess(image, target_or_rec_id):
    
    image = tf.image.grayscale_to_rgb(image)
    image = tf.image.resize(image, [params.height,params.width])
    image = tf.image.per_image_standardization(image)
    return image , target_or_rec_id


def read_labeled_tfrecord(example_proto):
    sample = tf.io.parse_single_example(example_proto, feature_description)
    wav, _ = tf.audio.decode_wav(sample['wav'], desired_channels=1) # mono
    target = tf.cast(sample['target'],tf.float32)
    target = tf.squeeze(tf.one_hot([target,], depth = params.num_classes), axis = 0)
    
    tmin = tf.cast(sample['tmin'], tf.float32)
    fmin = tf.cast(sample['fmin'], tf.float32)
    tmax = tf.cast(sample['tmax'], tf.float32)
    fmax = tf.cast(sample['fmax'], tf.float32)
    
    tmax_s = tmax * tf.cast(params.sample_rate, tf.float32)
    tmin_s = tmin * tf.cast(params.sample_rate, tf.float32)
    cut_s = tf.cast(CUT * params.sample_rate, tf.float32)
    all_s = tf.cast(60 * params.sample_rate, tf.float32)
    tsize_s = tmax_s - tmin_s
    cut_min = tf.cast(
    tf.maximum(0.0, 
        tf.minimum(tmin_s - (cut_s - tsize_s) / 2,
                   tf.minimum(tmax_s + (cut_s - tsize_s) / 2, all_s) - cut_s)
    ), tf.int32
      )
    cut_max = cut_min + CUT * params.sample_rate
    wav = tf.squeeze(wav[cut_min : cut_max] )
    
    return wav, target

def read_unlabeled_tfrecord(example):
    feature_description = {
    'recording_id': tf.io.FixedLenFeature([], tf.string),
    'audio_wav': tf.io.FixedLenFeature([], tf.string),
    }
    sample = tf.io.parse_single_example(example, feature_description)
    wav, _ = tf.audio.decode_wav(sample['audio_wav'], desired_channels=1) # mono
    recording_id = tf.reshape(tf.cast(sample['recording_id'] , tf.string), [1])
#     wav = tf.squeeze(wav)

    def _cut_audio(i):
        _sample = {
            'audio_wav': tf.reshape(wav[i*params.sample_rate*TIME:(i+1)*params.sample_rate*TIME], [params.sample_rate*TIME]),
            'recording_id': sample['recording_id']
        }
        return _sample

    return tf.map_fn(_cut_audio, tf.range(60//TIME), dtype={
        'audio_wav': tf.float32,
        'recording_id': tf.string
    })

In [None]:
def load_dataset(filenames, labeled = True, ordered = False , training = True):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    
    ignore_order = tf.data.Options()
    if not ordered:
        # disable order, increase speed
        ignore_order.experimental_deterministic = False 
        
    # automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO )
    # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord , num_parallel_calls = AUTO )
    dataset = dataset.map(waveform_to_log_mel_spectrogram , num_parallel_calls = AUTO)   
#     if training:
#         dataset = dataset.map(apply_augmentation, num_parallel_calls = AUTO)
    dataset = dataset.map(preprocess, num_parallel_calls = AUTO)
    return dataset

In [None]:
def get_dataset(filenames, training = True):
    if training:
        dataset = load_dataset(filenames , training = True)
        dataset = dataset.shuffle(256).repeat()
        dataset = dataset.batch(GLOBAL_BATCH_SIZE, drop_remainder = True)
    else:
        dataset = load_dataset(filenames , training = False)
        dataset = dataset.repeat().batch(GLOBAL_BATCH_SIZE)
    
    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
# mel spectrogram visualization

train_dataset = get_dataset(FILENAMES, training = True)

plt.figure(figsize=(16,6))
for i, (wav, target) in enumerate(train_dataset.unbatch().take(4)):
    plt.subplot(2,2,i+1)
    plt.imshow(wav[:, :, 0])
plt.show()

# Competition Metric

In [None]:
# from https://www.kaggle.com/carlthome/l-lrap-metric-for-tf-keras
@tf.function
def _one_sample_positive_class_precisions(example):
    y_true, y_pred = example
    y_true = tf.reshape(y_true, tf.shape(y_pred))
    retrieved_classes = tf.argsort(y_pred, direction='DESCENDING')
#     shape = tf.shape(retrieved_classes)
    class_rankings = tf.argsort(retrieved_classes)
    retrieved_class_true = tf.gather(y_true, retrieved_classes)
    retrieved_cumulative_hits = tf.math.cumsum(tf.cast(retrieved_class_true, tf.float32))

    idx = tf.where(y_true)[:, 0]
    i = tf.boolean_mask(class_rankings, y_true)
    r = tf.gather(retrieved_cumulative_hits, i)
    c = 1 + tf.cast(i, tf.float32)
    precisions = r / c

    dense = tf.scatter_nd(idx[:, None], precisions, [y_pred.shape[0]])
    return dense


class LWLRAP(tf.keras.metrics.Metric):
    def __init__(self, num_classes, name='lwlrap'):
        super().__init__(name=name)

        self._precisions = self.add_weight(
            name='per_class_cumulative_precision',
            shape=[num_classes],
            initializer='zeros',
        )

        self._counts = self.add_weight(
            name='per_class_cumulative_count',
            shape=[num_classes],
            initializer='zeros',
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        precisions = tf.map_fn(
            fn=_one_sample_positive_class_precisions,
            elems=(y_true, y_pred),
            dtype=(tf.float32),
        )

        increments = tf.cast(precisions > 0, tf.float32)
        total_increments = tf.reduce_sum(increments, axis=0)
        total_precisions = tf.reduce_sum(precisions, axis=0)

        self._precisions.assign_add(total_precisions)
        self._counts.assign_add(total_increments)        

    def result(self):
        per_class_lwlrap = self._precisions / tf.maximum(self._counts, 1.0)
        per_class_weight = self._counts / tf.reduce_sum(self._counts)
        overall_lwlrap = tf.reduce_sum(per_class_lwlrap * per_class_weight)
        return overall_lwlrap

    def reset_states(self):
        self._precisions.assign(self._precisions * 0)
        self._counts.assign(self._counts * 0)

# Stepwise Cosine Decay Callback

In [None]:
learning_rate_base = LEARNING_RATE
total_steps = STEPS_PER_EPOCH * EPOCHS
warmup_learning_rate = WARMUP_LEARNING_RATE
warmup_steps= WARMUP_EPOCHS * STEPS_PER_EPOCH


@tf.function
def cosine_decay_with_warmup(global_step,
                             hold_base_rate_steps=0):

    if total_steps < warmup_steps:
        raise ValueError('total_steps must be larger or equal to '
                     'warmup_steps.')
    learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
        np.pi *
        (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
        ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
    if hold_base_rate_steps > 0:
        learning_rate = tf.where(
          global_step > warmup_steps + hold_base_rate_steps,
          learning_rate, learning_rate_base)
    if warmup_steps > 0:
        if learning_rate_base < warmup_learning_rate:
            raise ValueError('learning_rate_base must be larger or equal to '
                         'warmup_learning_rate.')
        slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
        warmup_rate = slope * tf.cast(global_step,
                                    tf.float32) + warmup_learning_rate
        learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
                               learning_rate)
    return tf.where(global_step > total_steps, 0.0, learning_rate,
                    name='learning_rate')


#dummy example
rng = [i for i in range(int(EPOCHS * STEPS_PER_EPOCH))]
WARMUP_STEPS =  int(WARMUP_EPOCHS * STEPS_PER_EPOCH)
y = [cosine_decay_with_warmup(x) for x in rng]

sns.set(style='whitegrid')
fig, ax = plt.subplots(figsize=(20, 6))
plt.plot(rng, y)

# Model Definition

In [None]:
# custom model

class RFCX_MODEL(tf.keras.Model):
    def __init__(self):
        super(RFCX_MODEL , self).__init__()
        self.gaussian_noise = GaussianNoise(0.05)
        self.resnet_model = ResNet50(include_top=False, weights='imagenet')
        self.model_output = GlobalAveragePooling2D()
        self.dropout = Dropout(params.dropout)
        self.predictions = Dense(params.num_classes, activation = params.classifier_activation )
        
    def call(self, inputs):
        noisy_input = self.gaussian_noise(inputs)
        resnet_output = self.resnet_model(noisy_input)
        x = self.model_output(resnet_output)
        x = self.dropout(x)
        x = self.predictions(x)
        return x

# Training And Validation Loop

In [None]:
def train_one_fold(train_dataset, valid_dataset):
    print('Start fine-tuning!', flush=True)
    # now we will distribute the dataset according to the strategy here it is TPUStrategy
    train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
    valid_dist_dataset = strategy.experimental_distribute_dataset(valid_dataset)
    start_time = epoch_start_time = time.time()

    print("Steps per epoch:", STEPS_PER_EPOCH)
    History = namedtuple('History', 'history')
    history = History(history={'train_loss': [], 'val_loss': [], 'train_lwlrap' : [], 'val_lwlrap' : []})
    
    train_iterator = iter(train_dist_dataset)
    val_iterator = iter(valid_dist_dataset)
    
    
    steps = 0
    best_val_lwlrap = 0
    for epoch in range(EPOCHS):
        print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
        # each iteration on train dist dataset returns per replica object dictionary containing data for each worker or replica
        train_multiple_steps(train_iterator , tf.convert_to_tensor(STEPS_PER_EPOCH))
        steps += STEPS_PER_EPOCH 
        

        if (steps // STEPS_PER_EPOCH ) > epoch:
            val_steps = 0
            val_multiple_steps(val_iterator, tf.convert_to_tensor(VAL_STEPS))
            val_steps += VAL_STEPS 
#             print('=' , end = ' ' , flush = True)

            history.history['train_loss'].append(train_loss.result().numpy() / (STEPS_PER_EPOCH) )
            history.history['val_loss'].append((val_loss.result().numpy() / val_steps))
            history.history['train_lwlrap'].append(train_lwlrap.result().numpy())
            history.history['val_lwlrap'].append(val_lwlrap.result().numpy())
            
            # show metrics
            epoch_time = time.time() - epoch_start_time
            
            print('time: {:0.1f}s'.format(epoch_time),
                  'train_loss: {:0.4f}'.format(history.history['train_loss'][-1]),
                  'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),
                  'train_lwlrap: {:0.4f}'.format(history.history['train_lwlrap'][-1]),
                  'val_lwlrap: {:0.4f}'.format(history.history['val_lwlrap'][-1]),
                  'lr : {:0.6f}'.format(cosine_decay_with_warmup(steps))
                  )

            # Early stopping monitor
            if history.history['val_lwlrap'][-1] >= best_val_lwlrap:
                best_val_lwlrap = history.history['val_lwlrap'][-1]
                model.save_weights(model_path)
                print(f'Saved model weights at "{model_path}"')
                patience_cnt = 1
            else:
                patience_cnt += 1
            if patience_cnt > PATIENCE:
                print(f'Epoch {epoch:05d}: early stopping')
                break              


            # set up next epoch

            epoch_start_time = time.time()
            train_loss.reset_states()
            val_loss.reset_states()
            train_lwlrap.reset_states()
            val_lwlrap.reset_states()  
            
    
    history_list.append(history)

In [None]:
kfold = KFold(n_splits = N_FOLDS,shuffle = True ,random_state = seed)
history_list = [] 
for fold, (train_idx, test_idx) in enumerate(kfold.split(np.arange(10))):
    if tpu : tf.tpu.experimental.initialize_tpu_system()
    K.clear_session()
    train_files = [FILENAMES[i] for i in train_idx]
    test_files = [FILENAMES[i] for i in test_idx]
    VAL_STEPS = count_data_items(test_files) // GLOBAL_BATCH_SIZE 
    train_dataset = get_dataset(train_files, training = True)
    valid_dataset = get_dataset(test_files, training = False)
    print('fold', fold+1)
    
    @tf.function
    def train_multiple_steps(train_iterator, steps):
        def train_step(wav, target):

            with tf.GradientTape() as tape:
                predictions = model(wav, training = True)
                total_loss = loss_fn(target, predictions)
            gradients = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(list(zip(gradients, model.trainable_variables)))
            train_loss.update_state(total_loss)
            train_lwlrap.update_state(target, predictions)

        for _ in tf.range(steps):
             #strategy.run will distribute train_step and execute operation as specified by function on each replica 
            strategy.run(train_step, args = (next(train_iterator)))

    @tf.function
    def val_multiple_steps(val_iterator, val_steps):
        def val_step(wav, target):
            predictions = model(wav, training = False)
            total_loss = loss_fn(target, predictions)
            val_loss.update_state(total_loss)
            val_lwlrap.update_state(target, predictions)
        for _ in tf.range(val_steps):
            strategy.run(val_step, args = (next(val_iterator)))

    # defining model and variables under strategy to allow tpu to track them
    with strategy.scope():
        model = RFCX_MODEL() 
        model.build((None,None,None, 3))
        model.summary()
        loss = tf.keras.losses.BinaryCrossentropy(reduction= tf.keras.losses.Reduction.NONE)
        loss_fn = lambda target, predict : tf.nn.compute_average_loss(loss(target, predict) , global_batch_size = GLOBAL_BATCH_SIZE)
        class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
            def __call__(self, step):
                return cosine_decay_with_warmup(step)
        optimizer = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
        train_loss = tf.keras.metrics.Sum()
        val_loss = tf.keras.metrics.Sum()
        train_lwlrap = LWLRAP(params.num_classes)
        val_lwlrap = LWLRAP(params.num_classes)
        
       
    model_path = f'RFCX_model_fold {fold}.h5'
    # training for one fold
    train_one_fold(train_dataset, valid_dataset)

# Plot curve

In [None]:
def plot_history(history):
    plt.figure(figsize=(8,3))
    plt.subplot(1,2,1)
    plt.plot(history.history["train_loss"])
    plt.plot(history.history["val_loss"])
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.title("loss")

    plt.subplot(1,2,2)
    plt.plot(history.history["train_lwlrap"])
    plt.plot(history.history["val_lwlrap"])
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.title("lwlrap")
    
for hist in history_list:
#     print(hist)
    plot_history(hist)

# Inference

In [None]:
def get_test_dataset(filenames, training = False):
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO )  
    dataset = dataset.map(read_unlabeled_tfrecord , num_parallel_calls = AUTO ).unbatch()
    dataset = dataset.map(lambda spec : waveform_to_log_mel_spectrogram(spec['audio_wav'], spec['recording_id']) , num_parallel_calls = AUTO)
    dataset = dataset.map(preprocess, num_parallel_calls = AUTO)
    return dataset.batch(GLOBAL_BATCH_SIZE*4).cache()

In [None]:
test_predict = []

test_data = get_test_dataset(TEST_FILES, training = False)
test_audio = test_data.map(lambda frames, recording_id: frames)

for fold in range(N_FOLDS):
    model.load_weights(f'./RFCX_model_fold {fold}.h5')
    test_predict.append(model.predict(test_audio, verbose = 1 ))

# Submission

In [None]:
np.array(test_predict).shape

In [None]:
SUB = pd.read_csv('../input/rfcx-species-audio-detection/sample_submission.csv')

predict = np.array(test_predict).reshape(N_FOLDS, len(SUB), 60 // TIME, params.num_classes)
predict = np.mean(np.max(predict ,axis = 2) , axis = 0)
# predict = np.mean(predict, axis =  0)

recording_id = test_data.map(lambda frames, recording_id: recording_id).unbatch()
# # all in one batch
test_ids = next(iter(recording_id.batch(len(SUB) * 60 // TIME))).numpy().astype('U').reshape(len(SUB), 60 // TIME)

pred_df = pd.DataFrame({ 'recording_id' : test_ids[:, 0],
             **{f's{i}' : predict[:, i] for i in range(params.num_classes)} })

In [None]:
pred_df.sort_values('recording_id', inplace = True) 
pred_df.to_csv('submission.csv', index = False)    

In [None]:
pred_df