# EfficientNetB0 Inference (submission generation)

In [148]:
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import timm
from torchvision.models import regnet_y_800mf, RegNet_Y_800MF_Weights

In [149]:
## REUSE IN INFERENCE NOTEBOOK

import sys
sys.path.append("..")
import utils

IS_IN_KAGGLE_ENV = utils.get_is_in_kaggle_env()

DATA_PATH = '/kaggle/input/birdclef-2023' if IS_IN_KAGGLE_ENV else '../data'

DEVICE = 'cpu' if IS_IN_KAGGLE_ENV else utils.determine_device()

AUDIO_LENGTH_S = 5
SAMPLE_RATE = 32_000

We are running code on Localhost
We are using device: mps


In [150]:
## REUSE IN INFERENCE NOTEBOOK

# class BirdMelspecClf(nn.Module):
#     def __init__(self, out_features, pretrained):
#         super().__init__()

#         self.cnn = timm.create_model('regnety_080', pretrained=pretrained)
#         for name, param in self.cnn.named_parameters():
#             if param.requires_grad:
#                 print(name)

#         # Replace original stem for black and white images
#         self.cnn.stem.conv = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)

#         # Replace original classifier for our task
#         self.cnn.head.fc = nn.Sequential(
#             nn.Linear(self.cnn.head.fc.in_features, 1024),
#             nn.BatchNorm1d(1024),
#             nn.PReLU(),

#             nn.Linear(1024, 512),
#             nn.BatchNorm1d(512),
#             nn.PReLU(),
            
#             nn.Linear(512, out_features),
#         )

#         self.softmax = nn.Softmax(dim=1)

#     def forward(self, x):
#         logits = self.cnn(x)
#         probas = self.softmax(logits)

#         return logits, probas


class BirdMelspecClf(nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html

        self.regnet = regnet_y_800mf(weights=RegNet_Y_800MF_Weights.DEFAULT) if pretrained else regnet_y_800mf()

        """
        Replace the stem to take 1 channel instead of 3. The original stem:
        RegnetCNN(
        (regnet): RegNet(
            (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
        )"""
        self.regnet.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        )
        
        # Fine-tune the regnet classifier
        self.regnet.fc = nn.Sequential(
            nn.Linear(self.regnet.fc.in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.PReLU(),
            
            nn.Linear(512, out_features),
        )

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def get_model(out_features, pretrained=False):
    model = BirdMelspecClf(out_features=out_features, pretrained=pretrained)
    print(f"Loaded model {model.__class__.__name__} with {sum(p.numel() for p in model.parameters())} parameters, pretained={pretrained}")
    return model

In [151]:
## REUSE IN INFERENCE NOTEBOOK

def resample(audio, current_sample_rate, desired_sample_rate=SAMPLE_RATE):
    resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=desired_sample_rate)
    resampled_audio = resampler(audio)
    return resampled_audio

def load_audio(audio_path, sample_rate=SAMPLE_RATE):
    audio, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        audio = resample(audio, sr, sample_rate)
    return audio

def get_melspec_transform(sample_rate=SAMPLE_RATE, n_fft=1024, win_length=1024, hop_length=512, n_mels=128, f_min=0, f_max=SAMPLE_RATE // 2, normalized=True):
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=f_min,
        f_max=f_max,
        normalized=normalized,
    )

def get_melspec_db_transform(stype='power', top_db=80):
    return torchaudio.transforms.AmplitudeToDB(
        stype=stype,
        top_db=top_db
    )

def normalize_melspec(spectrogram):
    min_val = torch.min(spectrogram)
    max_val = torch.max(spectrogram)
    if max_val - min_val == 0:
        return spectrogram
    else:
        return (spectrogram - min_val) / (max_val - min_val)

In [152]:
# df = pd.read_csv('/kaggle/input/birdclef-2023/train_metadata.csv')
df = pd.read_csv(f"{DATA_PATH}/train_metadata.csv")
index_to_label = sorted(df.primary_label.unique())

model = get_model(out_features=len(index_to_label), pretrained=False)
model.load_state_dict(torch.load("best_model.pt"))

Loaded model BirdMelspecClf with 7114082 parameters, pretained=False


<All keys matched successfully>

In [153]:
filepaths = glob.glob(f"{DATA_PATH}/test_soundscapes/*.ogg")
print(f"filepaths length: {len(filepaths)}")

filepaths length: 1


# Inference

In [155]:
debug = False
simulate_200_files = False

if simulate_200_files:
    filepaths = [filepaths[0] for i in range(200)] # simulate submission
    print(f"filepaths length: {len(filepaths)}")

model.eval()

MIN_WINDOW = AUDIO_LENGTH_S * SAMPLE_RATE
melspec_transform = get_melspec_transform(n_mels=128)
melspec_db_transform = get_melspec_db_transform()

def infer(filepath):
    all_predictions = []
    name = Path(filepath).stem
    audio = torchaudio.load(filepath)[0][0]
    audio_len_s = len(audio) / 32_000
    debug and print(f"Infering file {filepath} with length {audio_len_s} s")
    n_crops = int(audio_len_s // 5)
    for i in range(n_crops):
        debug and print(f"Crop {i} / {n_crops}")
        debug and print(f"Audio length: {len(audio)}")
        crop = audio[i*MIN_WINDOW:(i+1)*MIN_WINDOW]
        debug and print(f"Crop length: {len(crop)}")
        melspec = normalize_melspec(melspec_db_transform(melspec_transform(crop)))
        debug and print(f"melspec shape: {melspec.shape}") # [128, 313]
        melspec = melspec.unsqueeze(0) # add batch dimension (1)
        debug and print(f"melspec unsqueezed shape: {melspec.shape}") # [1, 128, 313]
        melspec = melspec.unsqueeze(0) # add channel dimension (1)
        debug and print(f"melspec unsqueezed 2 shape: {melspec.shape}") # [1, 1, 128, 313]
        with torch.no_grad():
            logit, proba = model(melspec)
        t = (i + 1) * 5
        all_predictions.append({"row_id": f'{name}_{t}',"predictions": proba})
        debug and print('---')
    return all_predictions

if debug:
    all_preds = []
    for filepath in tqdm(filepaths, desc='Infering files'):
        all_preds.append(infer(filepath))
else:
    parallel_task = (delayed(infer)(filepath) for filepath in tqdm(filepaths, desc='Infering files'))
    all_preds = Parallel(n_jobs=os.cpu_count())(parallel_task)

all_preds_flat = [item for sublist in all_preds for item in sublist]

print(f"all_preds length: {len(all_preds)}, all_preds_flat length: {len(all_preds_flat)}")

Infering files: 100%|██████████| 1/1 [00:00<00:00, 520.84it/s]


all_preds length: 1, all_preds_flat length: 120


In [156]:
all_preds_flat[100]['predictions']

tensor([[3.9706e-03, 7.2762e-03, 6.1396e-03, 2.8244e-03, 4.1173e-03, 7.1982e-03,
         5.5321e-03, 1.9108e-03, 4.4981e-03, 3.9354e-03, 6.1121e-03, 1.3410e-04,
         4.4746e-03, 4.8645e-03, 3.9582e-03, 6.1077e-03, 5.7932e-03, 4.8046e-03,
         2.1387e-03, 3.7121e-03, 5.7121e-03, 2.5794e-03, 1.4057e-03, 4.2507e-03,
         3.2525e-03, 6.1193e-03, 4.8145e-03, 4.3836e-03, 6.0398e-03, 6.5155e-03,
         4.2887e-03, 5.6927e-03, 2.3401e-03, 5.3657e-03, 3.6230e-03, 3.8857e-03,
         2.9716e-03, 2.4384e-03, 1.9920e-03, 1.5643e-03, 6.1607e-04, 3.3131e-03,
         2.5294e-03, 3.1214e-03, 1.8387e-03, 4.8135e-03, 1.1986e-04, 4.8293e-03,
         4.8389e-03, 3.5440e-03, 2.8350e-03, 2.5395e-03, 1.1849e-03, 4.9323e-03,
         3.5866e-03, 4.5290e-03, 5.4081e-03, 3.1593e-03, 4.8210e-03, 3.2281e-03,
         5.9865e-03, 5.5127e-03, 5.9646e-03, 2.4399e-03, 8.7471e-04, 1.0692e-04,
         5.1909e-03, 5.0739e-03, 4.0773e-03, 6.3981e-03, 4.8655e-03, 5.3319e-03,
         5.9780e-03, 6.4753e

In [157]:
df = pd.concat([
    pd.DataFrame({'row_id': [p['row_id'] for p in all_preds_flat]}), 
    pd.DataFrame(torch.stack([p['predictions'][0] for p in all_preds_flat]).numpy(), columns=index_to_label)
], axis=1)

df

Unnamed: 0,row_id,abethr1,abhori1,abythr1,afbfly1,afdfly1,afecuc1,affeag1,afgfly1,afghor1,...,yebsto1,yeccan1,yefcan,yelbis1,yenspu1,yertin1,yesbar1,yespet1,yetgre1,yewgre1
0,soundscape_29201_5,0.003745,0.006981,0.006451,0.002825,0.004266,0.006654,0.005769,0.001988,0.004202,...,0.000182,0.003732,0.005797,0.002210,0.002989,0.006481,0.003585,0.002723,0.003634,0.004053
1,soundscape_29201_10,0.003929,0.006879,0.005417,0.002773,0.003959,0.007461,0.005716,0.001915,0.004356,...,0.000145,0.003380,0.006781,0.002258,0.003230,0.006479,0.003492,0.002174,0.003429,0.004584
2,soundscape_29201_15,0.003777,0.006647,0.006278,0.002545,0.004217,0.007128,0.005267,0.002044,0.004202,...,0.000208,0.003733,0.006265,0.002051,0.003148,0.006776,0.003919,0.002696,0.003364,0.004136
3,soundscape_29201_20,0.003810,0.006275,0.006033,0.002864,0.004304,0.006995,0.005365,0.001995,0.004050,...,0.000197,0.003767,0.006172,0.002483,0.003167,0.006245,0.003792,0.002598,0.003547,0.004299
4,soundscape_29201_25,0.003769,0.006725,0.005536,0.002871,0.004004,0.006357,0.005485,0.001855,0.004470,...,0.000204,0.003763,0.006099,0.002478,0.003194,0.006597,0.003596,0.002207,0.003769,0.004225
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,soundscape_29201_580,0.003805,0.007205,0.005741,0.002562,0.004218,0.007010,0.005767,0.001920,0.004451,...,0.000139,0.003611,0.007013,0.002162,0.003247,0.007340,0.003572,0.002432,0.003254,0.004031
116,soundscape_29201_585,0.003556,0.007137,0.006254,0.002671,0.004095,0.007339,0.005809,0.001800,0.004568,...,0.000141,0.003370,0.006661,0.002239,0.003135,0.006611,0.003522,0.002381,0.003509,0.004203
117,soundscape_29201_590,0.003756,0.007065,0.005734,0.002587,0.004245,0.007077,0.005284,0.002024,0.004438,...,0.000158,0.003616,0.006878,0.002207,0.003175,0.007314,0.003408,0.002259,0.003375,0.004155
118,soundscape_29201_595,0.004268,0.007350,0.006079,0.002650,0.004250,0.006820,0.005938,0.002110,0.004523,...,0.000173,0.003970,0.005949,0.002230,0.003718,0.006745,0.003297,0.002760,0.003494,0.003966


In [158]:
df.to_csv('submission.csv', index=False)