In [None]:
import json
import os
import random

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torchvision.models.resnet import ResNet, BasicBlock

In [None]:
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Prepare paths
root_path = "../input/birdclef-2022/"
input_path = root_path + '/train_audio/'

# Read dataset and labels
train_meta = pd.read_csv(root_path + 'train_metadata.csv')
with open(root_path + '/scored_birds.json') as sbfile:
    scored_birds = json.load(sbfile)

# bird_label = train_meta["primary_label"].unique()
bird_label = np.asarray(scored_birds)

# Preprocessing data
sample_rate = 48000
n_fft = 2048
win_length = None
hop_length = 1024
n_mels = 128
min_sec_proc = sample_rate*5

mel_spectrogram = T.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=1.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="slaney",
)

In [None]:
# Set pseudo randomize
def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

torch_fix_seed()

In [None]:
# Create spectrogramm for audio files
def audio_to_mel_label(filepath,
                       min_sec_proc,
                       mode='train',
                       data_index=0,
                       label_list=[],
                       bird_label=[],
                       label_file=[],
                       mel_list=[]):

    waveform, sample_rate_file = torchaudio.load(filepath=filepath)
    len_wav = waveform.shape[1]
    waveform = waveform[0, :].reshape(1, len_wav)  # stereo->mono mono->mono
    if not len_wav < min_sec_proc * 12:
        waveform = torch.cat((waveform, waveform[:, 0:len_wav]), 1)
        len_wav = min_sec_proc * 12
        waveform = waveform[:, 0:len_wav]

    for index in range(int(len_wav / min_sec_proc)):
        log_melspec = torch.log10(
            mel_spectrogram(waveform[0, index * min_sec_proc:index * min_sec_proc + min_sec_proc]).reshape(1, 128,

  235) + 1e-10)
        log_melspec = (log_melspec - torch.mean(log_melspec)) / torch.std(log_melspec)

        mel_list.append(log_melspec)

    return mel_list

# class ResNetBird(ResNet):
#     def __init__(self):
#         super().__init__(BasicBlock, [5, 8, 6, 3], num_classes=21)

#         self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)


# net = ResNetBird().to(device)

import torchvision
from torchvision import datasets, models, transforms

net = models.resnet18(pretrained=False)
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
net.fc = nn.Linear(net.fc.in_features, 21)
net.to(device)

net.load_state_dict(torch.load('../input/birdsclassification/model.pt'))
out_sigmoid = nn.Sigmoid()

test_audio_dir = root_path + '/test_soundscapes/'
file_list = [f.split('.')[0] for f in sorted(os.listdir(test_audio_dir))]

pred = {'row_id': [], 'target': []}
binary_th = 0.001
net.eval()

for afile in file_list:

    path = test_audio_dir + afile + '.ogg'

    chunks = [[] for i in range(12)]

    mel_list_test = []
    mel_list_test = audio_to_mel_label(path, min_sec_proc, 'test', mel_list=mel_list_test)
    mel_list_test = torch.stack(mel_list_test).to(device)

    outputs = net(mel_list_test)

    outputs_test = out_sigmoid(outputs)

    for idx, i in enumerate(range(len(chunks))):
        chunk_end_time = (i + 1) * 5
        for bird in scored_birds:

            try:
                score = outputs_test[idx][np.where(bird_label == bird)]
            except IndexError:
                score = 0
            print(score)
            row_id = afile + '_' + bird + '_' + str(chunk_end_time)

            pred['row_id'].append(row_id)
            pred['target'].append(True if score > binary_th else False)

results = pd.DataFrame(pred, columns=['row_id', 'target'])

print(results)

results.to_csv("./submission.csv", index=False)