# Introduction
I wanted to share something that worked pretty well for me early on in this competition. The idea comes from a [2018 paper](https://arxiv.org/pdf/1703.01780.pdf) titled *Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results* by Antti Tarvainen and Harri Valpola. 

### Mean Teacher
Biefly, the idea is to use two models. A student model with weights trained the standard way, using backprop. And a teacher model with weights that are an exponential moving average of the student's weights. The teacher is the *mean* of the student \*ba dum tss\*. The student is then trained using two different losses, a standard classification loss and a consistency loss that penalizes student predictions that deviate from the teaher's. 

![](https://raw.githubusercontent.com/CuriousAI/mean-teacher/master/mean_teacher.png)

Mean teachers are useful in a semi-supervised context where we have both labeled and unlabeled samples. The consistency loss on the unlabeled samples acts as a form of regularization and helps the model generalize better. As an added bonus the final teacher model is a temporal ensemble which tends to perform better than the results at the end of a single epoch. 

### Missing Labels
As a few others have pointed out, there are a lot of missing labels. If we were to randomly sample a segment from the training data, we might consider it completely unlabeled rather than rely on the provided labels. We'll train our mean teacher model(s) on two classes of data, carefully selected positive samples and randomly selected unlabeled samples. The classification loss won't apply to the unlabeled samples. 

![](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4704212%2F9ca088bb386abf7114543c019c1d8a5f%2Ffig.png?generation=1609892974092435&alt=media)

*Thanks to [shinmura0](https://www.kaggle.com/shinmurashinmura) for the great visualization!*

### Results
For me, mean teacher worked a good bit better than baseline models with similar configurations. 

|                                         | Baseline | Mean Teacher |
|-----------------------------------------|----------|--------------|
| Well Tuned, 5 fold, from my local setup | 0.847        | **0.865**            |
| Single fold Expt1 on Kaggle                   | 0.592**        | **0.786**            |
| Single fold Expt2 on Kaggle                   | 0.826        | **0.830**            |
| 5 Fold on Kaggle                        | ?        | ?            |

\*\* I might have accidentally sabatoged this run.

In [1]:
# !pip -q install --upgrade pip
# !pip -q install timm
# !pip -q install torchlibrosa
# !pip -q install audiomentations

In [5]:
!pip freeze | grep contextlib

In [6]:
!pip install cPython

Defaulting to user installation because normal site-packages is not writeable
Collecting cPython
  Downloading cPython-0.0.6.tar.gz (4.7 kB)
Collecting pymongo
  Downloading pymongo-3.11.3-cp36-cp36m-manylinux2014_x86_64.whl (509 kB)
[K     |████████████████████████████████| 509 kB 2.4 MB/s 
Building wheels for collected packages: cPython
  Building wheel for cPython (setup.py) ... [?25ldone
[?25h  Created wheel for cPython: filename=cPython-0.0.6-py3-none-any.whl size=4913 sha256=959eaee8e79a6da4eaa53d305960491cde3363e24bd3fe682e9fd4e6644b8994
  Stored in directory: /home/user/.cache/pip/wheels/b2/0f/fa/faf1ddc5dbe2fbd858f0d53fa7cc7b5f060d9986a3e1f18810
Successfully built cPython
Installing collected packages: pymongo, cPython
Successfully installed cPython-0.0.6 pymongo-3.11.3
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
import audiomentations as A
import os, time, librosa, random
from functools import partial
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from timm.models import resnet34d
from timm.models.efficientnet import tf_efficientnet_b0_ns
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm
import soundfile as sf
from pathlib import Path
# from contextlib import nullcontext

# Config
We'll start by setting up some global config variable that we'll access later.

In [2]:
INPUT_ROOT = Path("/home/knikaido/work/Rainforest-Connection/data")
RAW_DATA = INPUT_ROOT / "rfcx-species-audio-detection"
TRAIN_AUDIO_DIR = RAW_DATA / "train_denoise"
# TRAIN_RESAMPLED_AUDIO_DIRS = [
#   INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)
# ]
TEST_AUDIO_DIR = RAW_DATA / "test_denoise"

In [3]:
# Global Vars
NO_LABEL = -1
NUM_CLASSES = 24


class config:
    seed = 42
    device = "cuda:0"
    
    train_tp_csv = RAW_DATA / 'train_tp.csv'
    test_csv = RAW_DATA / 'sample_submission.csv'
    save_path = './output/'
    
    encoder = resnet34d
    encoder_features = 1024
    
    percent_unlabeled = 1.0
    consistency_weight = 100.0
    consistency_rampup = 6 # 6 epochs
    ema_decay = 0.995
    positive_weight = 2.0
    
    lr = 1e-3
    epochs = 25
    train_batch_size = 6
    valid_batch_size = 2
    num_workers = 1
    train_5_folds = True
    
    period = 6 # 6 second clips
    step = 1
    model_params = {
        'sample_rate': 48000,
        'window_size': 2048,
        'hop_size': 512,
        'mel_bins': 384,
        'fmin': 20,
        'fmax': 48000 // 2,
        'classes_num': NUM_CLASSES
    }
    
    augmenter = A.Compose([
#         A.AddGaussianNoise(p=0.33, max_amplitude=0.02),
#         A.AddGaussianSNR(p=0.33),
#         A.FrequencyMask(min_frequency_band=0.01,  max_frequency_band=0.25, p=0.33),
#         A.TimeMask(min_band_part=0.01, max_band_part=0.25, p=0.33),
#         A.Gain(p=0.33)
        
        A.AddGaussianNoise(p=0.2),
        A.AddGaussianSNR(p=0.2),
        A.Gain(min_gain_in_db=-15,max_gain_in_db=15,p=0.3),
        A.PitchShift(min_semitones=-1, max_semitones=1, p=0.2)
    ])


In [4]:
## Utils - Not much interesting going on here.

def get_n_fold_df(csv_path, folds=5):
    df = pd.read_csv(csv_path)
    df_group = df.groupby("recording_id")[["species_id"]].first().reset_index()
    df_group = df_group.sample(frac=1, random_state=config.seed).reset_index(drop=True)
    df_group.loc[:, 'fold'] = -1

    X = df_group["recording_id"].values
    y = df_group["species_id"].values

    kfold = StratifiedKFold(n_splits=folds, random_state=config.seed)
    for fold, (t_idx, v_idx) in enumerate(kfold.split(X, y)):
        df_group.loc[v_idx, "fold"] = fold

    return df.merge(df_group[['recording_id', 'fold']], on="recording_id", how="left")
    

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class MetricMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = []
        self.y_pred = []

    def update(self, y_true, y_pred):
        try:
            self.y_true.extend(y_true.detach().cpu().numpy().tolist())
            self.y_pred.extend(torch.sigmoid(y_pred).cpu().detach().numpy().tolist())
        except:
            print("UPDATE FAILURE")

    def update_list(self, y_true, y_pred):
        self.y_true.extend(y_true)
        self.y_pred.extend(y_pred)

    @property
    def avg(self):
        score_class, weight = lwlrap(np.array(self.y_true), np.array(self.y_pred))
        self.score = (score_class * weight).sum()

        return self.score
    

def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled

def _one_sample_positive_class_precisions(scores, truth):
    num_classes = scores.shape[0]
    pos_class_indices = np.flatnonzero(truth > 0)

    if not len(pos_class_indices):
        return pos_class_indices, np.zeros(0)

    retrieved_classes = np.argsort(scores)[::-1]

    class_rankings = np.zeros(num_classes, dtype=np.int)
    class_rankings[retrieved_classes] = range(num_classes)

    retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
    retrieved_class_true[class_rankings[pos_class_indices]] = True

    retrieved_cumulative_hits = np.cumsum(retrieved_class_true)

    precision_at_hits = (
            retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
            (1 + class_rankings[pos_class_indices].astype(np.float)))
    return pos_class_indices, precision_at_hits


def lwlrap(truth, scores):
    assert truth.shape == scores.shape
    num_samples, num_classes = scores.shape
    precisions_for_samples_by_classes = np.zeros((num_samples, num_classes))
    for sample_num in range(num_samples):
        pos_class_indices, precision_at_hits = _one_sample_positive_class_precisions(scores[sample_num, :],
                                                                                     truth[sample_num, :])
        precisions_for_samples_by_classes[sample_num, pos_class_indices] = precision_at_hits

    labels_per_class = np.sum(truth > 0, axis=0)
    weight_per_class = labels_per_class / float(np.sum(labels_per_class))

    per_class_lwlrap = (np.sum(precisions_for_samples_by_classes, axis=0) /
                        np.maximum(1, labels_per_class))
    return per_class_lwlrap, weight_per_class


def pretty_print_metrics(fold, epoch, optimizer, train_loss_metrics, val_loss_metrics):
    print(f"""
    {time.ctime()} \n
    Fold:{fold}, Epoch:{epoch}, LR:{optimizer.param_groups[0]['lr']:.7}, Cons. Weight: {train_loss_metrics['consistency_weight']}\n
    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                {train_loss_metrics['loss']:0.4f}   |   {val_loss_metrics['loss']:0.4f}\n
    LWLRAP:              {train_loss_metrics['lwlrap']:0.4f}   |   {val_loss_metrics['lwlrap']:0.4f}\n
    Class Loss:          {train_loss_metrics['class_loss']:0.4f}   |   {val_loss_metrics['class_loss']:0.4f}\n
    Consistency Loss:    {train_loss_metrics['consistency_loss']:0.4f}   |   {val_loss_metrics['consistency_loss']:0.4f}\n
    --------------------------------------------------------\n
    """)
    

In [5]:
class TestDataset(Dataset):
    def __init__(self, df, data_path, period=10, step=1):
        self.data_path = data_path
        self.period = period
        self.step = step
        self.recording_ids = list(df["recording_id"].unique())

    def __len__(self):
        return len(self.recording_ids)

    def __getitem__(self, idx):
        recording_id = self.recording_ids[idx]

        y, sr = sf.read(f"{self.data_path}/{recording_id}_denoise.flac")

        len_y = len(y)
        effective_length = sr * self.period
        effective_step = sr * self.step

        y_ = []
        i = 0
        while i+effective_length <= len_y:
            y__ = y[i:i + effective_length]

            y_.append(y__)
            i = i + effective_step

        y = np.stack(y_)

        label = np.zeros(NUM_CLASSES, dtype='f')

        return {
            "waveform": y,
            "target": torch.tensor(label, dtype=torch.float),
            "id": recording_id
        }


def predict_on_test(model, test_loader):
    model.eval()
    pred_list = []
    id_list = []
    with torch.no_grad():
        t = tqdm(test_loader)
        for i, sample in enumerate(t):
            input = sample["waveform"].to(config.device)
            bs, seq, w = input.shape
            input = input.reshape(bs * seq, w)
            id = sample["id"]
            output, _ = model(input)
            output = output.reshape(bs, seq, -1)
            output, _ = torch.max(output, dim=1)
            
            output = output.cpu().detach().numpy().tolist()
            pred_list.extend(output)
            id_list.extend(id)

    return pred_list, id_list

# Model
The model should look pretty familiar if you're using [SED](https://arxiv.org/abs/1912.04761). (Huge thanks to [Hidehisa Arai](https://www.kaggle.com/hidehisaarai1213) and their [SED Notebook](https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection)!) You could use any model you'd like here. There's just one small tweak we need to make for our mean teacher setup. We need to "detach" the teacher's parameters so they aren't updated by the optimizer.

In [6]:
encoder_params = {
    "tf_efficientnet_b0_ns": {
        "features": 1280,
        "init_op": partial(tf_efficientnet_b0_ns, pretrained=True, drop_path_rate=0.2)
    }
}

In [7]:
class AttentionHead(nn.Module):
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.conv_attention = nn.Conv1d(in_channels=in_features, 
                                        out_channels=out_features,
                                        kernel_size=1, stride=1, 
                                        padding=0, bias=True)
        self.conv_classes = nn.Conv1d(in_channels=in_features, 
                                      out_channels=out_features,
                                      kernel_size=1, stride=1, 
                                      padding=0, bias=True)
        self.batch_norm_attention = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.conv_attention)
        init_layer(self.conv_classes)
        init_bn(self.batch_norm_attention)

    def forward(self, x):
        norm_att = torch.softmax(torch.tanh(self.conv_attention(x)), dim=-1)
        classes = self.conv_classes(x)
        x = torch.sum(norm_att * classes, dim=2)
        return x, norm_att, classes


class SEDAudioClassifier(nn.Module):

    def __init__(self, sample_rate, window_size, hop_size, 
                 mel_bins, fmin, fmax, classes_num):
        super().__init__()
        self.interpolate_ratio = 32

        self.spectrogram_extractor = Spectrogram(n_fft=window_size, 
                                                 hop_length=hop_size,
                                                 win_length=window_size, 
                                                 window='hann', center=True,
                                                 pad_mode='reflect', 
                                                 freeze_parameters=True)
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
                                                 n_mels=mel_bins, fmin=fmin, 
                                                 fmax=fmax, ref=1.0, 
                                                 amin=1e-10, top_db=None, 
                                                 freeze_parameters=True)

        self.batch_norm = nn.BatchNorm2d(mel_bins)
#         self.encoder = partial(config.encoder, pretrained=True, in_chans=1)()
        self.encoder = encoder_params['tf_efficientnet_b0_ns']["init_op"]()
        self.fc = nn.Linear(encoder_params['tf_efficientnet_b0_ns']["features"], 
                            config.encoder_features, bias=True)
        self.att_head = AttentionHead(config.encoder_features, classes_num)
        self.avg_pool = nn.modules.pooling.AdaptiveAvgPool2d((1, 1))

        self.init_weight()

    def init_weight(self):
        init_bn(self.batch_norm)
        init_layer(self.fc)
        self.att_head.init_weights()

    def forward(self, input, spec_aug=False, 
                mixup_lambda=None, return_encoding=False):
        x = self.spectrogram_extractor(input.float())
        x = self.logmel_extractor(x)
        
        x = x.transpose(1, 3)
        x = self.batch_norm(x)
        x = x.transpose(1, 3)
        x = x.expand(x.shape[0], 3, x.shape[2], x.shape[3])

        x = self.encoder.forward_features(x)
        x = torch.mean(x, dim=3)
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_head(x)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        return clipwise_output, framewise_output


def get_model(is_mean_teacher=False):
    model = SEDAudioClassifier(**config.model_params)
    model = model.to(config.device)
    
    # Detach params for Exponential Moving Average Model (aka the Mean Teacher).
    # We'll manually update these params instead of using backprop.
    if is_mean_teacher:
        for param in model.parameters():
            param.detach_()
    return model

# Loss Function
The loss function has 2 components:

1. A classification loss that only applies to labeled samples.
2. A consistency loss that applies to all samples. 

For the consistency loss we'll use the mean square error between the student and teacher predictions. We'll slowly ramp up the influence of the consistency loss since we don't want bad, early predictions having too much influence. 

Notice that we're weighting the positive samples for the classification loss. This is because we know the positives are correct while we're less sure about the negatives due to the missing labels issue. I found that this works better in practice. 

In [8]:
def sigmoid_mse_loss(input_logits, target_logits):
    assert input_logits.size() == target_logits.size()
    input_softmax = torch.sigmoid(input_logits)
    target_softmax = torch.sigmoid(target_logits)
    num_classes = input_logits.size()[1]
    return F.mse_loss(input_softmax, target_softmax, size_average=False
                     ) / num_classes


class MeanTeacherLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.positive_weight = torch.ones(
            NUM_CLASSES).to(config.device) * config.positive_weight
        self.class_criterion = nn.BCEWithLogitsLoss(
            reduction='none', pos_weight=self.positive_weight)
        self.consistency_criterion = sigmoid_mse_loss

    def make_safe(self, pred):
        pred = torch.where(torch.isnan(pred), torch.zeros_like(pred), pred)
        return torch.where(torch.isinf(pred), torch.zeros_like(pred), pred)
        
    def get_consistency_weight(self, epoch):
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        return config.consistency_weight * sigmoid_rampup(
            epoch, config.consistency_rampup)
    
    def forward(self, student_pred, teacher_pred, target, classif_weights, epoch):
        student_pred = self.make_safe(student_pred)
        teacher_pred = self.make_safe(teacher_pred).detach().data

        batch_size = len(target)
        labeled_batch_size = target.ne(NO_LABEL).all(axis=1).sum().item() + 1e-3

        student_classif, student_consistency = student_pred, student_pred
        student_class_loss = (self.class_criterion(
            student_classif, target) * classif_weights / labeled_batch_size).sum()

        consistency_weights = self.get_consistency_weight(epoch)
        consistency_loss = consistency_weights * self.consistency_criterion(
            student_consistency, teacher_pred) / batch_size
        loss = student_class_loss + consistency_loss
        return loss, student_class_loss, consistency_loss, consistency_weights

# Data Loader
The data loader produces two types of samples:

1. Labeled samples with the audio centered in the clip.
2. Random unlabeled clips without labels selected from files with at least one true positive label.

Each sample contains 2 different inputs, one for the student and one for the teacher. Different augmentations are applied to each input.

In [9]:
class MeanTeacherDataset(Dataset):
    
    def __init__(self, df, transforms, period=5, 
                 data_path=str(TRAIN_AUDIO_DIR), 
                 val=False, percent_unlabeled=0.0):
        self.period = period
        self.transforms = transforms
        self.data_path = data_path
        self.val = val
        self.percent_unlabeled = percent_unlabeled

        dfgby = df.groupby("recording_id").agg(lambda x: list(x)).reset_index()
        self.recording_ids = dfgby["recording_id"].values
        self.species_ids = dfgby["species_id"].values
        self.t_mins = dfgby["t_min"].values
        self.t_maxs = dfgby["t_max"].values
        dffp = pd.read_csv(RAW_DATA / 'train_fp.csv')
        self.recording_ids_fp = dffp["recording_id"].values

    def __len__(self):
        return int(len(self.recording_ids) * (1 + self.percent_unlabeled))

    def __getitem__(self, idx):
        if idx >= len(self.recording_ids):
            audio, label, rec_id, sr = self.get_unlabeled_item(idx)
            # For unlabeled samples, we zero out the classification loss.
            classif_weights = np.zeros(NUM_CLASSES, dtype='f')
        else:
            audio, label, rec_id, sr = self.get_labeled_item(idx)
            classif_weights = np.ones(NUM_CLASSES, dtype='f')

        audio_teacher = np.copy(audio)

        # The 2 samples fed to the 2 models have should have different augmentations.
        audio = self.transforms(samples=audio, sample_rate=sr)
        audio_teacher = self.transforms(samples=audio_teacher, sample_rate=sr)
        # assert (audio != audio_teacher).any()
        
        return {
            "waveform": audio,
            "teacher_waveform": audio_teacher,
            "target": torch.tensor(label, dtype=torch.float),
            "classification_weights": classif_weights,
            "id": rec_id
        }

    def get_labeled_item(self, idx):
        recording_id = self.recording_ids[idx]
        species_id = self.species_ids[idx]
        t_min, t_max = self.t_mins[idx], self.t_maxs[idx]

        rec, sr = sf.read(f"{self.data_path}/{recording_id}_denoise.flac")

        len_rec = len(rec)
        effective_length = sr * self.period
        rint = np.random.randint(len(t_min))
        tmin, tmax = round(sr * t_min[rint]), round(sr * t_max[rint])
        dur = tmax - tmin
        min_dur = min(dur, round(sr * self.period))
        
#         start_min = max(0, tmax-effective_length)
#         start_max = min(tmin, len_rec - effective_length)
#         start = np.random.randint(start_min, start_max)
#         rec = rec[start:start + effective_length]

        center = round((tmin + tmax) / 2)
        rand_start = center - effective_length + max(min_dur - dur//2, 0)
        if rand_start < 0:
            rand_start = 0
        rand_end = center - max(min_dur - dur//2, 0)
        start = np.random.randint(rand_start, rand_end)
        rec = rec[start:start + effective_length]
        if len(rec) < effective_length:
            new_rec = np.zeros(effective_length, dtype=rec.dtype)
            start1 = np.random.randint(effective_length - len(rec))
            new_rec[start1:start1 + len(rec)] = rec
            rec = new_rec.astype(np.float32)
        else:
            rec = rec.astype(np.float32)

        start_time = start / sr
        end_time = (start + effective_length) / sr

        label = np.zeros(NUM_CLASSES, dtype='f')

        for i in range(len(t_min)):
            if (t_min[i] >= start_time) & (t_max[i] <= end_time):
                label[species_id[i]] = 1
            elif start_time <= ((t_min[i] + t_max[i]) / 2) <= end_time:
                label[species_id[i]] = 1

        return rec, label, recording_id, sr

    def get_unlabeled_item(self, idx, random_sample=False):
        real_idx = idx - len(self.recording_ids)
        # We want our validation set to be fixed.
        if self.val:
            rec_id = self.recording_ids[real_idx]
        else:
            rec_id = random.sample(list(self.recording_ids_fp), 1)[0]

        rec, sr = sf.read(f"{self.data_path}/{rec_id}_denoise.flac")
        effective_length = int(sr * self.period)
        max_end = len(rec) - effective_length
        if self.val:
            # Fixed start for validation. Probaably a better way to do this.
            start = int(idx * 16963 % max_end)
        else:
            start = np.random.randint(0, max_end)
        rec = rec[start:(start+effective_length)]
        rec = rec.astype(np.float32)

        label = np.ones(NUM_CLASSES, dtype='f') * NO_LABEL

        return rec, label, rec_id, sr

    
def get_data_loader(df, b_size, is_val=False):
    dataset = MeanTeacherDataset(
        df=df,
        transforms=config.augmenter,
        period=config.period,
        percent_unlabeled=config.percent_unlabeled
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=b_size,
        shuffle=not is_val,
        drop_last=not is_val,
        num_workers=config.num_workers
    )

# Training
At the end of each training step we update the teacher weights by averaging in the latest student weights.

In [10]:
# Update teacher to be exponential moving average of student params.
def update_teacher_params(student, teacher, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(teacher.parameters(), student.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)


def train_one_epoch(student, mean_teacher, loader, 
                    criterion, optimizer, scheduler, epoch):
    global_step = 0
    losses = AverageMeter()
    consistency_loss_avg = AverageMeter()
    class_loss_avg = AverageMeter()
    comp_metric = MetricMeter()
    
    student.train()
    mean_teacher.train()
#     context = nullcontext()
    
    t = tqdm(loader)
    for i, sample in enumerate(t):
        student_input = sample['waveform'].to(config.device)
        teacher_input = sample['teacher_waveform'].to(config.device)
        target = sample['target'].to(config.device)
        classif_weights = sample['classification_weights'].to(config.device)
        batch_size = len(target)

        student_pred, _  = student(student_input)
        teacher_pred, _  = mean_teacher(teacher_input)

        loss, class_loss, consistency_loss, consistency_weight = criterion(
            student_pred, teacher_pred, target, classif_weights, epoch)

#         if not is_val:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_teacher_params(student, mean_teacher, 
                              config.ema_decay, global_step)

        scheduler.step()

        comp_metric.update(target, student_pred)
        losses.update(loss.item(), batch_size)
        consistency_loss_avg.update(consistency_loss.item(), batch_size)
        class_loss_avg.update(class_loss.item(), batch_size)
        global_step += 1

        t.set_description(f"Epoch:{epoch} - Loss:{losses.avg:0.4f}")
    t.close()
    return {'lwlrap':comp_metric.avg, 
            'loss':losses.avg, 
            'consistency_loss':consistency_loss_avg.avg, 
            'class_loss':class_loss_avg.avg, 
            'consistency_weight':consistency_weight}


def valid_one_epoch(student, mean_teacher, loader, 
                    criterion, optimizer, scheduler, epoch):
    global_step = 0
    losses = AverageMeter()
    consistency_loss_avg = AverageMeter()
    class_loss_avg = AverageMeter()
    comp_metric = MetricMeter()
    
    student.eval()
    mean_teacher.eval()
    context = torch.no_grad()
#     else:
#     student.train()
#     mean_teacher.train()
#     context = nullcontext()
    
    with context:
        t = tqdm(loader)
        for i, sample in enumerate(t):
            student_input = sample['waveform'].to(config.device)
            teacher_input = sample['teacher_waveform'].to(config.device)
            target = sample['target'].to(config.device)
            classif_weights = sample['classification_weights'].to(config.device)
            batch_size = len(target)

            student_pred, _  = student(student_input)
            teacher_pred, _  = mean_teacher(teacher_input)

            loss, class_loss, consistency_loss, consistency_weight = criterion(
                student_pred, teacher_pred, target, classif_weights, epoch)

            comp_metric.update(target, student_pred)
            losses.update(loss.item(), batch_size)
            consistency_loss_avg.update(consistency_loss.item(), batch_size)
            class_loss_avg.update(class_loss.item(), batch_size)
            global_step += 1

            t.set_description(f"Epoch:{epoch} - Loss:{losses.avg:0.4f}")
        t.close()
    return {'lwlrap':comp_metric.avg, 
            'loss':losses.avg, 
            'consistency_loss':consistency_loss_avg.avg, 
            'class_loss':class_loss_avg.avg, 
            'consistency_weight':consistency_weight}

Finally putting everything together...

In [11]:
def train(df, fold):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]
    train_loader = get_data_loader(train_df, config.train_batch_size, is_val=False)
    val_loader = get_data_loader(val_df, config.valid_batch_size, is_val=True)

    student_model = get_model()
    teacher_model = get_model(is_mean_teacher=True)

    optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.lr)
    num_train_steps = int(len(train_loader) * config.epochs)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.001*0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_train_steps)
    criterion = MeanTeacherLoss()

    best_val_metric = -np.inf
    val_metrics = []
    train_metrics = []
    for epoch in range(0, config.epochs):
        train_loss_metrics = train_one_epoch(
            student_model, teacher_model, train_loader, 
            criterion, optimizer, scheduler, epoch)
        val_loss_metrics = valid_one_epoch(
            student_model, teacher_model, val_loader, 
            criterion, optimizer, scheduler, epoch)

        train_metrics.append(train_loss_metrics)
        val_metrics.append(val_loss_metrics)
        pretty_print_metrics(fold, epoch, optimizer, 
                             train_loss_metrics, val_loss_metrics)
        
        if val_loss_metrics['lwlrap'] > best_val_metric:
            print(f"    LWLRAP Improved from {best_val_metric} --> {val_loss_metrics['lwlrap']}\n")
            torch.save(teacher_model.state_dict(), 
                       os.path.join(config.save_path, f'fold-{fold}.bin'))
            best_val_metric = val_loss_metrics['lwlrap']
    


df = get_n_fold_df(config.train_tp_csv)
for fold in range(5 if config.train_5_folds else 1):
    train(df, fold)
                  
    break

  "Empty filters detected in mel frequency basis. "
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
  
Epoch:0 - Loss:7.5761: 100%|██████████| 299/299 [03:35<00:00,  1.39it/s]
Epoch:0 - Loss:3.5738: 100%|██████████| 233/233 [00:58<00:00,  4.02it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:30:21 2021 

    Fold:0, Epoch:0, LR:0.0009960574, Cons. Weight: 0.6737946999085467

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                7.5761   |   3.5738

    LWLRAP:              0.2254   |   0.2823

    Class Loss:          7.5679   |   3.5622

    Consistency Loss:    0.0082   |   0.0116

    --------------------------------------------------------

    
    LWLRAP Improved from -inf --> 0.2822635206999458



Epoch:1 - Loss:6.3060: 100%|██████████| 299/299 [03:24<00:00,  1.47it/s]
Epoch:1 - Loss:3.3115: 100%|██████████| 233/233 [00:52<00:00,  4.40it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:34:38 2021 

    Fold:0, Epoch:1, LR:0.0009842916, Cons. Weight: 3.1047958479329627

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                6.3060   |   3.3115

    LWLRAP:              0.3370   |   0.3939

    Class Loss:          6.2733   |   3.2501

    Consistency Loss:    0.0327   |   0.0614

    --------------------------------------------------------

    
    LWLRAP Improved from 0.2822635206999458 --> 0.3938935766816256



Epoch:2 - Loss:5.5648: 100%|██████████| 299/299 [03:28<00:00,  1.44it/s]
Epoch:2 - Loss:2.5265: 100%|██████████| 233/233 [00:54<00:00,  4.30it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:39:00 2021 

    Fold:0, Epoch:2, LR:0.0009648882, Cons. Weight: 10.836802322189582

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                5.5648   |   2.5265

    LWLRAP:              0.4485   |   0.5662

    Class Loss:          5.4534   |   2.3613

    Consistency Loss:    0.1114   |   0.1653

    --------------------------------------------------------

    
    LWLRAP Improved from 0.3938935766816256 --> 0.5661689256370107



Epoch:3 - Loss:4.6106: 100%|██████████| 299/299 [03:25<00:00,  1.45it/s]
Epoch:3 - Loss:2.1791: 100%|██████████| 233/233 [01:00<00:00,  3.87it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:43:26 2021 

    Fold:0, Epoch:3, LR:0.0009381533, Cons. Weight: 28.650479686019008

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.6106   |   2.1791

    LWLRAP:              0.5874   |   0.7281

    Class Loss:          4.3210   |   1.7857

    Consistency Loss:    0.2896   |   0.3934

    --------------------------------------------------------

    
    LWLRAP Improved from 0.5661689256370107 --> 0.728091444738379



Epoch:4 - Loss:4.4658: 100%|██████████| 299/299 [03:14<00:00,  1.54it/s]
Epoch:4 - Loss:2.3939: 100%|██████████| 233/233 [00:52<00:00,  4.46it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:47:33 2021 

    Fold:0, Epoch:4, LR:0.0009045085, Cons. Weight: 57.375342073743276

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.4658   |   2.3939

    LWLRAP:              0.6588   |   0.7424

    Class Loss:          3.8940   |   1.7334

    Consistency Loss:    0.5718   |   0.6605

    --------------------------------------------------------

    
    LWLRAP Improved from 0.728091444738379 --> 0.7424106889693557



Epoch:5 - Loss:3.8510: 100%|██████████| 299/299 [03:24<00:00,  1.46it/s]
Epoch:5 - Loss:2.4736: 100%|██████████| 233/233 [00:53<00:00,  4.38it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:51:51 2021 

    Fold:0, Epoch:5, LR:0.0008644843, Cons. Weight: 87.03247258333906

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.8510   |   2.4736

    LWLRAP:              0.7356   |   0.7358

    Class Loss:          3.1494   |   1.6272

    Consistency Loss:    0.7017   |   0.8464

    --------------------------------------------------------

    


Epoch:6 - Loss:3.5935: 100%|██████████| 299/299 [03:31<00:00,  1.41it/s]
Epoch:6 - Loss:2.5143: 100%|██████████| 233/233 [00:53<00:00,  4.34it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 17:56:17 2021 

    Fold:0, Epoch:6, LR:0.000818712, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.5935   |   2.5143

    LWLRAP:              0.7720   |   0.7883

    Class Loss:          2.7888   |   1.5677

    Consistency Loss:    0.8047   |   0.9466

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7424106889693557 --> 0.7883434673501971



Epoch:7 - Loss:3.3958: 100%|██████████| 299/299 [03:30<00:00,  1.42it/s]
Epoch:7 - Loss:2.6089: 100%|██████████| 233/233 [00:56<00:00,  4.14it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:00:43 2021 

    Fold:0, Epoch:7, LR:0.0007679134, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.3958   |   2.6089

    LWLRAP:              0.7910   |   0.7976

    Class Loss:          2.5984   |   1.5204

    Consistency Loss:    0.7974   |   1.0885

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7883434673501971 --> 0.7976160969831856



Epoch:8 - Loss:3.0791: 100%|██████████| 299/299 [03:35<00:00,  1.39it/s]
Epoch:8 - Loss:1.9834: 100%|██████████| 233/233 [00:52<00:00,  4.43it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:05:12 2021 

    Fold:0, Epoch:8, LR:0.0007128896, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.0791   |   1.9834

    LWLRAP:              0.8140   |   0.8440

    Class Loss:          2.3947   |   1.2147

    Consistency Loss:    0.6844   |   0.7688

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7976160969831856 --> 0.8439606279796152



Epoch:9 - Loss:2.6479: 100%|██████████| 299/299 [03:39<00:00,  1.36it/s]
Epoch:9 - Loss:1.8996: 100%|██████████| 233/233 [00:53<00:00,  4.35it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:09:45 2021 

    Fold:0, Epoch:9, LR:0.0006545085, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.6479   |   1.8996

    LWLRAP:              0.8484   |   0.8192

    Class Loss:          2.0204   |   1.2291

    Consistency Loss:    0.6275   |   0.6705

    --------------------------------------------------------

    


Epoch:10 - Loss:2.6733: 100%|██████████| 299/299 [03:31<00:00,  1.41it/s]
Epoch:10 - Loss:2.2472: 100%|██████████| 233/233 [00:51<00:00,  4.49it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:14:08 2021 

    Fold:0, Epoch:10, LR:0.0005936907, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.6733   |   2.2472

    LWLRAP:              0.8516   |   0.8303

    Class Loss:          1.9607   |   1.3245

    Consistency Loss:    0.7126   |   0.9227

    --------------------------------------------------------

    


Epoch:11 - Loss:2.5166: 100%|██████████| 299/299 [03:27<00:00,  1.44it/s]
Epoch:11 - Loss:1.6715: 100%|██████████| 233/233 [00:56<00:00,  4.15it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:18:32 2021 

    Fold:0, Epoch:11, LR:0.0005313953, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.5166   |   1.6715

    LWLRAP:              0.8707   |   0.8621

    Class Loss:          1.8743   |   1.0446

    Consistency Loss:    0.6424   |   0.6269

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8439606279796152 --> 0.8621251864895932



Epoch:12 - Loss:2.3613: 100%|██████████| 299/299 [03:43<00:00,  1.34it/s]
Epoch:12 - Loss:1.5998: 100%|██████████| 233/233 [00:54<00:00,  4.31it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:23:09 2021 

    Fold:0, Epoch:12, LR:0.0004686047, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.3613   |   1.5998

    LWLRAP:              0.8697   |   0.8743

    Class Loss:          1.7275   |   1.0458

    Consistency Loss:    0.6339   |   0.5539

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8621251864895932 --> 0.8743307775087437



Epoch:13 - Loss:2.0675: 100%|██████████| 299/299 [03:20<00:00,  1.49it/s]
Epoch:13 - Loss:1.6500: 100%|██████████| 233/233 [00:54<00:00,  4.24it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:27:25 2021 

    Fold:0, Epoch:13, LR:0.0004063093, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.0675   |   1.6500

    LWLRAP:              0.9010   |   0.8731

    Class Loss:          1.4698   |   0.9937

    Consistency Loss:    0.5977   |   0.6563

    --------------------------------------------------------

    


Epoch:14 - Loss:1.9910: 100%|██████████| 299/299 [03:32<00:00,  1.41it/s]
Epoch:14 - Loss:1.6896: 100%|██████████| 233/233 [00:53<00:00,  4.32it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:31:52 2021 

    Fold:0, Epoch:14, LR:0.0003454915, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.9910   |   1.6896

    LWLRAP:              0.8920   |   0.8722

    Class Loss:          1.4746   |   1.1333

    Consistency Loss:    0.5164   |   0.5563

    --------------------------------------------------------

    


Epoch:15 - Loss:1.8034: 100%|██████████| 299/299 [03:21<00:00,  1.48it/s]
Epoch:15 - Loss:1.7266: 100%|██████████| 233/233 [00:56<00:00,  4.11it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:36:10 2021 

    Fold:0, Epoch:15, LR:0.0002871104, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.8034   |   1.7266

    LWLRAP:              0.9056   |   0.8669

    Class Loss:          1.3173   |   1.1875

    Consistency Loss:    0.4861   |   0.5391

    --------------------------------------------------------

    


Epoch:16 - Loss:1.6614: 100%|██████████| 299/299 [03:27<00:00,  1.44it/s]
Epoch:16 - Loss:1.5380: 100%|██████████| 233/233 [00:52<00:00,  4.46it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:40:30 2021 

    Fold:0, Epoch:16, LR:0.0002320866, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.6614   |   1.5380

    LWLRAP:              0.9267   |   0.8842

    Class Loss:          1.1612   |   1.0252

    Consistency Loss:    0.5002   |   0.5128

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8743307775087437 --> 0.8841571446444327



Epoch:17 - Loss:1.6388: 100%|██████████| 299/299 [03:38<00:00,  1.37it/s]
Epoch:17 - Loss:1.5778: 100%|██████████| 233/233 [00:54<00:00,  4.28it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:45:02 2021 

    Fold:0, Epoch:17, LR:0.000181288, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.6388   |   1.5778

    LWLRAP:              0.9219   |   0.8834

    Class Loss:          1.1639   |   1.0223

    Consistency Loss:    0.4749   |   0.5554

    --------------------------------------------------------

    


Epoch:18 - Loss:1.4036: 100%|██████████| 299/299 [03:35<00:00,  1.39it/s]
Epoch:18 - Loss:1.4773: 100%|██████████| 233/233 [00:58<00:00,  4.01it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:49:36 2021 

    Fold:0, Epoch:18, LR:0.0001355157, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.4036   |   1.4773

    LWLRAP:              0.9335   |   0.8761

    Class Loss:          0.9258   |   1.0275

    Consistency Loss:    0.4778   |   0.4498

    --------------------------------------------------------

    


Epoch:19 - Loss:1.3530: 100%|██████████| 299/299 [03:23<00:00,  1.47it/s]
Epoch:19 - Loss:1.3426: 100%|██████████| 233/233 [00:55<00:00,  4.22it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:53:54 2021 

    Fold:0, Epoch:19, LR:9.54915e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.3530   |   1.3426

    LWLRAP:              0.9366   |   0.8868

    Class Loss:          0.9114   |   0.8928

    Consistency Loss:    0.4416   |   0.4499

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8841571446444327 --> 0.8867563169690829



Epoch:20 - Loss:1.2736: 100%|██████████| 299/299 [03:25<00:00,  1.45it/s]
Epoch:20 - Loss:1.4117: 100%|██████████| 233/233 [01:01<00:00,  3.77it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 18:58:22 2021 

    Fold:0, Epoch:20, LR:6.184666e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2736   |   1.4117

    LWLRAP:              0.9436   |   0.8843

    Class Loss:          0.8175   |   0.9723

    Consistency Loss:    0.4561   |   0.4394

    --------------------------------------------------------

    


Epoch:21 - Loss:1.3866: 100%|██████████| 299/299 [03:24<00:00,  1.46it/s]
Epoch:21 - Loss:1.4847: 100%|██████████| 233/233 [00:58<00:00,  4.00it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 19:02:45 2021 

    Fold:0, Epoch:21, LR:3.511176e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.3866   |   1.4847

    LWLRAP:              0.9371   |   0.8876

    Class Loss:          0.9470   |   1.0439

    Consistency Loss:    0.4397   |   0.4408

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8867563169690829 --> 0.8876261994159367



Epoch:22 - Loss:1.1521: 100%|██████████| 299/299 [03:25<00:00,  1.46it/s]
Epoch:22 - Loss:1.2939: 100%|██████████| 233/233 [00:57<00:00,  4.03it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 19:07:08 2021 

    Fold:0, Epoch:22, LR:1.570842e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1521   |   1.2939

    LWLRAP:              0.9517   |   0.8961

    Class Loss:          0.7750   |   0.8851

    Consistency Loss:    0.3772   |   0.4088

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8876261994159367 --> 0.8960956416464891



Epoch:23 - Loss:1.2340: 100%|██████████| 299/299 [03:41<00:00,  1.35it/s]
Epoch:23 - Loss:1.2587: 100%|██████████| 233/233 [00:57<00:00,  4.07it/s]
  0%|          | 0/299 [00:00<?, ?it/s]


    Sat Feb 13 19:11:48 2021 

    Fold:0, Epoch:23, LR:3.942649e-06, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2340   |   1.2587

    LWLRAP:              0.9468   |   0.8924

    Class Loss:          0.8148   |   0.8831

    Consistency Loss:    0.4192   |   0.3756

    --------------------------------------------------------

    


Epoch:24 - Loss:1.2822: 100%|██████████| 299/299 [03:37<00:00,  1.37it/s]
Epoch:24 - Loss:1.4945: 100%|██████████| 233/233 [00:53<00:00,  4.35it/s]


    Sat Feb 13 19:16:19 2021 

    Fold:0, Epoch:24, LR:0.0, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2822   |   1.4945

    LWLRAP:              0.9385   |   0.8868

    Class Loss:          0.8665   |   1.0631

    Consistency Loss:    0.4157   |   0.4314

    --------------------------------------------------------

    





# Predict on Test Set
We'll predict using the teacher model but you could also use the student or a combination of the two. Inference works just like it would for a vanilla baseline model.

In [12]:
def test(test_df, train_fold):
    test_dataset = TestDataset(
        df=test_df,
        data_path=TEST_AUDIO_DIR,
        period=config.period,
        step=config.step
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=0
    )
    
    weights_path = os.path.join(config.save_path, f'fold-{train_fold}.bin')
    model = get_model()
    model.load_state_dict(torch.load(weights_path, map_location=config.device), strict=False)
    
    test_pred, ids = predict_on_test(model, test_loader)

    # Build Submission File
    test_pred_df = pd.DataFrame({
        "recording_id": test_df.recording_id.values
    })
    target_cols = test_df.columns[1:].values.tolist()
    test_pred_df = test_pred_df.join(pd.DataFrame(np.array(test_pred), 
                                                  columns=target_cols))
    test_pred_df.to_csv(os.path.join(config.save_path, 
                                     f"fold-{train_fold}-submission.csv"), 
                        index=False)
    
    
test_df = pd.read_csv(config.test_csv)
for fold in range(5 if config.train_5_folds else 1):
    test(test_df, fold)
    break

100%|██████████| 1992/1992 [07:21<00:00,  4.51it/s]


## 5 Fold Ensemble
For 5 fold runs, we'll create a single ensemble prediction by simply averaging all of the folds.

In [13]:
def ensemble(submission_path):
    dfs = [pd.read_csv(os.path.join(
        config.save_path, f"fold-{i}-submission.csv")) for i in range(5)]
    anchor = dfs[0].copy()
    cols = anchor.columns[1:]
   
    for c in cols:
        total = 0
        for df in dfs:
            total += df[c]
        anchor[c] = total / len(dfs)
    anchor.to_csv(submission_path, index=False)


submission_path = os.path.join(config.save_path, f"submission.csv")
if config.train_5_folds:
    ensemble(submission_path)
else:
    fold0_submission = os.path.join(config.save_path, f"fold-0-submission.csv")
    os.rename(fold0_submission, submission_path)
    

FileNotFoundError: [Errno 2] No such file or directory: './output/fold-1-submission.csv'

# Conclusion 
Thanks for reading! I dropped some unrelated tricks from this and didn't spend much time tuning so there's almost definetely room for improvement.

I know it's pretty late in the competition for new notebooks, but considering that there are a few other public notebooks that score higher, I'm hoping this won't cause a significant shakeup. 