# RegNet Inference (submission generation)

In [1]:
import os
import sys
from pathlib import Path
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from torchvision.models import regnet_y_800mf, RegNet_Y_800MF_Weights
import timm
import re

In [2]:
## REUSE IN INFERENCE NOTEBOOK

custom_dataset_path = '/kaggle/input/birdclef2023-inference'
if os.path.exists(os.path.join(custom_dataset_path, 'utils.py')):
    sys.path.append(custom_dataset_path)
else:
    sys.path.append('..')
import utils

IS_IN_KAGGLE_ENV = utils.get_is_in_kaggle_env()

DATA_PATH = '/kaggle/input/birdclef-2023' if IS_IN_KAGGLE_ENV else '../data'
JOBLIB_PATH = custom_dataset_path if IS_IN_KAGGLE_ENV else './'

DEVICE = 'cpu'

AUDIO_LENGTH_S = 5
SAMPLE_RATE = 32_000

We are running code on Localhost


In [3]:
## REUSE IN INFERENCE NOTEBOOK

# class BirdMelspecClf(nn.Module):
#     def __init__(self, out_features, pretrained):
#         super().__init__()

#         self.cnn = timm.create_model('regnety_080', pretrained=pretrained)
#         for name, param in self.cnn.named_parameters():
#             if param.requires_grad:
#                 print(name)

#         # Replace original stem for black and white images
#         self.cnn.stem.conv = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)

#         # Replace original classifier for our task
#         self.cnn.head.fc = nn.Sequential(
#             nn.Linear(self.cnn.head.fc.in_features, 1024),
#             nn.BatchNorm1d(1024),
#             nn.PReLU(),

#             nn.Linear(1024, 512),
#             nn.BatchNorm1d(512),
#             nn.PReLU(),
            
#             nn.Linear(512, out_features),
#         )

#         self.softmax = nn.Softmax(dim=1)

#     def forward(self, x):
#         logits = self.cnn(x)
#         probas = self.softmax(logits)

#         return logits, probas


class BirdMelspecClf(nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html

        self.regnet = regnet_y_800mf(weights=RegNet_Y_800MF_Weights.DEFAULT) if pretrained else regnet_y_800mf()

        """
        Replace the stem to take 1 channel instead of 3. The original stem:
        RegnetCNN(
        (regnet): RegNet(
            (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
        )"""
        self.regnet.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        )
        
        # Fine-tune the regnet classifier
        self.regnet.fc = nn.Sequential(
            nn.Linear(self.regnet.fc.in_features, 512),
            nn.BatchNorm1d(512),
            nn.PReLU(),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.PReLU(),
            
            nn.Linear(256, out_features),
        )

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def get_model(out_features, device, pretrained=False, load_state_dict=True, state_dict_starts_with=f"{AUDIO_LENGTH_S}s_model"):
    model = BirdMelspecClf(out_features=out_features, pretrained=pretrained)
    print(f"Loaded model {model.__class__.__name__} with {sum(p.numel() for p in model.parameters())} parameters, pretained={pretrained}")
    model.to(device)

    if not load_state_dict:
        return model

    model_files = [f for f in os.listdir(JOBLIB_PATH) if f.startswith(state_dict_starts_with) and f.endswith('.pt')]
    if len(model_files) == 0:
        print(f"No model starting with {state_dict_starts_with} found in {JOBLIB_PATH}")
        return model
    
    # Extract timestamp from the filenames and sort based on it
    model_files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]) if re.findall(r'\d+', x) else -1)

    # The latest model file is the last one in the sorted list
    latest_model_file = model_files[-1]
    model_path = os.path.join(JOBLIB_PATH, latest_model_file)
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded model weights from {model_path}")
    model.to(device)

    return model


def get_label_encoder():
    label_encoder_path = os.path.join(JOBLIB_PATH, 'label_encoder.joblib')
    label_encoder = joblib.load(label_encoder_path)
    print(f"Loaded label encoder from {label_encoder_path}")
    return label_encoder

In [4]:
## REUSE IN INFERENCE NOTEBOOK

def resample(audio, current_sample_rate, desired_sample_rate=SAMPLE_RATE):
    resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=desired_sample_rate)
    resampled_audio = resampler(audio)
    return resampled_audio

def load_audio(audio_path, sample_rate=SAMPLE_RATE):
    audio, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        audio = resample(audio, sr, sample_rate)
    return audio

# Using librosa defaults for n_fft and hop_length
def get_melspec_transform(sample_rate=SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128):
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )

# Using librosa defaults for top_db
def get_melspec_db_transform(stype='power', top_db=80):
    return torchaudio.transforms.AmplitudeToDB(
        stype=stype,
        top_db=top_db
    )

def normalize_melspec(spectrogram):
    min_val = torch.min(spectrogram)
    max_val = torch.max(spectrogram)
    if max_val - min_val == 0:
        return spectrogram
    else:
        return (spectrogram - min_val) / (max_val - min_val)

In [5]:
filepaths = glob.glob(f"{DATA_PATH}/test_soundscapes/*.ogg")
print(f"filepaths length: {len(filepaths)} (amount of test soundscapes)")

filepaths length: 1 (amount of test soundscapes)


# Inference

In [6]:
debug = False
simulate_200_files = False

if simulate_200_files:
    filepaths = [filepaths[0] for i in range(200)] # simulate submission
    print(f"filepaths length: {len(filepaths)} after simulation additions")

label_encoder = get_label_encoder()
model = get_model(out_features=len(label_encoder.classes_), device=DEVICE, pretrained=False, load_state_dict=True)
model.eval()

MIN_WINDOW = AUDIO_LENGTH_S * SAMPLE_RATE
melspec_transform = get_melspec_transform(n_mels=128)
melspec_db_transform = get_melspec_db_transform()

def infer(filepath):
    all_predictions = []
    name = Path(filepath).stem
    audio = torchaudio.load(filepath)[0][0]
    audio_len_s = len(audio) / SAMPLE_RATE
    debug and print(f"Infering file {filepath} with length {audio_len_s} s")
    n_crops = int(audio_len_s // 5)
    for i in range(n_crops):
        debug and print(f"Crop {i} / {n_crops}")
        debug and print(f"Audio length: {len(audio)}")
        crop = audio[i*MIN_WINDOW:(i+1)*MIN_WINDOW]
        debug and print(f"Crop length: {len(crop)}")
        melspec = normalize_melspec(melspec_db_transform(melspec_transform(crop)))
        debug and print(f"melspec shape: {melspec.shape}") # [128, 313]
        melspec = melspec.unsqueeze(0) # add batch dimension (1)
        debug and print(f"melspec unsqueezed shape: {melspec.shape}") # [1, 128, 313]
        melspec = melspec.unsqueeze(0) # add channel dimension (1)
        debug and print(f"melspec unsqueezed 2 shape: {melspec.shape}") # [1, 1, 128, 313]
        with torch.no_grad():
            logit, proba = model(melspec)
        t = (i + 1) * 5
        all_predictions.append({"row_id": f'{name}_{t}',"predictions": proba})
        debug and print('---')
    return all_predictions

if debug:
    all_preds = []
    for filepath in tqdm(filepaths, desc='Infering files'):
        all_preds.append(infer(filepath))
else:
    parallel_task = (delayed(infer)(filepath) for filepath in tqdm(filepaths, desc='Infering files'))
    all_preds = Parallel(n_jobs=os.cpu_count())(parallel_task)

all_preds_flat = [item for sublist in all_preds for item in sublist]

print(f"all_preds length: {len(all_preds)}, all_preds_flat length: {len(all_preds_flat)}")

Loaded label encoder from ./label_encoder.joblib
Loaded model BirdMelspecClf with 6249570 parameters, pretained=False
Loaded model weights from ./5s_model_e0_valacc6_1684168058.pt


Infering files: 100%|██████████| 1/1 [00:00<00:00, 129.65it/s]


all_preds length: 1, all_preds_flat length: 120


In [7]:
all_preds_flat[100]['predictions'][0]

tensor([0.0028, 0.0051, 0.0062, 0.0044, 0.0053, 0.0037, 0.0033, 0.0026, 0.0040,
        0.0041, 0.0049, 0.0023, 0.0037, 0.0085, 0.0025, 0.0041, 0.0041, 0.0049,
        0.0042, 0.0032, 0.0055, 0.0033, 0.0033, 0.0036, 0.0030, 0.0032, 0.0042,
        0.0037, 0.0059, 0.0049, 0.0039, 0.0048, 0.0028, 0.0058, 0.0063, 0.0030,
        0.0040, 0.0024, 0.0028, 0.0020, 0.0023, 0.0040, 0.0028, 0.0031, 0.0031,
        0.0041, 0.0029, 0.0038, 0.0036, 0.0046, 0.0048, 0.0027, 0.0023, 0.0055,
        0.0043, 0.0044, 0.0045, 0.0025, 0.0019, 0.0064, 0.0026, 0.0043, 0.0047,
        0.0039, 0.0036, 0.0037, 0.0053, 0.0025, 0.0028, 0.0048, 0.0052, 0.0058,
        0.0043, 0.0046, 0.0080, 0.0015, 0.0057, 0.0051, 0.0033, 0.0024, 0.0049,
        0.0023, 0.0032, 0.0032, 0.0065, 0.0036, 0.0053, 0.0027, 0.0031, 0.0050,
        0.0033, 0.0030, 0.0028, 0.0056, 0.0042, 0.0059, 0.0043, 0.0062, 0.0025,
        0.0034, 0.0024, 0.0026, 0.0026, 0.0059, 0.0035, 0.0050, 0.0062, 0.0065,
        0.0038, 0.0032, 0.0050, 0.0025, 

In [8]:
df = pd.concat([
    pd.DataFrame({'row_id': [p['row_id'] for p in all_preds_flat]}), 
    pd.DataFrame(torch.stack([p['predictions'][0] for p in all_preds_flat]).numpy(), columns=label_encoder.classes_)
], axis=1)

df

Unnamed: 0,row_id,abethr1,abhori1,abythr1,afbfly1,afdfly1,afecuc1,affeag1,afgfly1,afghor1,...,yebsto1,yeccan1,yefcan,yelbis1,yenspu1,yertin1,yesbar1,yespet1,yetgre1,yewgre1
0,soundscape_29201_5,0.003576,0.005822,0.004354,0.004870,0.006215,0.005141,0.004250,0.002504,0.006324,...,0.001851,0.003267,0.004422,0.003134,0.004531,0.005090,0.003747,0.002355,0.004569,0.004693
1,soundscape_29201_10,0.003977,0.005767,0.004102,0.004940,0.006419,0.004796,0.004301,0.002649,0.005563,...,0.001843,0.003486,0.004200,0.003021,0.004720,0.005277,0.003476,0.002504,0.004277,0.004917
2,soundscape_29201_15,0.003700,0.005966,0.004445,0.004714,0.006058,0.004608,0.004260,0.002649,0.005858,...,0.001777,0.003222,0.004369,0.003056,0.005464,0.005365,0.003434,0.002416,0.004033,0.004241
3,soundscape_29201_20,0.003312,0.005231,0.003939,0.004398,0.005670,0.004918,0.003720,0.002234,0.005700,...,0.001936,0.003264,0.003329,0.003551,0.004728,0.004708,0.003659,0.002273,0.004342,0.004883
4,soundscape_29201_25,0.003339,0.005025,0.003525,0.004613,0.005037,0.004028,0.003553,0.002562,0.004367,...,0.002032,0.003652,0.003661,0.004140,0.004924,0.004205,0.003497,0.002508,0.004702,0.005105
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,soundscape_29201_580,0.003643,0.006082,0.003923,0.005285,0.006118,0.004996,0.004085,0.002409,0.005923,...,0.001953,0.003683,0.004055,0.003263,0.004898,0.005271,0.003297,0.002646,0.003738,0.004690
116,soundscape_29201_585,0.003213,0.005496,0.005459,0.005213,0.005395,0.004037,0.003724,0.002507,0.004775,...,0.002170,0.003099,0.004410,0.003832,0.004727,0.004501,0.003237,0.001942,0.004888,0.004627
117,soundscape_29201_590,0.003353,0.005215,0.004896,0.004787,0.005457,0.004062,0.003728,0.002497,0.004367,...,0.001933,0.003429,0.003621,0.003230,0.005527,0.004699,0.003193,0.002080,0.004444,0.004989
118,soundscape_29201_595,0.003764,0.005768,0.004371,0.005322,0.006364,0.005460,0.003721,0.002613,0.006099,...,0.001948,0.003728,0.004256,0.003172,0.005074,0.004515,0.003438,0.002500,0.004582,0.004941


In [9]:
df.to_csv('submission.csv', index=False)

In [10]:
!ls submission.csv

submission.csv
