# Introduction
I wanted to share something that worked pretty well for me early on in this competition. The idea comes from a [2018 paper](https://arxiv.org/pdf/1703.01780.pdf) titled *Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results* by Antti Tarvainen and Harri Valpola. 

### Mean Teacher
Biefly, the idea is to use two models. A student model with weights trained the standard way, using backprop. And a teacher model with weights that are an exponential moving average of the student's weights. The teacher is the *mean* of the student \*ba dum tss\*. The student is then trained using two different losses, a standard classification loss and a consistency loss that penalizes student predictions that deviate from the teaher's. 

![](https://raw.githubusercontent.com/CuriousAI/mean-teacher/master/mean_teacher.png)

Mean teachers are useful in a semi-supervised context where we have both labeled and unlabeled samples. The consistency loss on the unlabeled samples acts as a form of regularization and helps the model generalize better. As an added bonus the final teacher model is a temporal ensemble which tends to perform better than the results at the end of a single epoch. 

### Missing Labels
As a few others have pointed out, there are a lot of missing labels. If we were to randomly sample a segment from the training data, we might consider it completely unlabeled rather than rely on the provided labels. We'll train our mean teacher model(s) on two classes of data, carefully selected positive samples and randomly selected unlabeled samples. The classification loss won't apply to the unlabeled samples. 

![](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4704212%2F9ca088bb386abf7114543c019c1d8a5f%2Ffig.png?generation=1609892974092435&alt=media)

*Thanks to [shinmura0](https://www.kaggle.com/shinmurashinmura) for the great visualization!*

### Results
For me, mean teacher worked a good bit better than baseline models with similar configurations. 

|                                         | Baseline | Mean Teacher |
|-----------------------------------------|----------|--------------|
| Well Tuned, 5 fold, from my local setup | 0.847        | **0.865**            |
| Single fold Expt1 on Kaggle                   | 0.592**        | **0.786**            |
| Single fold Expt2 on Kaggle                   | 0.826        | **0.830**            |
| 5 Fold on Kaggle***                        | 0.844        | **0.857**           |

\*\* I might have accidentally sabatoged this run.

\*\*\* There was a major bug in v21 of the notebook where the consistence_ramp was set to 1000 which means it was just normal / non-mean-teacher training. Setting consisteny_ramp to 6 and using the mean teacher, we get an improvement of 0.13.

In [1]:
!pip -q install --upgrade pip
!pip -q install timm
!pip -q install torchlibrosa
!pip -q install audiomentations

In [2]:
import audiomentations as A
import os, time, librosa, random
from functools import partial
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from timm.models import resnet34d, resnest26d, resnest50d
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
    tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns, tf_efficientnet_b0_ns, tf_efficientnet_b1_ns
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm
import soundfile as sf
from contextlib import nullcontext


from ranger import Ranger

In [3]:
class MixUp(nn.Module):
    def __init__(self, prob=0.33, alpha=0.6, mixup_mode="basic"):
        super().__init__()
        self.alpha = alpha
        self.prob = prob
        self.mixup_mode = mixup_mode
        
    def forward(self, waveforms, labels):
        inds = np.arange(waveforms.shape[0])
        new_inds = inds.copy()
        np.random.shuffle(new_inds)
        aug_count = int(inds[inds != new_inds].shape[0] * self.prob)
        to_augment = np.random.choice(inds[inds != new_inds], aug_count, replace=False)
        betas = torch.tensor(np.random.beta(self.alpha, self.alpha, size=aug_count),
                             dtype=torch.float).unsqueeze(1).to(waveforms.device)
        # new_inds = torch.tensor(new_inds)
        # to_augment = torch.tensor(to_augment)
        waveforms[to_augment] = betas * waveforms[to_augment] + (1 - betas) * waveforms[new_inds][to_augment]
        if self.mixup_mode == "basic":
            labels[to_augment] = betas * labels[to_augment] + (1 - betas) * labels[new_inds][to_augment]
        elif self.mixup_mode == "or":
            labels[to_augment] = torch.clamp_max(labels[to_augment] + labels[new_inds][to_augment], max=1.)
        return waveforms, labels

# Config
We'll start by setting up some global config variable that we'll access later.

In [4]:
# Global Vars
NO_LABEL = -1
NUM_CLASSES = 24

LOSS_TYPE = "LSEP"

class config:
    seed = 42
    device = "cuda:0"
    
    train_tp_csv = '/media/paniquex/samsung_2tb/rfcx_kaggle/rfcx-species-audio-detection/train_tp.csv'
    test_csv = '/media/paniquex/samsung_2tb/rfcx_kaggle/rfcx-species-audio-detection/sample_submission.csv'
    save_path = '../experiments/mean_teacher_mels=224_35_epoches_BCE_eff_b0_adamw_period=3_val=3_encoder_percent_unlabeled=1.0_consistency_weight=50_consistency_rampup=6_ema_decay=0.995'
    
    encoder = tf_efficientnet_b0_ns
    encoder_features = 1280
    
    percent_unlabeled = 1.0
    consistency_weight = 100.0
    consistency_rampup = 6
    
    ema_decay = 0.995
    positive_weight = 2.0
    
    lr = 1e-3
    epochs = 35
    batch_size = 16
    num_workers = 8
    train_5_folds = True
    
    period = 3 # 6 second clips
    period_val = 3

    
    step = 1
    model_params = {
        'sample_rate': 48000,
        'window_size': 1024,
        'hop_size': 345,
        'mel_bins': 224,
        'fmin': 20,
        'fmax': 48000 // 2,
        'classes_num': NUM_CLASSES,
        'mixup_module': None
    }
    
    augmenter = A.Compose([
        A.AddGaussianNoise(p=0.33, max_amplitude=0.02),
        A.AddGaussianSNR(p=0.33),
        A.FrequencyMask(min_frequency_band=0.01,  max_frequency_band=0.25, p=0.33),
        A.TimeMask(min_band_part=0.01, max_band_part=0.25, p=0.33),
        A.Gain(p=0.33)
    ])


In [5]:
try:
    shutil.rmtree(config.save_path)
except:
    pass

try:
    os.mkdir(config.save_path)
except:
    pass

In [6]:
## Utils - Not much interesting going on here.

def get_n_fold_df(csv_path, folds=5):
    df = pd.read_csv(csv_path)
    df_group = df.groupby("recording_id")[["species_id"]].first().reset_index()
    df_group = df_group.sample(frac=1, random_state=config.seed).reset_index(drop=True)
    df_group.loc[:, 'fold'] = -1

    X = df_group["recording_id"].values
    y = df_group["species_id"].values

    kfold = StratifiedKFold(n_splits=folds, random_state=config.seed)
    for fold, (t_idx, v_idx) in enumerate(kfold.split(X, y)):
        df_group.loc[v_idx, "fold"] = fold

    return df.merge(df_group[['recording_id', 'fold']], on="recording_id", how="left")
    

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class MetricMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = []
        self.y_pred = []

    def update(self, y_true, y_pred):
        try:
            self.y_true.extend(y_true.detach().cpu().numpy().tolist())
            self.y_pred.extend(torch.sigmoid(y_pred).cpu().detach().numpy().tolist())
        except:
            print("UPDATE FAILURE")

    def update_list(self, y_true, y_pred):
        self.y_true.extend(y_true)
        self.y_pred.extend(y_pred)

    @property
    def avg(self):
        score_class, weight = lwlrap(np.array(self.y_true), np.array(self.y_pred))
        self.score = (score_class * weight).sum()

        return self.score
    

def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled

def _one_sample_positive_class_precisions(scores, truth):
    num_classes = scores.shape[0]
    pos_class_indices = np.flatnonzero(truth > 0)

    if not len(pos_class_indices):
        return pos_class_indices, np.zeros(0)

    retrieved_classes = np.argsort(scores)[::-1]

    class_rankings = np.zeros(num_classes, dtype=np.int)
    class_rankings[retrieved_classes] = range(num_classes)

    retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
    retrieved_class_true[class_rankings[pos_class_indices]] = True

    retrieved_cumulative_hits = np.cumsum(retrieved_class_true)

    precision_at_hits = (
            retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
            (1 + class_rankings[pos_class_indices].astype(np.float)))
    return pos_class_indices, precision_at_hits


def lwlrap(truth, scores):
    assert truth.shape == scores.shape
    num_samples, num_classes = scores.shape
    precisions_for_samples_by_classes = np.zeros((num_samples, num_classes))
    for sample_num in range(num_samples):
        pos_class_indices, precision_at_hits = _one_sample_positive_class_precisions(scores[sample_num, :],
                                                                                     truth[sample_num, :])
        precisions_for_samples_by_classes[sample_num, pos_class_indices] = precision_at_hits

    labels_per_class = np.sum(truth > 0, axis=0)
    weight_per_class = labels_per_class / float(np.sum(labels_per_class))

    per_class_lwlrap = (np.sum(precisions_for_samples_by_classes, axis=0) /
                        np.maximum(1, labels_per_class))
    return per_class_lwlrap, weight_per_class


def pretty_print_metrics(fold, epoch, optimizer, train_loss_metrics, val_loss_metrics):
    print(f"""
    {time.ctime()} \n
    Fold:{fold}, Epoch:{epoch}, LR:{optimizer.param_groups[0]['lr']:.7}, Cons. Weight: {train_loss_metrics['consistency_weight']}\n
    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                {train_loss_metrics['loss']:0.4f}   |   {val_loss_metrics['loss']:0.4f}\n
    LWLRAP:              {train_loss_metrics['lwlrap']:0.4f}   |   {val_loss_metrics['lwlrap']:0.4f}\n
    Class Loss:          {train_loss_metrics['class_loss']:0.4f}   |   {val_loss_metrics['class_loss']:0.4f}\n
    Consistency Loss:    {train_loss_metrics['consistency_loss']:0.4f}   |   {val_loss_metrics['consistency_loss']:0.4f}\n
    --------------------------------------------------------\n
    """)
    

class TestDataset(Dataset):
    def __init__(self, df, data_path, period=10, step=1):
        self.data_path = data_path
        self.period = period
        self.step = step
        self.recording_ids = list(df["recording_id"].unique())

    def __len__(self):
        return len(self.recording_ids)

    def __getitem__(self, idx):
        recording_id = self.recording_ids[idx]

        y, sr = sf.read(f"{self.data_path}/{recording_id}.flac")

        len_y = len(y)
        effective_length = sr * self.period
        effective_step = sr * self.step

        y_ = []
        i = 0
        while i+effective_length <= len_y:
            y__ = y[i:i + effective_length]

            y_.append(y__)
            i = i + effective_step

        y = np.stack(y_)

        label = np.zeros(NUM_CLASSES, dtype='f')

        return {
            "waveform": y,
            "target": torch.tensor(label, dtype=torch.float),
            "id": recording_id
        }


def predict_on_test(model, test_loader):
    model.eval()
    pred_list = []
    id_list = []
    with torch.no_grad():
        t = tqdm(test_loader)
        for i, sample in enumerate(t):
            input = sample["waveform"].to(config.device)
            bs, seq, w = input.shape
            input = input.reshape(bs * seq, w)
            id = sample["id"]
            output, _ = model(input)
            output = output.reshape(bs, seq, -1)
            output, _ = torch.max(output, dim=1)
            
            output = output.cpu().detach().numpy().tolist()
            pred_list.extend(output)
            id_list.extend(id)

    return pred_list, id_list

# Model
The model should look pretty familiar if you're using [SED](https://arxiv.org/abs/1912.04761). (Huge thanks to [Hidehisa Arai](https://www.kaggle.com/hidehisaarai1213) and their [SED Notebook](https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection)!) You could use any model you'd like here. There's just one small tweak we need to make for our mean teacher setup. We need to "detach" the teacher's parameters so they aren't updated by the optimizer.

In [7]:
class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)



    
    
class SEDAudioClassifier(nn.Module):

    def __init__(self, sample_rate, window_size, hop_size, 
                 mel_bins, fmin, fmax, classes_num, mixup_module=None):
        super().__init__()
        self.interpolate_ratio = 32

        self.spectrogram_extractor = Spectrogram(n_fft=window_size, 
                                                 hop_length=hop_size,
                                                 win_length=window_size, 
                                                 window='hann', center=True,
                                                 pad_mode='reflect', 
                                                 freeze_parameters=True)
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
                                                 n_mels=mel_bins, fmin=fmin, 
                                                 fmax=fmax, ref=1.0, 
                                                 amin=1e-10, top_db=None, 
                                                 freeze_parameters=True)

        self.batch_norm = nn.BatchNorm2d(mel_bins)
        self.encoder = partial(config.encoder, pretrained=True, in_chans=1)()
        self.fc = nn.Linear(config.encoder_features, 
                            config.encoder_features, bias=True)
        self.att_head = AttBlockV2(config.encoder_features, classes_num)
        self.avg_pool = nn.modules.pooling.AdaptiveAvgPool2d((1, 1))
        
        self.mixup_module = mixup_module
        
        self.init_weight()

    def init_weight(self):
        init_bn(self.batch_norm)
        init_layer(self.fc)
        self.att_head.init_weights()

    def forward(self, input, labels=None, spec_aug=False, return_encoding=False):
        
        with torch.no_grad():
            if self.mixup_module and labels is not None:
                input, labels = self.mixup_module(input, labels)
        
        x = self.spectrogram_extractor(input.float())
        x = self.logmel_extractor(x)
        
        x = x.transpose(1, 3)
        x = self.batch_norm(x)
        x = x.transpose(1, 3)

        x = self.encoder.forward_features(x)
        x = torch.mean(x, dim=3)
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
#         print(x.shape)
        x = F.relu_(self.fc(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_head(x)
        logit = torch.sum(norm_att * self.att_head.cla(x), dim=2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        if labels is not None:
            return clipwise_output, framewise_output, logit, labels
        else:
            return clipwise_output, framewise_output, labels


def get_model(is_mean_teacher=False):
    model = SEDAudioClassifier(**config.model_params)
    model = model.to(config.device)
    
    # Detach params for Exponential Moving Average Model (aka the Mean Teacher).
    # We'll manually update these params instead of using backprop.
    if is_mean_teacher:
        for param in model.parameters():
            param.detach_()
    return model

# Loss Function
The loss function has 2 components:

1. A classification loss that only applies to labeled samples.
2. A consistency loss that applies to all samples. 

For the consistency loss we'll use the mean square error between the student and teacher predictions. We'll slowly ramp up the influence of the consistency loss since we don't want bad, early predictions having too much influence. 

Notice that we're weighting the positive samples for the classification loss. This is because we know the positives are correct while we're less sure about the negatives due to the missing labels issue. I found that this works better in practice. 

In [8]:
import torch.nn as nn




In [9]:
class ImprovedPANNsLoss(nn.Module):
    def __init__(self, pos_weight, weights=[1, 0.5]):
        super().__init__()

        self.normal_loss = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)

        self.bce = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
        self.weights = weights

    def forward(self, input, framewise_output, target):
        input_ = input
        target = target.float()

        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = self.normal_loss(input_, target)
        auxiliary_loss = self.bce(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss

In [10]:
def sigmoid_mse_loss(input_logits, target_logits):
    assert input_logits.size() == target_logits.size()
    input_softmax = torch.sigmoid(input_logits)
    target_softmax = torch.sigmoid(target_logits)
    num_classes = input_logits.size()[1]
    return F.mse_loss(input_softmax, target_softmax, size_average=False
                     ) / num_classes

def lsep_loss_stable(input, target, average=True):

    n = input.size(0)

    differences = input.unsqueeze(1) - input.unsqueeze(2)
    where_lower = (target.unsqueeze(1) < target.unsqueeze(2)).float()

    differences = differences.view(n, -1)
    where_lower = where_lower.view(n, -1)

    max_difference, index = torch.max(differences, dim=1, keepdim=True)
    differences = differences - max_difference
    exps = differences.exp() * where_lower

    lsep = max_difference + torch.log(torch.exp(-max_difference) + exps.sum(-1))

    if average:
        return lsep.mean()
    else:
        return lsep


class MeanTeacherLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.positive_weight = torch.ones(
            NUM_CLASSES).to(config.device) * config.positive_weight
        self.class_criterion = nn.BCEWithLogitsLoss(
            reduction='none', pos_weight=self.positive_weight)
        self.consistency_criterion = sigmoid_mse_loss

    def make_safe(self, pred):
        pred = torch.where(torch.isnan(pred), torch.zeros_like(pred), pred)
        return torch.where(torch.isinf(pred), torch.zeros_like(pred), pred)
        
    def get_consistency_weight(self, epoch):
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        return config.consistency_weight * sigmoid_rampup(
            epoch, config.consistency_rampup)
    
    def forward(self, student_pred, teacher_pred, target, classif_weights, epoch):
        student_pred = self.make_safe(student_pred)
        teacher_pred = self.make_safe(teacher_pred).detach().data

        batch_size = len(target)
        labeled_batch_size = target.ne(NO_LABEL).all(axis=1).sum().item() + 1e-3

        student_classif, student_consistency = student_pred, student_pred
        student_class_loss = (self.class_criterion(
            student_classif, target) * classif_weights / labeled_batch_size).sum()

        consistency_weights = self.get_consistency_weight(epoch)
        consistency_loss = consistency_weights * self.consistency_criterion(
            student_consistency, teacher_pred) / batch_size
        loss = student_class_loss + consistency_loss
        return loss, student_class_loss, consistency_loss, consistency_weights

# Data Loader
The data loader produces two types of samples:

1. Labeled samples with the audio centered in the clip.
2. Random unlabeled clips without labels selected from files with at least one true positive label.

Each sample contains 2 different inputs, one for the student and one for the teacher. Different augmentations are applied to each input.

In [11]:
class MeanTeacherDataset(Dataset):
    
    def __init__(self, df, transforms, period=5, 
                 data_path="/media/paniquex/samsung_2tb/rfcx_kaggle/rfcx-species-audio-detection/train", 
                 val=False, percent_unlabeled=0.0):
        self.period = period
        self.transforms = transforms
        self.data_path = data_path
        self.val = val
        self.percent_unlabeled = percent_unlabeled

        dfgby = df.groupby("recording_id").agg(lambda x: list(x)).reset_index()
        self.recording_ids = dfgby["recording_id"].values
        self.species_ids = dfgby["species_id"].values
        self.t_mins = dfgby["t_min"].values
        self.t_maxs = dfgby["t_max"].values

    def __len__(self):
        return int(len(self.recording_ids) * (1 + self.percent_unlabeled))

    def __getitem__(self, idx):
        if idx >= len(self.recording_ids):
            audio, label, rec_id, sr = self.get_unlabeled_item(idx)
            # For unlabeled samples, we zero out the classification loss.
            classif_weights = np.zeros(NUM_CLASSES, dtype='f')
        else:
            audio, label, rec_id, sr = self.get_labeled_item(idx)
            classif_weights = np.ones(NUM_CLASSES, dtype='f')

        audio_teacher = np.copy(audio)

        # The 2 samples fed to the 2 models have should have different augmentations.
        audio = self.transforms(samples=audio, sample_rate=sr)
        audio_teacher = self.transforms(samples=audio_teacher, sample_rate=sr)
        # assert (audio != audio_teacher).any()
        
        return {
            "waveform": audio,
            "teacher_waveform": audio_teacher,
            "target": torch.tensor(label, dtype=torch.float),
            "classification_weights": classif_weights,
            "id": rec_id
        }

    def get_labeled_item(self, idx):
        recording_id = self.recording_ids[idx]
        species_id = self.species_ids[idx]
        t_min, t_max = self.t_mins[idx], self.t_maxs[idx]

        rec, sr = sf.read(f"{self.data_path}/{recording_id}.flac")

        len_rec = len(rec)
        effective_length = sr * self.period
        rint = np.random.randint(len(t_min))
        tmin, tmax = round(sr * t_min[rint]), round(sr * t_max[rint])
        dur = tmax - tmin
        min_dur = min(dur, round(sr * self.period))

        center = round((tmin + tmax) / 2)
        rand_start = center - effective_length + max(min_dur - dur//2, 0)
        if rand_start < 0:
            rand_start = 0
        rand_end = center - max(min_dur - dur//2, 0)
        start = np.random.randint(rand_start, rand_end)
        rec = rec[start:start + effective_length]
        if len(rec) < effective_length:
            new_rec = np.zeros(effective_length, dtype=rec.dtype)
            start1 = np.random.randint(effective_length - len(rec))
            new_rec[start1:start1 + len(rec)] = rec
            rec = new_rec.astype(np.float32)
        else:
            rec = rec.astype(np.float32)

        start_time = start / sr
        end_time = (start + effective_length) / sr

        label = np.zeros(NUM_CLASSES, dtype='f')

        for i in range(len(t_min)):
            if (t_min[i] >= start_time) & (t_max[i] <= end_time):
                label[species_id[i]] = 1
            elif start_time <= ((t_min[i] + t_max[i]) / 2) <= end_time:
                label[species_id[i]] = 1

        return rec, label, recording_id, sr

    def get_unlabeled_item(self, idx, random_sample=False):
        real_idx = idx - len(self.recording_ids)
        # We want our validation set to be fixed.
        if self.val:
            rec_id = self.recording_ids[real_idx]
        else:
            rec_id = random.sample(list(self.recording_ids), 1)[0]

        rec, sr = sf.read(f"{self.data_path}/{rec_id}.flac")
        effective_length = int(sr * self.period)
        max_end = len(rec) - effective_length
        if self.val:
            # Fixed start for validation. Probaably a better way to do this.
            start = int(idx * 16963 % max_end)
        else:
            start = np.random.randint(0, max_end)
        rec = rec[start:(start+effective_length)]
        rec = rec.astype(np.float32)

        label = np.ones(NUM_CLASSES, dtype='f') * NO_LABEL

        return rec, label, rec_id, sr

    
def get_data_loader(df, is_val=False):
    if is_val:
        period = config.period_val
    else:
        period = config.period
    dataset = MeanTeacherDataset(
        df=df,
        transforms=config.augmenter,
        period=period,
        percent_unlabeled=config.percent_unlabeled
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not is_val,
        drop_last=not is_val,
        num_workers=config.num_workers
    )

# Training
At the end of each training step we update the teacher weights by averaging in the latest student weights.

In [12]:
# Update teacher to be exponential moving average of student params.
def update_teacher_params(student, teacher, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(teacher.parameters(), student.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)


def train_one_epoch(student, mean_teacher, loader, 
                    criterion, optimizer, scheduler, epoch, is_val=False):
    global_step = 0
    losses = AverageMeter()
    consistency_loss_avg = AverageMeter()
    class_loss_avg = AverageMeter()
    comp_metric = MetricMeter()
    
    if is_val:
        student.eval()
        mean_teacher.eval()
        context = torch.no_grad()
    else:
        student.train()
        mean_teacher.train()
        context = nullcontext()
    
    with context:
        t = tqdm(loader)
        for i, sample in enumerate(t):
            student_input = sample['waveform'].to(config.device)
            teacher_input = sample['teacher_waveform'].to(config.device)
            target = sample['target'].to(config.device)
            classif_weights = sample['classification_weights'].to(config.device)
            batch_size = len(target)
            
            if student.mixup_module:
                student_pred, framewise_output, logit, target = student(student_input, labels=target)
            else:
                student_pred, framewise_output, logit, target = student(student_input, labels=target)
#             if teacher.mixup_module:
#                 teacher_pred, _, target  = mean_teacher(teacher_input, labels=target)
#             else:
            teacher_pred, _, _ = mean_teacher(teacher_input)

            loss, class_loss, consistency_loss, consistency_weight = criterion(
                student_pred, teacher_pred, target, classif_weights, epoch)

            if not is_val:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                update_teacher_params(student, mean_teacher, 
                                      config.ema_decay, global_step)

                scheduler.step()

            comp_metric.update(target, student_pred)
            losses.update(loss.item(), batch_size)
            consistency_loss_avg.update(consistency_loss.item(), batch_size)
            class_loss_avg.update(class_loss.item(), batch_size)
            global_step += 1

            t.set_description(f"Epoch:{epoch} - Loss:{losses.avg:0.4f}")
        t.close()
    return {'lwlrap':comp_metric.avg, 
            'loss':losses.avg, 
            'consistency_loss':consistency_loss_avg.avg, 
            'class_loss':class_loss_avg.avg, 
            'consistency_weight':consistency_weight}

Finally putting everything together...

In [None]:
def train(df, fold):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]
    train_loader = get_data_loader(train_df)
    val_loader = get_data_loader(val_df)

    student_model = get_model()
    teacher_model = get_model(is_mean_teacher=True)

#     optimizer = Ranger(student_model.parameters(),
#                lr=config.lr,
#                k=4,
#                betas=(.9, 0.999), weight_decay=0)
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.lr)
    num_train_steps = int(len(train_loader) * config.epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_train_steps)
    criterion = MeanTeacherLoss()

    best_val_metric = -np.inf
    val_metrics = []
    train_metrics = []
    for epoch in range(0, config.epochs):
        train_loss_metrics = train_one_epoch(
            student_model, teacher_model, train_loader, 
            criterion, optimizer, scheduler, epoch)
        val_loss_metrics = train_one_epoch(
            student_model, teacher_model, val_loader, 
            criterion, optimizer, scheduler, epoch, is_val=True)

        train_metrics.append(train_loss_metrics)
        val_metrics.append(val_loss_metrics)
        pretty_print_metrics(fold, epoch, optimizer, 
                             train_loss_metrics, val_loss_metrics)
        
        if val_loss_metrics['lwlrap'] > best_val_metric:
            print(f"    LWLRAP Improved from {best_val_metric} --> {val_loss_metrics['lwlrap']}\n")
            best_val_metric = val_loss_metrics['lwlrap']
            
            torch.save(teacher_model.state_dict(), 
                       os.path.join(config.save_path, f'fold-{fold}_{best_val_metric:.3f}.bin'))
    


df = get_n_fold_df(config.train_tp_csv)
for fold in range(5 if config.train_5_folds else 1):
    train(df, fold)

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  
Epoch:0 - Loss:7.0537: 100%|██████████| 113/113 [00:40<00:00,  2.76it/s]
Epoch:0 - Loss:6.2640: 100%|██████████| 28/28 [00:10<00:00,  2.67it/s]



    Thu Feb 18 01:55:31 2021 

    Fold:0, Epoch:0, LR:0.0009979871, Cons. Weight: 0.6737946999085467

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                7.0537   |   6.2640

    LWLRAP:              0.2995   |   0.4273

    Class Loss:          7.0450   |   6.2498

    Consistency Loss:    0.0086   |   0.0143

    --------------------------------------------------------

    
    LWLRAP Improved from -inf --> 0.4273342285131242



Epoch:1 - Loss:4.9667: 100%|██████████| 113/113 [00:36<00:00,  3.07it/s]
Epoch:1 - Loss:4.3109: 100%|██████████| 28/28 [00:07<00:00,  3.60it/s]



    Thu Feb 18 01:56:16 2021 

    Fold:0, Epoch:1, LR:0.0009919648, Cons. Weight: 3.1047958479329627

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.9667   |   4.3109

    LWLRAP:              0.5384   |   0.6389

    Class Loss:          4.9266   |   4.2456

    Consistency Loss:    0.0401   |   0.0653

    --------------------------------------------------------

    
    LWLRAP Improved from 0.4273342285131242 --> 0.6389454079279313



Epoch:2 - Loss:4.0934: 100%|██████████| 113/113 [00:35<00:00,  3.14it/s]
Epoch:2 - Loss:4.2757: 100%|██████████| 28/28 [00:11<00:00,  2.51it/s]



    Thu Feb 18 01:57:04 2021 

    Fold:0, Epoch:2, LR:0.0009819814, Cons. Weight: 10.836802322189582

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.0934   |   4.2757

    LWLRAP:              0.6608   |   0.7071

    Class Loss:          3.9388   |   4.1024

    Consistency Loss:    0.1546   |   0.1732

    --------------------------------------------------------

    
    LWLRAP Improved from 0.6389454079279313 --> 0.7071110243194867



Epoch:3 - Loss:3.5930: 100%|██████████| 113/113 [00:36<00:00,  3.07it/s]
Epoch:3 - Loss:3.5951: 100%|██████████| 28/28 [00:12<00:00,  2.31it/s]



    Thu Feb 18 01:57:53 2021 

    Fold:0, Epoch:3, LR:0.0009681174, Cons. Weight: 28.650479686019008

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.5930   |   3.5951

    LWLRAP:              0.7329   |   0.7691

    Class Loss:          3.2493   |   3.1102

    Consistency Loss:    0.3437   |   0.4849

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7071110243194867 --> 0.7690916401268825



Epoch:4 - Loss:3.2782: 100%|██████████| 113/113 [00:36<00:00,  3.13it/s]
Epoch:4 - Loss:3.4264: 100%|██████████| 28/28 [00:12<00:00,  2.23it/s]



    Thu Feb 18 01:58:42 2021 

    Fold:0, Epoch:4, LR:0.0009504844, Cons. Weight: 57.375342073743276

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.2782   |   3.4264

    LWLRAP:              0.7910   |   0.8034

    Class Loss:          2.7356   |   2.6367

    Consistency Loss:    0.5426   |   0.7896

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7690916401268825 --> 0.8034350693948908



Epoch:5 - Loss:3.2192: 100%|██████████| 113/113 [00:38<00:00,  2.92it/s]
Epoch:5 - Loss:4.0853: 100%|██████████| 28/28 [00:12<00:00,  2.31it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 01:59:34 2021 

    Fold:0, Epoch:5, LR:0.0009292244, Cons. Weight: 87.03247258333906

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.2192   |   4.0853

    LWLRAP:              0.8054   |   0.7730

    Class Loss:          2.4769   |   2.8822

    Consistency Loss:    0.7423   |   1.2031

    --------------------------------------------------------

    


Epoch:6 - Loss:3.0912: 100%|██████████| 113/113 [00:36<00:00,  3.07it/s]
Epoch:6 - Loss:4.3820: 100%|██████████| 28/28 [00:12<00:00,  2.28it/s]



    Thu Feb 18 02:00:23 2021 

    Fold:0, Epoch:6, LR:0.0009045085, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.0912   |   4.3820

    LWLRAP:              0.8320   |   0.8050

    Class Loss:          2.2786   |   2.9483

    Consistency Loss:    0.8126   |   1.4337

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8034350693948908 --> 0.8050310066829891



Epoch:7 - Loss:2.7883: 100%|██████████| 113/113 [00:39<00:00,  2.84it/s]
Epoch:7 - Loss:4.0450: 100%|██████████| 28/28 [00:12<00:00,  2.23it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:01:16 2021 

    Fold:0, Epoch:7, LR:0.0008765357, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.7883   |   4.0450

    LWLRAP:              0.8566   |   0.8038

    Class Loss:          1.9957   |   2.9127

    Consistency Loss:    0.7926   |   1.1322

    --------------------------------------------------------

    


Epoch:8 - Loss:2.7856: 100%|██████████| 113/113 [00:38<00:00,  2.95it/s]
Epoch:8 - Loss:4.2154: 100%|██████████| 28/28 [00:12<00:00,  2.29it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:02:07 2021 

    Fold:0, Epoch:8, LR:0.0008455313, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.7856   |   4.2154

    LWLRAP:              0.8642   |   0.8029

    Class Loss:          1.9142   |   3.0602

    Consistency Loss:    0.8714   |   1.1552

    --------------------------------------------------------

    


Epoch:9 - Loss:2.7279: 100%|██████████| 113/113 [00:37<00:00,  3.04it/s]
Epoch:9 - Loss:3.7756: 100%|██████████| 28/28 [00:11<00:00,  2.54it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:02:55 2021 

    Fold:0, Epoch:9, LR:0.0008117449, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.7279   |   3.7756

    LWLRAP:              0.8635   |   0.7944

    Class Loss:          1.8965   |   2.6722

    Consistency Loss:    0.8314   |   1.1034

    --------------------------------------------------------

    


Epoch:10 - Loss:2.6987: 100%|██████████| 113/113 [00:35<00:00,  3.21it/s]
Epoch:10 - Loss:3.6911: 100%|██████████| 28/28 [00:10<00:00,  2.80it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:03:41 2021 

    Fold:0, Epoch:10, LR:0.0007754485, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.6987   |   3.6911

    LWLRAP:              0.8611   |   0.8021

    Class Loss:          1.8778   |   2.5780

    Consistency Loss:    0.8209   |   1.1130

    --------------------------------------------------------

    


Epoch:11 - Loss:2.1878: 100%|██████████| 113/113 [00:37<00:00,  3.01it/s]
Epoch:11 - Loss:3.5718: 100%|██████████| 28/28 [00:10<00:00,  2.64it/s]



    Thu Feb 18 02:04:29 2021 

    Fold:0, Epoch:11, LR:0.0007369343, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.1878   |   3.5718

    LWLRAP:              0.8997   |   0.8145

    Class Loss:          1.4844   |   2.5593

    Consistency Loss:    0.7034   |   1.0124

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8050310066829891 --> 0.8144796782689159



Epoch:12 - Loss:2.0438: 100%|██████████| 113/113 [00:38<00:00,  2.91it/s]
Epoch:12 - Loss:3.4218: 100%|██████████| 28/28 [00:11<00:00,  2.54it/s]



    Thu Feb 18 02:05:19 2021 

    Fold:0, Epoch:12, LR:0.0006965125, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.0438   |   3.4218

    LWLRAP:              0.9027   |   0.8201

    Class Loss:          1.3786   |   2.4656

    Consistency Loss:    0.6652   |   0.9562

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8144796782689159 --> 0.8200629784754965



Epoch:13 - Loss:2.0790: 100%|██████████| 113/113 [00:40<00:00,  2.78it/s]
Epoch:13 - Loss:3.3578: 100%|██████████| 28/28 [00:11<00:00,  2.49it/s]



    Thu Feb 18 02:06:12 2021 

    Fold:0, Epoch:13, LR:0.0006545085, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.0790   |   3.3578

    LWLRAP:              0.9014   |   0.8415

    Class Loss:          1.4054   |   2.3994

    Consistency Loss:    0.6736   |   0.9584

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8200629784754965 --> 0.8415035338643186



Epoch:14 - Loss:2.0164: 100%|██████████| 113/113 [00:38<00:00,  2.92it/s]
Epoch:14 - Loss:3.2189: 100%|██████████| 28/28 [00:10<00:00,  2.62it/s]



    Thu Feb 18 02:07:02 2021 

    Fold:0, Epoch:14, LR:0.0006112605, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.0164   |   3.2189

    LWLRAP:              0.9008   |   0.8478

    Class Loss:          1.3689   |   2.5402

    Consistency Loss:    0.6475   |   0.6786

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8415035338643186 --> 0.847816463050838



Epoch:15 - Loss:1.7302: 100%|██████████| 113/113 [00:38<00:00,  2.93it/s]
Epoch:15 - Loss:2.9640: 100%|██████████| 28/28 [00:10<00:00,  2.64it/s]



    Thu Feb 18 02:07:52 2021 

    Fold:0, Epoch:15, LR:0.0005671166, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.7302   |   2.9640

    LWLRAP:              0.9253   |   0.8619

    Class Loss:          1.1043   |   2.1574

    Consistency Loss:    0.6260   |   0.8066

    --------------------------------------------------------

    
    LWLRAP Improved from 0.847816463050838 --> 0.8618898791839968



Epoch:16 - Loss:1.6336: 100%|██████████| 113/113 [00:39<00:00,  2.83it/s]
Epoch:16 - Loss:3.1836: 100%|██████████| 28/28 [00:10<00:00,  2.65it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:08:43 2021 

    Fold:0, Epoch:16, LR:0.0005224324, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.6336   |   3.1836

    LWLRAP:              0.9335   |   0.8404

    Class Loss:          1.0315   |   2.4641

    Consistency Loss:    0.6021   |   0.7194

    --------------------------------------------------------

    


Epoch:17 - Loss:1.7780: 100%|██████████| 113/113 [00:39<00:00,  2.88it/s]
Epoch:17 - Loss:3.1222: 100%|██████████| 28/28 [00:10<00:00,  2.74it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:09:32 2021 

    Fold:0, Epoch:17, LR:0.0004775676, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.7780   |   3.1222

    LWLRAP:              0.9204   |   0.8515

    Class Loss:          1.1039   |   2.2622

    Consistency Loss:    0.6741   |   0.8600

    --------------------------------------------------------

    


Epoch:18 - Loss:1.6232: 100%|██████████| 113/113 [00:38<00:00,  2.95it/s]
Epoch:18 - Loss:3.0522: 100%|██████████| 28/28 [00:09<00:00,  2.82it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:10:21 2021 

    Fold:0, Epoch:18, LR:0.0004328834, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.6232   |   3.0522

    LWLRAP:              0.9296   |   0.8459

    Class Loss:          0.9853   |   2.2913

    Consistency Loss:    0.6378   |   0.7609

    --------------------------------------------------------

    


Epoch:19 - Loss:1.4591: 100%|██████████| 113/113 [00:39<00:00,  2.86it/s]
Epoch:19 - Loss:3.2081: 100%|██████████| 28/28 [00:09<00:00,  2.98it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:11:10 2021 

    Fold:0, Epoch:19, LR:0.0003887395, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.4591   |   3.2081

    LWLRAP:              0.9407   |   0.8448

    Class Loss:          0.8997   |   2.3814

    Consistency Loss:    0.5594   |   0.8266

    --------------------------------------------------------

    


Epoch:20 - Loss:1.3531: 100%|██████████| 113/113 [00:38<00:00,  2.90it/s]
Epoch:20 - Loss:2.6256: 100%|██████████| 28/28 [00:09<00:00,  3.04it/s]



    Thu Feb 18 02:11:58 2021 

    Fold:0, Epoch:20, LR:0.0003454915, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.3531   |   2.6256

    LWLRAP:              0.9489   |   0.8668

    Class Loss:          0.7806   |   2.0437

    Consistency Loss:    0.5726   |   0.5819

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8618898791839968 --> 0.8668118216109286



Epoch:21 - Loss:1.1925: 100%|██████████| 113/113 [00:40<00:00,  2.80it/s]
Epoch:21 - Loss:2.8851: 100%|██████████| 28/28 [00:08<00:00,  3.16it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:12:48 2021 

    Fold:0, Epoch:21, LR:0.0003034875, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1925   |   2.8851

    LWLRAP:              0.9608   |   0.8529

    Class Loss:          0.6515   |   2.1019

    Consistency Loss:    0.5410   |   0.7832

    --------------------------------------------------------

    


Epoch:22 - Loss:1.2582: 100%|██████████| 113/113 [00:40<00:00,  2.79it/s]
Epoch:22 - Loss:2.8728: 100%|██████████| 28/28 [00:09<00:00,  2.97it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:13:38 2021 

    Fold:0, Epoch:22, LR:0.0002630657, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2582   |   2.8728

    LWLRAP:              0.9520   |   0.8651

    Class Loss:          0.6995   |   2.1192

    Consistency Loss:    0.5587   |   0.7536

    --------------------------------------------------------

    


Epoch:23 - Loss:1.1411: 100%|██████████| 113/113 [00:40<00:00,  2.79it/s]
Epoch:23 - Loss:3.1869: 100%|██████████| 28/28 [00:09<00:00,  2.99it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:14:28 2021 

    Fold:0, Epoch:23, LR:0.0002245515, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1411   |   3.1869

    LWLRAP:              0.9616   |   0.8608

    Class Loss:          0.6171   |   2.4599

    Consistency Loss:    0.5241   |   0.7270

    --------------------------------------------------------

    


Epoch:24 - Loss:1.1124: 100%|██████████| 113/113 [00:40<00:00,  2.76it/s]
Epoch:24 - Loss:2.6663: 100%|██████████| 28/28 [00:09<00:00,  2.90it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:15:18 2021 

    Fold:0, Epoch:24, LR:0.0001882551, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1124   |   2.6663

    LWLRAP:              0.9617   |   0.8630

    Class Loss:          0.5676   |   1.9665

    Consistency Loss:    0.5448   |   0.6997

    --------------------------------------------------------

    


Epoch:25 - Loss:1.2087: 100%|██████████| 113/113 [00:40<00:00,  2.77it/s]
Epoch:25 - Loss:2.9754: 100%|██████████| 28/28 [00:09<00:00,  3.09it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:16:09 2021 

    Fold:0, Epoch:25, LR:0.0001544687, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2087   |   2.9754

    LWLRAP:              0.9587   |   0.8346

    Class Loss:          0.6319   |   2.3110

    Consistency Loss:    0.5768   |   0.6644

    --------------------------------------------------------

    


Epoch:26 - Loss:0.9711: 100%|██████████| 113/113 [00:41<00:00,  2.74it/s]
Epoch:26 - Loss:2.8723: 100%|██████████| 28/28 [00:09<00:00,  3.04it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:16:59 2021 

    Fold:0, Epoch:26, LR:0.0001234643, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9711   |   2.8723

    LWLRAP:              0.9684   |   0.8424

    Class Loss:          0.4877   |   2.3186

    Consistency Loss:    0.4834   |   0.5537

    --------------------------------------------------------

    


Epoch:27 - Loss:1.0107: 100%|██████████| 113/113 [00:41<00:00,  2.71it/s]
Epoch:27 - Loss:2.7256: 100%|██████████| 28/28 [00:08<00:00,  3.11it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:17:50 2021 

    Fold:0, Epoch:27, LR:9.54915e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0107   |   2.7256

    LWLRAP:              0.9598   |   0.8485

    Class Loss:          0.5080   |   2.1836

    Consistency Loss:    0.5028   |   0.5420

    --------------------------------------------------------

    


Epoch:28 - Loss:0.9335: 100%|██████████| 113/113 [00:40<00:00,  2.79it/s]
Epoch:28 - Loss:3.1969: 100%|██████████| 28/28 [00:08<00:00,  3.21it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:18:39 2021 

    Fold:0, Epoch:28, LR:7.07756e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9335   |   3.1969

    LWLRAP:              0.9726   |   0.8391

    Class Loss:          0.4801   |   2.6064

    Consistency Loss:    0.4534   |   0.5905

    --------------------------------------------------------

    


Epoch:29 - Loss:1.0337: 100%|██████████| 113/113 [00:40<00:00,  2.81it/s]
Epoch:29 - Loss:2.9282: 100%|██████████| 28/28 [00:08<00:00,  3.16it/s]



    Thu Feb 18 02:19:28 2021 

    Fold:0, Epoch:29, LR:4.951557e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0337   |   2.9282

    LWLRAP:              0.9629   |   0.8695

    Class Loss:          0.5449   |   2.3734

    Consistency Loss:    0.4888   |   0.5548

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8668118216109286 --> 0.8694790680404716



Epoch:30 - Loss:0.8274: 100%|██████████| 113/113 [00:40<00:00,  2.77it/s]
Epoch:30 - Loss:2.4010: 100%|██████████| 28/28 [00:08<00:00,  3.22it/s]



    Thu Feb 18 02:20:18 2021 

    Fold:0, Epoch:30, LR:3.188256e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.8274   |   2.4010

    LWLRAP:              0.9782   |   0.8766

    Class Loss:          0.3849   |   1.8373

    Consistency Loss:    0.4425   |   0.5637

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8694790680404716 --> 0.8766118836915298



Epoch:31 - Loss:0.9392: 100%|██████████| 113/113 [00:39<00:00,  2.86it/s]
Epoch:31 - Loss:3.2217: 100%|██████████| 28/28 [00:09<00:00,  3.05it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:21:08 2021 

    Fold:0, Epoch:31, LR:1.801857e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9392   |   3.2217

    LWLRAP:              0.9716   |   0.8470

    Class Loss:          0.5049   |   2.7542

    Consistency Loss:    0.4343   |   0.4675

    --------------------------------------------------------

    


Epoch:32 - Loss:0.8652: 100%|██████████| 113/113 [00:40<00:00,  2.80it/s]
Epoch:32 - Loss:2.7599: 100%|██████████| 28/28 [00:09<00:00,  2.88it/s]



    Thu Feb 18 02:21:58 2021 

    Fold:0, Epoch:32, LR:8.035206e-06, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.8652   |   2.7599

    LWLRAP:              0.9711   |   0.8812

    Class Loss:          0.4233   |   2.2295

    Consistency Loss:    0.4419   |   0.5303

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8766118836915298 --> 0.8811676464307696



Epoch:33 - Loss:0.9626: 100%|██████████| 113/113 [00:39<00:00,  2.89it/s]
Epoch:33 - Loss:2.8398: 100%|██████████| 28/28 [00:09<00:00,  2.87it/s]
  0%|          | 0/113 [00:00<?, ?it/s]


    Thu Feb 18 02:22:47 2021 

    Fold:0, Epoch:33, LR:2.012853e-06, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9626   |   2.8398

    LWLRAP:              0.9690   |   0.8691

    Class Loss:          0.4566   |   2.2723

    Consistency Loss:    0.5060   |   0.5675

    --------------------------------------------------------

    


Epoch:34 - Loss:0.8732: 100%|██████████| 113/113 [00:38<00:00,  2.95it/s]
Epoch:34 - Loss:2.9689: 100%|██████████| 28/28 [00:08<00:00,  3.28it/s]



    Thu Feb 18 02:23:34 2021 

    Fold:0, Epoch:34, LR:0.0, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.8732   |   2.9689

    LWLRAP:              0.9716   |   0.8673

    Class Loss:          0.4522   |   2.3302

    Consistency Loss:    0.4210   |   0.6388

    --------------------------------------------------------

    


Epoch:0 - Loss:7.6153:  67%|██████▋   | 76/113 [00:14<00:04,  7.72it/s]

# Predict on Test Set
We'll predict using the teacher model but you could also use the student or a combination of the two. Inference works just like it would for a vanilla baseline model.

In [None]:
def test(test_df, train_FSLfold):
    test_dataset = TestDataset(
        df=test_df,
        data_path="/media/paniquex/samsung_2tb/rfcx_kaggle/rfcx-species-audio-detection/test",
        period=config.period_val,
        step=config.step
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=config.num_workers
    )
    
    weights_path = os.path.join(config.save_path, f'fold-{train_fold}.bin')
    model = get_model()
    model.load_state_dict(torch.load(weights_path, map_location=config.device), strict=False)
    
    test_pred, ids = predict_on_test(model, test_loader)

    # Build Submission File
    test_pred_df = pd.DataFrame({
        "recording_id": test_df.recording_id.values
    })
    target_cols = test_df.columns[1:].values.tolist()
    test_pred_df = test_pred_df.join(pd.DataFrame(np.array(test_pred), 
                                                  columns=target_cols))
    test_pred_df.to_csv(os.path.join(config.save_path, 
                                     f"fold-{train_fold}-submission.csv"), 
                        index=False)
    
    
test_df = pd.read_csv(config.test_csv)
for fold in range(5 if config.train_5_folds else 1):
    test(test_df, fold)

## 5 Fold Ensemble
For 5 fold runs, we'll create a single ensemble prediction by simply averaging all of the folds.

In [None]:
def ensemble(submission_path):
    dfs = [pd.read_csv(os.path.join(
        config.save_path, f"fold-{i}-submission.csv")) for i in range(5)]
    anchor = dfs[0].copy()
    cols = anchor.columns[1:]
   
    for c in cols:
        total = 0
        for df in dfs:
            total += df[c]
        anchor[c] = total / len(dfs)
    anchor.to_csv(submission_path, index=False)


submission_path = os.path.join(config.save_path, f"submission.csv")
if config.train_5_folds:
    ensemble(submission_path)
else:
    fold0_submission = os.path.join(config.save_path, f"fold-0-submission.csv")
    os.rename(fold0_submission, submission_path)
    

# Conclusion 
Thanks for reading! I dropped some unrelated tricks from this and didn't spend much time tuning so there's almost definetely room for improvement.

I know it's pretty late in the competition for new notebooks, but considering that there are a few other public notebooks that score higher, I'm hoping this won't cause a significant shakeup. 