# RegNet Inference (submission generation)

In [None]:
import os
import sys
from pathlib import Path
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from torchvision.models import regnet_y_800mf, RegNet_Y_800MF_Weights
import timm
import re
from torchaudio import functional as F_audio

In [None]:
## REUSE IN INFERENCE NOTEBOOK

custom_dataset_path = '/kaggle/input/birdclef2023-inference'
if os.path.exists(os.path.join(custom_dataset_path, 'utils.py')):
    sys.path.append(custom_dataset_path)
else:
    sys.path.append('..')
import utils

IS_IN_KAGGLE_ENV = utils.get_is_in_kaggle_env()

DATA_PATH = '/kaggle/input/birdclef-2023' if IS_IN_KAGGLE_ENV else '../data'
JOBLIB_PATH = custom_dataset_path if IS_IN_KAGGLE_ENV else './'

DEVICE = 'cpu'

AUDIO_LENGTH_S = 5
SAMPLE_RATE = 32_000

In [None]:
## REUSE IN INFERENCE NOTEBOOK

class BirdMelspecClf(nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html

        self.regnet = regnet_y_800mf(weights=RegNet_Y_800MF_Weights.DEFAULT) if pretrained else regnet_y_800mf()

        """
        Replace the stem to take 1 channel instead of 3. The original stem:
        RegnetCNN(
        (regnet): RegNet(
            (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
        )"""
        self.regnet.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        )
        
        # Fine-tune the regnet classifier
        self.regnet.fc = nn.Sequential(
            nn.Linear(self.regnet.fc.in_features, 512),
            nn.BatchNorm1d(512),
            nn.PReLU(),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.PReLU(),
            
            nn.Linear(256, out_features),
        )

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def get_model(out_features, device, pretrained=False, load_state_dict=True, state_dict_starts_with=f"{AUDIO_LENGTH_S}s_regnetY800MFWC"):
    model = BirdMelspecClf(out_features=out_features, pretrained=pretrained)
    print(f"Loaded model {model.__class__.__name__} with {sum(p.numel() for p in model.parameters())} parameters, pretained={pretrained}")
    model.to(device)

    if not load_state_dict:
        return model

    model_files = [f for f in os.listdir(JOBLIB_PATH) if f.startswith(state_dict_starts_with) and f.endswith('.pt')]
    if len(model_files) == 0:
        print(f"No model starting with {state_dict_starts_with} found in {JOBLIB_PATH}")
        return model
    
    # Extract timestamp from the filenames and sort based on it
    model_files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]) if re.findall(r'\d+', x) else -1)

    # The latest model file is the last one in the sorted list
    latest_model_file = model_files[-1]
    model_path = os.path.join(JOBLIB_PATH, latest_model_file)
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded model weights from {model_path}")
    model.to(device)

    return model


def get_label_encoder():
    label_encoder_path = os.path.join(JOBLIB_PATH, 'label_encoder.joblib')
    label_encoder = joblib.load(label_encoder_path)
    print(f"Loaded label encoder from {label_encoder_path}")
    return label_encoder

In [None]:
## REUSE IN INFERENCE NOTEBOOK

def resample(audio, current_sample_rate, desired_sample_rate=SAMPLE_RATE):
    resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=desired_sample_rate)
    resampled_audio = resampler(audio)
    return resampled_audio

def load_audio(audio_path, sample_rate=SAMPLE_RATE):
    audio, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        audio = resample(audio, sr, sample_rate)
    return audio

# Using librosa defaults for n_fft and hop_length
def get_melspec_transform(sample_rate=SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128):
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )

# Using librosa defaults for top_db
def get_melspec_db_transform(stype='power', top_db=80):
    return torchaudio.transforms.AmplitudeToDB(
        stype=stype,
        top_db=top_db
    )

# Copied from torchaudio/transforms/_transforms.py (to avoid converting to melspec twice)
dct_mat = F_audio.create_dct(40, 128, "ortho")
def get_mfcc_from_melspec(melspec):
    return torch.matmul(melspec.transpose(-1, -2), dct_mat).transpose(-1, -2)

def normalize_tensor(tensor):
    min_val = torch.min(tensor)
    max_val = torch.max(tensor)
    if max_val - min_val == 0:
        return tensor
    else:
        return (tensor - min_val) / (max_val - min_val)

In [None]:
filepaths = glob.glob(f"{DATA_PATH}/test_soundscapes/*.ogg")
print(f"filepaths length: {len(filepaths)} (amount of test soundscapes)")

# Inference

In [None]:
debug = False
simulate_200_files = False

if simulate_200_files:
    filepaths = [filepaths[0] for i in range(200)] # simulate submission
    print(f"filepaths length: {len(filepaths)} after simulation additions")

label_encoder = get_label_encoder()
model = get_model(out_features=len(label_encoder.classes_), device=DEVICE, pretrained=False, load_state_dict=True)
model.eval()

MIN_WINDOW = AUDIO_LENGTH_S * SAMPLE_RATE
melspec_transform = get_melspec_transform(n_mels=128)
melspec_db_transform = get_melspec_db_transform()

def infer(filepath):
    all_predictions = []
    name = Path(filepath).stem
    audio = load_audio(filepath)
    audio_len_s = audio.shape[1] / SAMPLE_RATE
    debug and print(f"Infering file {filepath} with length {audio_len_s} s")
    n_crops = int(audio_len_s // 5)
    for i in range(n_crops):
        debug and print(f"Crop {i} / {n_crops}")
        debug and print(f"Audio length: {len(audio)}")
        crop = audio[:, i*MIN_WINDOW:(i+1)*MIN_WINDOW]
        debug and print(f"Crop dimensions: {crop.shape}")
        melspec = normalize_tensor(melspec_db_transform(melspec_transform(crop)))
        debug and print(f"melspec shape: {melspec.shape}") # [1, 128, 313]
        melspec = melspec.unsqueeze(0) # add batch dimension (1)
        debug and print(f"melspec unsqueezed shape: {melspec.shape}") # [1, 1, 128, 313]
        with torch.no_grad():
            logit, proba = model(melspec)
        t = (i + 1) * 5
        all_predictions.append({"row_id": f'{name}_{t}',"predictions": proba})
        debug and print('---')
    return all_predictions

if debug:
    all_preds = []
    for filepath in tqdm(filepaths, desc='Infering files'):
        all_preds.append(infer(filepath))
else:
    parallel_task = (delayed(infer)(filepath) for filepath in tqdm(filepaths, desc='Infering files'))
    all_preds = Parallel(n_jobs=os.cpu_count())(parallel_task)

all_preds_flat = [item for sublist in all_preds for item in sublist]

print(f"all_preds length: {len(all_preds)}, all_preds_flat length: {len(all_preds_flat)}")

In [None]:
all_preds_flat[100]['predictions'][0]

In [None]:
df = pd.concat([
    pd.DataFrame({'row_id': [p['row_id'] for p in all_preds_flat]}), 
    pd.DataFrame(torch.stack([p['predictions'][0] for p in all_preds_flat]).numpy(), columns=label_encoder.classes_)
], axis=1)

df

In [None]:
df.to_csv('submission.csv', index=False)

In [None]:
!ls submission.csv