In [1]:
import sys
sys.path.append('./models/')

import os
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np

In [16]:
MODEL_PATH = "./result/best.pt"
CLASS_NAME = ["danger", "fire", "gas", "non", "tsunami"] # 분류할 클래스
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAMPLE_RATE = 16000
DURATION = 1.0
NUM_SAMPLE = int(SAMPLE_RATE * DURATION)

sample_audio_path = "./sample/sample.wav"

In [17]:
def preprocess(filepath, target_sample_rate=SAMPLE_RATE, target_duration=DURATION):
    '''
    오디오 파일을 로드하고 전처리
    '''
    num_samples_target = int(target_sample_rate * target_duration)

    try:
        waveform, sr = torchaudio.load(filepath)
    except Exception as e:
        print(f"파일로드 오류 {filepath}: {e}")
        return None
    
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
        waveform = resampler(waveform)

    current_length = waveform.shape[1]
    if current_length < num_samples_target:
        waveform = F.pad(waveform, (0, num_samples_target  - current_length))
        waveform = waveform[:, :num_samples_target]

    waveform = waveform / (waveform.abs().max() + 1e-9)
    waveform = waveform.squeeze(0)

    return waveform

def predict(model, audio_tensor, device=DEVICE, class_names=CLASS_NAME):
    '''
    전처리된 오디오에 대해 예측수행
    '''
    model.eval()
    audio_tensor = audio_tensor.to(device)

    if audio_tensor.ndim == 1:
        audio_tensor = audio_tensor.unsqueeze(0)

    with torch.no_grad():
        output_dict = model(audio_tensor)
        clipwise_output = output_dict['clipwise_output']

        probabilities = clipwise_output
        threshold = 0.5 # PANNs CNN14 모델의 기본값
        predictions_tensor = (probabilities > threshold).int().cpu().numpy()
        predicted_labels = []

        for i, class_prediction in enumerate(predictions_tensor[0]):
            if class_prediction == 1:
                predicted_labels.append(class_names[i])

        if not predicted_labels:
            highest_prob_idx = torch.argmax(probabilities, dim=1).item()
            return [f"{class_names[highest_prob_idx]}"], probabilities.cpu().numpy()[0]

        return predicted_labels, probabilities.cpu().numpy()[0]        

In [18]:
model = torch.load(MODEL_PATH, weights_only=False)
model.to(DEVICE)

Cnn14(
  (spectrogram_extractor): Spectrogram(
    (stft): STFT(
      (conv_real): Conv1d(1, 257, kernel_size=(512,), stride=(160,), bias=False)
      (conv_imag): Conv1d(1, 257, kernel_size=(512,), stride=(160,), bias=False)
    )
  )
  (logmel_extractor): LogmelFilterBank()
  (spec_augmenter): SpecAugmentation(
    (time_dropper): DropStripes()
    (freq_dropper): DropStripes()
  )
  (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_block1): ConvBlock(
    (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block2): ConvBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (con

In [19]:
audio_tensor = preprocess(sample_audio_path)
print(audio_tensor.shape)

torch.Size([16000])


In [25]:
predicted_labels, class_probabilities = predict(model, audio_tensor)

if predicted_labels:
    # print(f"예측결과: {",".join(predicted_labels)}")
    print(f"예측결과: {predicted_labels}")
print("\n클래스별 확률:")
for i, class_name_val in enumerate(CLASS_NAME):
    print(f" - {class_name_val}: {class_probabilities[i]:.4f}")

예측결과: ['fire']

클래스별 확률:
 - danger: 0.0000
 - fire: 0.9999
 - gas: 0.0002
 - non: 0.0000
 - tsunami: 0.0019
