# AGENDA
#### TO DO LIST
1. Plotting confusion matrix, plot prediction for top losses (need typedispatch on AudioTensor)

#### Completed
1. [11/05/2021] Set up stratified KFold
2. [11/05/2021] Crop training set data
3. [14/05/2021] Try multiclass v.s. multilabel
4. [17/05/2021] Prototyped W&B integration
5. [17/05/2021] Fixed `Normalize` is missing in both train and valid
6. [23/05/2021] Fixed sample rate != 48000 error
7. [23/05/2021] Fixed imagenet normalization on train+val 
8. [03/06/2021] Prototype of Pseudo Labeling

#### Reference
1. code starts from: https://www.kaggle.com/scart97/fastaudio-starter-kit/notebook

### 0. Install Packages

In [None]:
!pip install --up../input/rfcx-species-audio-detection/he-torch
# specify version to resolve version compability
!ltt install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
#!pip install --upgrade git+https://github.com/fastaudio/fastaudio.git
!pip install fastaudio==1.0.0
!pip install wandb --upgrade

In [None]:
import pkg_resources

def placeholder(x):
    raise pkg_resources.DistributionNotFound
pkg_resources.get_distribution = placeholder

In [None]:
from enum import Enum
from math import ceil

import wandb
from kaggle_secrets import UserSecretsClient

import pandas as pd
from fastaudio.all import *
from fastai.vision.all import *
from fastai.callback.wandb import *

import torch.nn as nn
# migrate to sox_io
import torchaudio
torchaudio.set_audio_backend("sox_io")

from sklearn.model_selection import StratifiedKFold

# check cuda is available
import torch
print(torch.cuda.is_available())

In [None]:
import torch
import fastai
import fastaudio
import torchaudio
print(f'torch version: {torch.__version__}')
print(f'torchaudio version: {torchaudio.__version__}')
print(f'fastaudio version: {fastaudio.__version__}')
print(f'fastai version: {fastai.__version__}')

In [None]:
# fix failure to preserve metadata in multiprocessing in DataLoader
def _rebuild_from_type(func, type, args, dict):
    ret = func(*args).as_subclass(type)
    ret.__dict__ = dict
    return ret

@patch
def __reduce_ex__(self: TensorBase, proto):
    from fastai.torch_core import _fa_rebuild_qtensor, _fa_rebuild_tensor
    torch.utils.hooks.warn_if_has_hooks(self)
    args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride())
    if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point())
    args = args + (self.requires_grad, OrderedDict())
    f = _fa_rebuild_qtensor if self.is_quantized else  _fa_rebuild_tensor
    return (_rebuild_from_type, (f, type(self), args, self.__dict__))

### 1a. Configuration

In [None]:
class LossFunction(str, Enum):
    FOCAL_LOSS = 'focal_loss'
    BCE_LOGIT_LOSS = 'bce_logit_loss'
    CE_SOFTMAX_LOSS = 'ce_softmax_loss'
    MASKED_BCE_WITH_TPFP = 'masked_bce_with_tpfp'


class RecordMetric(str, Enum):
    LWRAP = 'LWRAP'
    VALID_LOSS = 'valid_loss'
    

class Normalizer(str, Enum):
    IMAGENET_STATS = 'imagenet_stats'
    SAMPLE_STATS = 'sample_stats'

    
class Channel(int, Enum):
    SINGLE = 1
    THREE = 3

    
assert LossFunction.FOCAL_LOSS == 'focal_loss'

In [None]:
# W&B CONFIG
DISABLE_WB = False
WB_CONFIG = {
    "project": "RFCX Experiment Tracker", 
    "name": "baseline with n_mels=256, masked bce with TPFP, monitor LWRAP (10+30)",
    "notes": "rerun baseline again, without mixup and audio aug, n_mels=256, masked bce loss with TPFP, monitor LWRAP (10+30)",
    "job_type": "train",
    "tags": ['resnet34', "TPFP", "Masked BCE loss", "imagenet norm"]
}


# DATA CONFIG
DATA_DIR = Path('../input/rfcx-species-audio-detection')
MODEL_DIR = Path('../input/fastiai-fastaudio-rainforest-starter')


# MORE DATA
IS_PSEUDO_LABEL = False
PSEUDO_LABEL_CSV = Path('../input/rainforest-model-checkpoints/df_train_pseudo_3s.csv')


# AUDIO TRANSFORM CONFIG
NORMALIZER = Normalizer.IMAGENET_STATS.value
SR = 48000
INTERVAL = 10 # crop len in sec
RANDOM_INTERVAL = 8 # random crop among 10 sec
MELSPEC_CONFIG = {
    'n_fft': 2048,
    'sample_rate': SR,
    'n_mels': 256
#     'mel': False,
#     'f_max': 14000
}
IS_MIN_MAX_RESCALE = False
RESIZE_MEL = None


# TRAINING SCHEME CONFIG
IS_AUDIO_AUG = False
IS_MIXUP = False
NUM_WORKERS = 0
CHANNEL = Channel.SINGLE.value
LOSS_FUNCTION = LossFunction.MASKED_BCE_WITH_TPFP.value # 'celoss'/ 'focal'
IS_MULTILABEL = True
MONITOR_METRIC = RecordMetric.LWRAP.value
FOLD_ID = 0
BATCH_SIZE = 128
# multilabel: 0.033, multiclass: 0.0132
#LR = 0.033 if IS_MULTILABEL else 0.0132
# Masked BCE Loss: lr_min=0.33, lr_steep=0.0052
LR = 5e-2
FREEZE_EPOCH = 10
UNFREEZE_EPOCH = 30


RANDOM_SEED = 144


# https://www.kaggle.com/c/rfcx-species-audio-detection/discussion/220389
# MELSPEC_CONFIG = {
#     'sample_rate': 32000,
#     'n_fft': 2048,
#     'n_mels': 384,
#     'hop_length': 512,
#     'win_length': 2048
# }

### 1b. Preset Files and DataFrame

In [None]:
# extract dataset
train_path = DATA_DIR/ 'train'
test_path = DATA_DIR/ 'test'
train_fns = get_audio_files(train_path)

# massage df
df_train_tp = pd.read_csv(DATA_DIR/ 'train_tp.csv')
df_train_tp['recording_id'] = df_train_tp['recording_id'].map(lambda x: 'train/' + x)

# K FOLD STRATIFICATION
df_train_tp['species_id'] = df_train_tp['species_id'].astype(str)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
for fold_id, (train_idxs, val_idxs) in enumerate(skf.split(df_train_tp.recording_id.values, df_train_tp.species_id.values)):
    kfold_col = f'fold_{fold_id}'
    df_train_tp[kfold_col] = 0
    df_train_tp.loc[val_idxs, kfold_col] = 1
target_fold = f'fold_{FOLD_ID}'

if LOSS_FUNCTION == LossFunction.MASKED_BCE_WITH_TPFP.value:
    print('Concatenating TPFP when using Masked Loss')
    df_train_fp = pd.read_csv(DATA_DIR/ 'train_fp.csv')
    df_train_fp['recording_id'] = df_train_fp['recording_id'].map(lambda x: 'train/' + x)
    df_train_fp['species_id'] = df_train_fp['species_id'].astype(str)
    for fold_id in range(5):
        df_train_fp[f'fold_{fold_id}'] = 0
    df_train_tp['label_type'] = 'TP'
    df_train_fp['label_type'] = 'FP'
    df_train_tp = pd.concat([df_train_tp, df_train_fp]).reset_index(drop=True)
    print(f'DataFrame size after concat TPFP: {df_train_tp.shape[0]}')
    print(f'# Val after concat TPFP: {df_train_tp[df_train_tp[target_fold] == 1].shape[0]}')
    

# sanity check
assert DATA_DIR.is_dir()
assert MODEL_DIR.is_dir()

In [None]:
df_train_tp[df_train_tp[target_fold]==1].head(3)

### 2a. Sample Create Cropped AudioTensor

In [None]:
SAMPLE_IDX = 1
test_row = df_train_tp.loc[SAMPLE_IDX]
test_row

In [None]:
def create_audio_tensor(row, is_truncate=True):
    if is_truncate:
        t_min, t_max = row.t_min, row.t_max
        center = (t_max + t_min) / 2
        start_t = center - (INTERVAL/2.)
        _frame_offset = int(max(0, start_t) * SR)
        _num_frames = int(INTERVAL * SR)
    else:
        _frame_offset = 0
        _num_frames = -1
    
    # debug
    #print(f'offset: {_frame_offset}')
    #print(f'num_frames: {_num_frames}')
    audio_fn = DATA_DIR.resolve()/f'{row.recording_id}.flac'
    audio = AudioTensor.create(audio_fn,
                               frame_offset=_frame_offset, 
                               num_frames=_num_frames)
    return audio

In [None]:
truncated_audio = create_audio_tensor(test_row, is_truncate=True)
orig_audio = create_audio_tensor(test_row, is_truncate=False)

assert truncated_audio.shape[1] <= INTERVAL * SR
assert orig_audio.shape[1] == 60 * SR

In [None]:
truncated_audio.show();

In [None]:
orig_audio.show();

### 2b. Set up `DataBlock`, `DataLoaders`

In [None]:
# define custom block for audio
if NORMALIZER == Normalizer.IMAGENET_STATS.value:
    batch_tfms = IntToFloatTensor
    
elif NORMALIZER == Normalizer.SAMPLE_STATS.value:
    class SampleNormalize(Transform):
        order=99
        def encodes(self, o:AudioSpectrogram):
            # (BS, C, W, H)
            means = o.mean(axis=(1,2,3), keepdim=True)
            stds = o.std(axis=(1,2,3), keepdim=True)
            return (o-means)/stds
    batch_tfms = SampleNormalize
    
else:
    raise NotImplementedError
    
class CustomAudioBlock(TransformBlock):
    "A `TransformBlock` for audios"
    @delegates(audio_item_tfms)
    def __init__(self, batch_tfms, cache_folder=None, **kwargs):
        item_tfms = audio_item_tfms(**kwargs)
        type_tfm = None
        return super().__init__(type_tfms=type_tfm, 
                                item_tfms=item_tfms, 
                                batch_tfms=batch_tfms)

    
vocab = list(map(str, range(24)))
if LOSS_FUNCTION == LossFunction.MASKED_BCE_WITH_TPFP.value:
    def label_encodes_dict(self, d: dict):
        o = d['label']
        if not all(elem in self.vocab.o2i.keys() for elem in o):
            diff = [elem for elem in o if elem not in self.vocab.o2i.keys()]
            diff_str = "', '".join(diff)
            raise KeyError(f"Labels '{diff_str}' were not included in the training dataset")
        t = TensorMultiCategory([self.vocab.o2i[o_] for o_ in o])
        t.label_type = d['type']
        return t
    
    def one_hot_with_metadata(self, o: TensorMultiCategory):
        t = TensorMultiCategory(one_hot(o, self.c).float())
        t.__dict__ = o.__dict__
        return t
    
    class NegateFP(Transform):
        order = OneHotEncode.order + 1
        def encodes(self, o: TensorMultiCategory):
            assert getattr(o, 'label_type')
            assert o.label_type in ('TP', 'FP')
            if o.label_type == 'FP':
                o *= -1
            return o

        def decodes(self, o: TensorMultiCategory):
            assert getattr(o, 'label_type')
            assert o.label_type in ('TP', 'FP')
            if o.label_type == 'FP':
                o *= -1
            return o        
    
    def MultiCategoryBlockForMetadata(vocab):
        tfm = [MultiCategorize(vocab=vocab), OneHotEncode, NegateFP]
        return TransformBlock(type_tfms=tfm)
    
    MultiCategorize.encodes.add(label_encodes_dict)
    OneHotEncode.encodes.add(one_hot_with_metadata)
    blocks = (
        partial(CustomAudioBlock, batch_tfms=batch_tfms, sample_rate=SR),
        partial(MultiCategoryBlockForMetadata, vocab=vocab)
    )

elif IS_MULTILABEL:
    blocks = (
        partial(CustomAudioBlock, batch_tfms=batch_tfms, sample_rate=SR), 
        partial(MultiCategoryBlock, vocab=vocab)
    )
else:
    blocks = (
        partial(CustomAudioBlock, batch_tfms=batch_tfms, sample_rate=SR),
        partial(CategoryBlock, vocab=vocab)
    )


# duration in ms
item_tfms = [ResizeSignal(duration=RANDOM_INTERVAL*1000, pad_mode=AudioPadType.Repeat)]
if IS_AUDIO_AUG:
    # data augmentation on raw audio
    data_augmentation = [
        AddNoise(color=NoiseColor.White, noise_level=0.1), 
        SignalShifter(max_pct=0.3)
    ]
    item_tfms += data_augmentation
    


# batch/ datablock setup
db_batch_tfms = []
db_batch_tfms += [
    AudioToSpec.from_cfg(
        AudioConfig.BasicMelSpectrogram(**MELSPEC_CONFIG)
    )
]

if IS_MIN_MAX_RESCALE:
    class MinMaxRescale(Transform):
        def encodes(self, o: AudioSpectrogram):
            BS, C, W, H = o.shape
            _min = o.reshape(BS, C, W*H).min(axis=2, keepdim=True)[0][:,:,:,None]
            _max = o.reshape(BS, C, W*H).max(axis=2, keepdim=True)[0][:,:,:,None]
            return (o - _min) / _max
    db_batch_tfms += [MinMaxRescale()]
    
if RESIZE_MEL is not None:
    assert isinstance(RESIZE_MEL, tuple)
    resize_tfms = TfmResizeGPU(size=RESIZE_MEL)
    db_batch_tfms += [resize_tfms]

if CHANNEL == Channel.THREE.value:
    class SpectrogramStacker(Transform):
        def encodes(self, o: AudioSpectrogram):
            return o.expand(-1, 3, -1, -1)
        def decodes(self, o: AudioSpectrogram):
            return o[:, :1, :, :]
    db_batch_tfms += [SpectrogramStacker()]


#  val split
if IS_PSEUDO_LABEL:
    df_pseudo = pd.read_csv(PSEUDO_LABEL_CSV)
    df_pseudo.recording_id = df_pseudo.recording_id.apply(lambda _id: f'train/{_id}')
    for _idx in range(5):
        df_pseudo[f'fold_{_idx}'] = None
    df_train_tp = pd.concat([df_train_tp, df_pseudo])
    df_train_tp = df_train_tp.reset_index(drop=True)
    df_train_tp['species_id'] = df_train_tp['species_id'].astype(str)
val_idxs = df_train_tp[df_train_tp[target_fold] == 1].index.tolist()
splitter=IndexSplitter(val_idxs)


# get label from DataFrame row
if LOSS_FUNCTION == LossFunction.MASKED_BCE_WITH_TPFP.value:
    def label_to_list_v2(row):
        """ propagate metadata to pipeline """
        y_reader = ColReader('species_id')
        label = y_reader(row)
        try:
            label_type = row.label_type
        except:
            label_type = 'TP'
        return {'label': [label], 'type': label_type}
    
    y_getter = label_to_list_v2
else:
    y_reader = ColReader('species_id')
    
    def label_to_list(row):
        label = y_reader(row)
        return [label]

    y_getter = label_to_list if IS_MULTILABEL else y_reader


# get the datablock
datablock = DataBlock(
    blocks=blocks,
    item_tfms=item_tfms,
    batch_tfms=db_batch_tfms,
    get_x=create_audio_tensor,
    get_y=y_getter,
    #splitter=RandomSplitter(valid_pct=0.2, seed=RANDOM_SEED)
    splitter=splitter
)

In [None]:
dls = datablock.dataloaders(source=df_train_tp, 
                            bs=BATCH_SIZE,
                            num_workers=NUM_WORKERS)
one_batch = dls.one_batch()

# sanity check
assert one_batch[0].shape[0] == BATCH_SIZE
# no OneHot transform in CategoryBlock
# because torch.nn.CrossEntropyLoss not require one hot encode
for idx in [1, 3, 6]:
    test_row = dls.train.items.loc[idx]
    gt_category = int(test_row.species_id)
    tfms_category = dls.train.tfms[1](test_row)
    if IS_MULTILABEL:
        assert tfms_category.sum().numpy() == 1.
        assert tfms_category[gt_category].numpy() == 1.
    else:
        assert isinstance(tfms_category, TensorCategory)

print(len(dls.train.items))
print(len(dls.valid.items))
print('DataLoader transform completed and checked')

In [None]:
# if sample norm, its scale may not look right coz its not reversable
dls.show_batch(ncols=2, nrows=3, figsize=(15, 10))

In [None]:
if CHANNEL == Channel.SINGLE.value:
    assert one_batch[0].shape[1] == 1
elif CHANNEL == Channel.THREE.value:
    assert one_batch[0].shape[1] == 3
else:
    raise NotImplementedError
assert one_batch[0].sr == SR
one_batch[0].shape, type(one_batch[0]), dls.num_workers, dls.after_batch

### 3a. Set up `Learner` and `AccumMetric`
- nn.BCEWithLogitsLoss: for each instance, sigmoid for each class, BCE for each class, average across classes

In [None]:
# source: https://www.kaggle.com/c/rfcx-species-audio-detection/discussion/198418
def LWRAP(preds, labels):
    
    # labels: (BS, ) for mutliclass, (BS, C) for multilabel
    if not IS_MULTILABEL:
        labels = torch.nn.functional.one_hot(labels, 24)
        
    # Ranks of the predictions
    ranked_classes = torch.argsort(preds, dim=-1, descending=True)
    # i, j corresponds to rank of prediction in row i
    class_ranks = torch.zeros_like(ranked_classes)
    for i in range(ranked_classes.size(0)):
        for j in range(ranked_classes.size(1)):
            class_ranks[i, ranked_classes[i][j]] = j + 1
    # Mask out to only use the ranks of relevant GT labels
    ground_truth_ranks = class_ranks * labels + (1e6) * (1 - labels)
    # All the GT ranks are in front now
    sorted_ground_truth_ranks, _ = torch.sort(ground_truth_ranks, dim=-1, descending=False)
    # Number of GT labels per instance
    num_labels = labels.sum(-1)
    pos_matrix = torch.tensor(np.array([i+1 for i in range(labels.size(-1))])).unsqueeze(0)
    score_matrix = pos_matrix / sorted_ground_truth_ranks
    score_mask_matrix, _ = torch.sort(labels, dim=-1, descending=True)
    scores = score_matrix * score_mask_matrix
    score = scores.sum() / labels.sum()
    return score.item()

activation_type = 'Softmax' if IS_MULTILABEL else 'Sigmoid'
lwrap_metric = AccumMetric(
    func=LWRAP, activation=activation_type,
    to_np=False, flatten=False
)

In [None]:
# source: https://www.kaggle.com/c/rfcx-species-audio-detection/discussion/213075
class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma, alpha=1.):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.loss_func = nn.BCEWithLogitsLoss(reduction='none')
    
    def forward(self, preds, targets):
        bce_loss = self.loss_func(preds, targets)
        probas = torch.sigmoid(preds)
        loss = torch.where(
            targets>=0.5, 
            self.alpha*((1.-probas)**self.gamma)*bce_loss, 
            (probas**self.gamma)*bce_loss
        )
        loss = loss.mean()
        return loss
    

class MaskedLossWithTPFP(BCEWithLogitsLossFlat):
    def __init__(self, **kwargs):
        super().__init__(reduction='sum', **kwargs)
    
    def __call__(self, inp, targ, **kwargs):
        weight = (targ != 0.).float().to(targ.device)
        # apply masking on non-TP/ non-FP
        self.func.weight = weight.view(-1)
        # set FP to 0.
        targ[targ == -1] = 0.
        n = weight.sum()
        masked_sum = super().__call__(inp, targ, **kwargs)
        return masked_sum / n

In [None]:
# get the metrics
if LOSS_FUNCTION == LossFunction.FOCAL_LOSS.value:
    loss_func = BinaryFocalLoss(gamma=2., alpha=1.)
elif LOSS_FUNCTION == LossFunction.BCE_LOGIT_LOSS.value:
    loss_func = BCEWithLogitsLossFlat()
elif LOSS_FUNCTION == LossFunction.CE_SOFTMAX_LOSS.value:
    loss_func = CrossEntropyLossFlat()
elif LOSS_FUNCTION == LossFunction.MASKED_BCE_WITH_TPFP.value:
    loss_func = MaskedLossWithTPFP()
else:
    raise ValueError('Invalid loss function')
    
    
if CHANNEL == Channel.SINGLE.value:
    config_dict = {'n_in': 1}
elif CHANNEL == Channel.THREE.value:
    config_dict = {'n_in': 3}
else:
    raise NotImplementedError


cbs = MixUp(1.) if IS_MIXUP else None

    
if NORMALIZER == Normalizer.IMAGENET_STATS.value:
    # type-dispatch for Normalizer to work on AudioSpec tensor
    def encode_tensorimage(self, x:TensorImage): 
        return (x-self.mean) / self.std
    def encode_audiospec(self, x:AudioSpectrogram): 
        return (x-self.mean) / self.std
    Normalize.encodes = TypeDispatch([encode_tensorimage, encode_audiospec])
    
    learner = cnn_learner(
        dls, resnet34, 
        pretrained=True,
        normalize=True,
        config=config_dict,
        loss_func=loss_func,
        metrics=[lwrap_metric],
        cbs=cbs
    )
    
elif NORMALIZER == Normalizer.SAMPLE_STATS.value:
    learner = cnn_learner(
        dls, resnet34, 
        pretrained=True,
        normalize=False,
        config=config_dict,
        loss_func=loss_func,
        metrics=[lwrap_metric],
        cbs=cbs
    )

else:
    raise NotImplementedError

print('Learner is set')

In [None]:
if (NORMALIZER == Normalizer.IMAGENET_STATS.value):
    if CHANNEL == (Channel.SINGLE.value):
        # make sure normalizer can handle AudioSpec tensor
        learner.dls.after_batch[-1].mean = learner.dls.after_batch[-1].mean[:, 0:1, :, :]
        learner.dls.after_batch[-1].std = learner.dls.after_batch[-1].std[:, 0:1, :, :]    
        # assert normalizer stats are 1 channel
        assert learner.dls.after_batch[-1].mean.shape[1] == 1
        assert learner.dls.train.after_batch[-1].mean.shape[1] == 1
        
    # add Normalize transform into valid (bug fix in latest release)
    if CHANNEL == (Channel.SINGLE.value):
        mean=learner.dls.after_batch[-1].mean.cpu().numpy().flatten()[0]
        std=learner.dls.after_batch[-1].std.cpu().numpy().flatten()[0]
    elif CHANNEL == (Channel.THREE.value):
        mean=learner.dls.after_batch[-1].mean.cpu().numpy().flatten()
        std=learner.dls.after_batch[-1].std.cpu().numpy().flatten()
    else:
        raise NotImplementedError
        
    learner.dls.valid.after_batch.add(
        Normalize.from_stats(mean=mean, std=std),
        'after_batch'
    )

print(f'Train batch_tfms: \n{learner.dls.train.after_batch}')
print(f'Valid batch_tfms: \n{learner.dls.valid.after_batch}')

In [None]:
learner.loss_func

In [None]:
learner.show_training_loop()

### 3b. Finding Optimal Learning Rate

In [None]:
#learner.lr_find() 

In [None]:
#learner.fine_tune(epochs=1, base_lr=0.033)

### 3c. Start Training
Issue Tracker:
- Low GPU utility (0%), and very high CPU utility

In [None]:
cbs = [SaveModelCallback(monitor=MONITOR_METRIC)]
if not DISABLE_WB:
    # start up W&B run
    user_secrets = UserSecretsClient()
    wandb_api = user_secrets.get_secret("wandb_key")
    wandb.login(key=wandb_api)
    wandb.init(**WB_CONFIG)
    config = wandb.config
    # additionally log the mel spectrogram transform config
    config.audio_to_spec = AudioConfig.BasicMelSpectrogram(**MELSPEC_CONFIG).__repr__()
    cbs += [WandbCallback()]
    print('Completed setup for W&B run')


print('Start training model')
learner.fine_tune(
    epochs=UNFREEZE_EPOCH, base_lr=LR, 
    freeze_epochs=FREEZE_EPOCH, cbs=cbs
)

In [None]:
learner.recorder.plot_loss()

In [None]:
learner.save('model_last_epoch')
print('Model checkpoint saved')

#### How to de-register a callback?
- Learner.remove_cbs: arg is the `Callback` class, NOT its instance
- Details
    - remove attribute `name` from the callback
    - set its attribute `learn` = None
    - remove from `Learner.cbs` list (calling `cbs.remove`, inherited from `list`)

In [None]:
#learner.show_training_loop();
try:
    learner.remove_cbs(FetchPredsCallback)
except:
    print('Failed to remove FetchPredsCallback, probably its absent')
learner.cbs

#### Sanity Check Best Model Validation Loss

In [None]:
preds, targs, loss = learner.get_preds(with_loss=True)
l = loss.mean().cpu().numpy()
print(f'Loss on validation: {l:.2f}')

### 4. Prepare Submission

In [None]:
def sample_one_subclips(row, i):    
    length = int(RANDOM_INTERVAL * SR)
    total_length = int(60 * SR)
    # Last segment going from the end
    if (i + 1) * length > total_length:
        _frame_offset = total_length - length
    else:
        _frame_offset = i * length
        
    audio_fn = test_path.resolve()/f'{row.recording_id}.flac'
    audio_subclip = AudioTensor.create(audio_fn, 
                                        frame_offset=_frame_offset, 
                                        num_frames=length)
    return audio_subclip

In [None]:
test_df = pd.read_csv(DATA_DIR/'sample_submission.csv')

In [None]:
dls.train_ds.tls[0].tfms

In [None]:
dls.valid_ds.tls[0].tfms = Pipeline(partial(sample_one_subclips, i=0))
sample_test_dl = dls.test_dl(test_df, with_labels=False)
print(sample_test_dl.after_batch)

In [None]:
subclips_preds = []
segment_n = ceil((60*SR)/(RANDOM_INTERVAL*SR))
cp_pipeline = learner.dls.valid_ds.tls[0].tfms

for crt_i in range(segment_n):
    print(f'Running prediction on subclip {crt_i} for all test samples...')
    #dls.valid_ds.tls[0].tfms = Pipeline(partial(sample_one_subclips, i=crt_i))
    crt_test_dl = learner.dls.test_dl(test_df, with_labels=False)
    crt_test_dl.tls[0].tfms = Pipeline(partial(sample_one_subclips, i=crt_i))
    # predict on a subclip for all test samples
    _preds = learner.get_preds(dl=crt_test_dl)
    
    # softmax makes output statistics of each class dependent
    # softmax could distort the order if u take max across all subclips
    if IS_MULTILABEL:
        _preds = _preds[0]
    else:
        _preds = torch.nn.Softmax(dim=-1)(_preds[0])
        
    subclips_preds.append(_preds)

In [None]:
# sanity check if normalizer is present in test dataloader
crt_test_dl.after_batch

In [None]:
all_preds = torch.stack(subclips_preds, dim=1)
final_preds = all_preds.max(dim=1)[0]

test_df.iloc[:, 1:] = final_preds
test_df.to_csv('my_submission.csv', index=False)
print('Submission CSV written')

#### Revalidate the save model is working

In [None]:
learner.dls.valid_ds.tls[0].tfms

In [None]:
learner = learner.load('model')
learner.dls.valid_ds.tls[0].tfms = cp_pipeline
preds, targs, loss = learner.get_preds(with_loss=True)
l = loss.mean().cpu().numpy()
print(f'Loss on validation after model load: {l:.2f}')