In [None]:
import os
import torch
import librosa
from natsort import natsorted
from dcase24t6.nn.hub import baseline_pipeline

In [None]:
# Function to load audio files from a directory
def load_audio_files(directory):
    audio_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.wav') or f.endswith('.flac')]
    audio_files = natsorted(audio_files)
    return audio_files

In [None]:
# Function to perform inference and save results
def perform_inference(input_directory):
    # Initialize the baseline model
    try:
        model = baseline_pipeline()
        model.eval()
    except Exception as e:
        print(f"Error initializing model: {e}") 
        return
    # Load audio files
    audio_files = load_audio_files(input_directory)
    # Process each audio file
    for audio_file in audio_files:
        try:
            # Use librosa to load audio file and convert to mono channel
            data, sr = librosa.load(audio_file, mono=True, sr=None)
            data = torch.from_numpy(data).float()
            
            # Check audio dimention and add one if only has one.
            if len(data.shape) == 1:
                data = data.unsqueeze(0)  # Add a batch dimention
            
            item = {"audio": data, "sr": sr}
            
            # Perform inference
            candidate = model(item)
            candidates = candidate['candidates']
            
            if candidates: 
                candidate_str = candidates[0]
            else:
                candidate_str = ""
        
            # print(f"File: {audio_file}, Candidate: {candidates}")
            file_name = audio_file.replace('/workspace/final_60/','')
            file_name = file_name.replace('.wav','')
            file_name = file_name.replace('.flac','')
            print(f'{file_name}:\nCandidate: {candidate_str}')
            
        except Exception as e:
            print(f"Error processing {audio_file}: {e}")

In [None]:
# Specify the input directories
input_directory = '/path/to/files/'  # Replace with the path to your input directory

# Perform inference on the audio files
perform_inference(input_directory)