# RegNet Inference (submission generation)

In [1]:
import os
import sys
from pathlib import Path
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import numpy as np
from torchvision.models import regnet_y_1_6gf, RegNet_Y_1_6GF_Weights
import timm
import re
from torchaudio import functional as F_audio
import pywt

In [2]:
## REUSE IN INFERENCE NOTEBOOK

custom_dataset_path = '/kaggle/input/birdclef2023-inference'
if os.path.exists(os.path.join(custom_dataset_path, 'utils.py')):
    sys.path.append(custom_dataset_path)
else:
    sys.path.append('..')
import utils

IS_IN_KAGGLE_ENV = utils.get_is_in_kaggle_env()

DATA_PATH = '/kaggle/input/birdclef-2023' if IS_IN_KAGGLE_ENV else '../data'
JOBLIB_PATH = custom_dataset_path if IS_IN_KAGGLE_ENV else './'

DEVICE = 'cpu'

AUDIO_LENGTH_S = 5
SAMPLE_RATE = 32_000

We are running code on Localhost


In [3]:
## REUSE IN INFERENCE NOTEBOOK

class BirdMelspecClf(nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html
        self.regnet = regnet_y_1_6gf(weights=RegNet_Y_1_6GF_Weights.DEFAULT) if pretrained else regnet_y_1_6gf()

        """
        Replace the stem to take 2 channels instead of 3. The original stem:
        (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            )
        )"""
        self.regnet.stem[0] = nn.Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        
        # Replace original classifier: (fc): Linear(in_features=888, out_features=1000, bias=True)
        self.regnet.fc = nn.Linear(self.regnet.fc.in_features, out_features)

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def get_model(out_features, device, pretrained=False, load_state_dict=True, state_dict_starts_with=f"{AUDIO_LENGTH_S}s_regnetY16GF"):
    model = BirdMelspecClf(out_features=out_features, pretrained=pretrained)
    print(f"Loaded model {model.__class__.__name__} with {sum(p.numel() for p in model.parameters())} parameters, pretained={pretrained}")
    model.to(device)

    if not load_state_dict:
        return model

    model_files = [f for f in os.listdir(JOBLIB_PATH) if f.startswith(state_dict_starts_with) and f.endswith('.pt')]
    if len(model_files) == 0:
        print(f"No model starting with {state_dict_starts_with} found in {JOBLIB_PATH}")
        return model
    
    # Extract timestamp from the filenames and sort based on it
    model_files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]) if re.findall(r'\d+', x) else -1)

    # The latest model file is the last one in the sorted list
    latest_model_file = model_files[-1]
    model_path = os.path.join(JOBLIB_PATH, latest_model_file)
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded model weights from {model_path}")
    model.to(device)

    return model


def get_label_encoder():
    label_encoder_path = os.path.join(JOBLIB_PATH, 'label_encoder.joblib')
    label_encoder = joblib.load(label_encoder_path)
    print(f"Loaded label encoder from {label_encoder_path}")
    return label_encoder

In [4]:
class WaveletTransformSingle(nn.Module):
  def __init__(
      self, 
      wavelet: pywt.Wavelet,
      cut_to_nearest: int | None = None
  ):
    super(WaveletTransformSingle, self).__init__()
    self.wavelet = wavelet
    self.ctn = cut_to_nearest

  def forward(self, X: torch.Tensor) -> torch.Tensor:
    item = X.cpu().numpy()
    
    wh, wl = pywt.dwt(item[0], self.wavelet)
    out = torch.stack((torch.from_numpy(wh), torch.from_numpy(wl)))
    
    if self.ctn is not None:
      out = out[:,:-1 * (out.shape[-1] % self.ctn)]
    
    return out

In [5]:
## REUSE IN INFERENCE NOTEBOOK

def resample(audio, current_sample_rate, desired_sample_rate=SAMPLE_RATE):
    resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=desired_sample_rate)
    resampled_audio = resampler(audio)
    return resampled_audio

def load_audio(audio_path, sample_rate=SAMPLE_RATE):
    audio, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        audio = resample(audio, sr, sample_rate)
    return audio

# Using librosa defaults for n_fft and hop_length
def get_melspec_transform(sample_rate=SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128):
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )

# Using librosa defaults for top_db
def get_melspec_db_transform(stype='power', top_db=80):
    return torchaudio.transforms.AmplitudeToDB(
        stype=stype,
        top_db=top_db
    )

# Copied from torchaudio/transforms/_transforms.py (to avoid converting to melspec twice)
dct_mat = F_audio.create_dct(128, 128, "ortho")
def get_mfcc_from_melspec(melspec):
    return torch.matmul(melspec.transpose(-1, -2), dct_mat).transpose(-1, -2)

def normalize_tensor(tensor):
    min_val = torch.min(tensor)
    max_val = torch.max(tensor)
    if max_val - min_val == 0:
        return tensor
    else:
        return (tensor - min_val) / (max_val - min_val)

In [6]:
filepaths = glob.glob(f"{DATA_PATH}/test_soundscapes/*.ogg")
print(f"filepaths length: {len(filepaths)} (amount of test soundscapes)")

filepaths length: 1 (amount of test soundscapes)


# Inference

In [7]:
from torchvision import transforms

debug = False
simulate_200_files = False

if simulate_200_files:
    filepaths = [filepaths[0] for i in range(200)] # simulate submission
    print(f"filepaths length: {len(filepaths)} after simulation additions")

label_encoder = get_label_encoder()
model = get_model(out_features=len(label_encoder.classes_), device=DEVICE, pretrained=False, load_state_dict=True)
model.eval()

MIN_WINDOW = AUDIO_LENGTH_S * SAMPLE_RATE
melspec_transform = get_melspec_transform(n_mels=128)
melspec_db_transform = get_melspec_db_transform()
wave_transform = WaveletTransformSingle(pywt.Wavelet('sym4'))
resize = transforms.Resize((128, 313), antialias=True)

def infer(filepath):
    all_predictions = []
    name = Path(filepath).stem
    audio = load_audio(filepath)
    audio_len_s = audio.shape[1] / SAMPLE_RATE
    debug and print(f"Infering file {filepath} with length {audio_len_s} s")
    n_crops = int(audio_len_s // 5)
    for i in range(n_crops):
        debug and print(f"Crop {i} / {n_crops}")
        debug and print(f"Audio length: {len(audio)}")
        crop = audio[:, i*MIN_WINDOW:(i+1)*MIN_WINDOW]
        debug and print(f"Crop dimensions: {crop.shape}")
        melspec = melspec_db_transform(melspec_transform(crop))
        norm_melspec = normalize_tensor(melspec)
        melspec_wave = wave_transform(crop)
        wh, wl = melspec_wave[0], melspec_wave[1]
        wh_mel = melspec_db_transform(melspec_transform(wh))
        wl_mel = melspec_db_transform(melspec_transform(wl))
        norm_wh = normalize_tensor(resize(wh_mel.unsqueeze(0)))
        norm_wl = normalize_tensor(resize(wl_mel.unsqueeze(0)))
        mfcc = get_mfcc_from_melspec(melspec)
        norm_mfcc = normalize_tensor(resize(mfcc))
        features = torch.cat((norm_melspec, norm_wh, norm_wl, norm_mfcc), dim=0)
        debug and print(f"features shape: {features.shape}") # [4, 128, 313]
        features = features.unsqueeze(0) # add batch dimension (1)
        debug and print(f"features unsqueezed shape: {features.shape}") # [1, 4, 128, 313]
        with torch.no_grad():
            logit, proba = model(features)
        t = (i + 1) * 5
        all_predictions.append({"row_id": f'{name}_{t}',"predictions": proba})
        debug and print('---')
    return all_predictions

if debug:
    all_preds = []
    for filepath in tqdm(filepaths, desc='Infering files'):
        all_preds.append(infer(filepath))
else:
    parallel_task = (delayed(infer)(filepath) for filepath in tqdm(filepaths, desc='Infering files'))
    all_preds = Parallel(n_jobs=os.cpu_count())(parallel_task)

all_preds_flat = [item for sublist in all_preds for item in sublist]

print(f"all_preds length: {len(all_preds)}, all_preds_flat length: {len(all_preds_flat)}")

Loaded label encoder from ./label_encoder.joblib
Loaded model BirdMelspecClf with 10548414 parameters, pretained=False
No model starting with 5s_regnetY800MF found in ./


Infering files: 100%|██████████| 1/1 [00:00<00:00, 124.10it/s]


all_preds length: 1, all_preds_flat length: 120


In [8]:
all_preds_flat[100]['predictions'][0]

tensor([0.0036, 0.0037, 0.0039, 0.0038, 0.0035, 0.0043, 0.0038, 0.0035, 0.0038,
        0.0040, 0.0037, 0.0038, 0.0037, 0.0038, 0.0038, 0.0040, 0.0038, 0.0039,
        0.0036, 0.0040, 0.0037, 0.0039, 0.0039, 0.0035, 0.0037, 0.0035, 0.0037,
        0.0039, 0.0035, 0.0037, 0.0039, 0.0038, 0.0035, 0.0037, 0.0039, 0.0040,
        0.0038, 0.0037, 0.0035, 0.0038, 0.0036, 0.0036, 0.0039, 0.0038, 0.0037,
        0.0039, 0.0039, 0.0039, 0.0039, 0.0038, 0.0039, 0.0036, 0.0038, 0.0040,
        0.0036, 0.0038, 0.0038, 0.0039, 0.0037, 0.0042, 0.0036, 0.0042, 0.0039,
        0.0036, 0.0039, 0.0040, 0.0039, 0.0036, 0.0039, 0.0037, 0.0038, 0.0040,
        0.0039, 0.0038, 0.0039, 0.0039, 0.0037, 0.0041, 0.0040, 0.0036, 0.0039,
        0.0038, 0.0041, 0.0038, 0.0036, 0.0039, 0.0038, 0.0039, 0.0037, 0.0040,
        0.0038, 0.0040, 0.0037, 0.0037, 0.0039, 0.0036, 0.0038, 0.0035, 0.0038,
        0.0038, 0.0038, 0.0036, 0.0036, 0.0038, 0.0037, 0.0036, 0.0037, 0.0036,
        0.0038, 0.0036, 0.0041, 0.0036, 

In [9]:
df = pd.concat([
    pd.DataFrame({'row_id': [p['row_id'] for p in all_preds_flat]}), 
    pd.DataFrame(torch.stack([p['predictions'][0] for p in all_preds_flat]).numpy(), columns=label_encoder.classes_)
], axis=1)

df

Unnamed: 0,row_id,abethr1,abhori1,abythr1,afbfly1,afdfly1,afecuc1,affeag1,afgfly1,afghor1,...,yebsto1,yeccan1,yefcan,yelbis1,yenspu1,yertin1,yesbar1,yespet1,yetgre1,yewgre1
0,soundscape_29201_5,0.003633,0.003702,0.003929,0.003761,0.003518,0.004286,0.003805,0.003509,0.003754,...,0.003642,0.003717,0.003785,0.003596,0.003784,0.003925,0.003797,0.003750,0.003776,0.003974
1,soundscape_29201_10,0.003636,0.003699,0.003930,0.003759,0.003514,0.004290,0.003809,0.003506,0.003752,...,0.003644,0.003719,0.003791,0.003599,0.003784,0.003925,0.003797,0.003749,0.003773,0.003980
2,soundscape_29201_15,0.003634,0.003703,0.003930,0.003767,0.003530,0.004272,0.003799,0.003515,0.003751,...,0.003650,0.003718,0.003788,0.003598,0.003783,0.003921,0.003803,0.003751,0.003773,0.003966
3,soundscape_29201_20,0.003636,0.003702,0.003931,0.003763,0.003524,0.004278,0.003798,0.003512,0.003749,...,0.003647,0.003719,0.003789,0.003601,0.003781,0.003923,0.003801,0.003751,0.003773,0.003967
4,soundscape_29201_25,0.003637,0.003705,0.003932,0.003765,0.003527,0.004272,0.003796,0.003516,0.003750,...,0.003649,0.003718,0.003787,0.003600,0.003781,0.003921,0.003802,0.003750,0.003775,0.003962
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,soundscape_29201_580,0.003631,0.003709,0.003924,0.003773,0.003537,0.004262,0.003793,0.003532,0.003762,...,0.003657,0.003714,0.003788,0.003602,0.003786,0.003919,0.003801,0.003753,0.003772,0.003961
116,soundscape_29201_585,0.003631,0.003709,0.003923,0.003768,0.003530,0.004270,0.003795,0.003524,0.003760,...,0.003653,0.003715,0.003790,0.003599,0.003785,0.003920,0.003801,0.003754,0.003775,0.003965
117,soundscape_29201_590,0.003632,0.003706,0.003926,0.003769,0.003529,0.004271,0.003799,0.003524,0.003759,...,0.003652,0.003716,0.003787,0.003601,0.003785,0.003921,0.003797,0.003753,0.003772,0.003967
118,soundscape_29201_595,0.003635,0.003710,0.003926,0.003777,0.003543,0.004255,0.003788,0.003537,0.003762,...,0.003665,0.003712,0.003792,0.003604,0.003780,0.003916,0.003804,0.003754,0.003771,0.003957


In [10]:
df.to_csv('submission.csv', index=False)

In [11]:
!ls submission.csv

submission.csv
