In [31]:
import sys
sys.path.append('./obsoleted_models')

import torch

from obsoleted_models.models import Cnn14
from model.stft import AudioPreprocessor

In [32]:
DATASET = "./Dataset/" # 데이터셋 경로
CHECKPOINT = "./checkpoint/Cnn14_16k_mAP=0.438.pth" # 모델의 사전학습된 가중치
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # torch에 전달할 디바이스 종류

In [33]:
checkpoint = torch.load(CHECKPOINT, map_location=DEVICE, weights_only=False)

In [34]:
state_dict = checkpoint["model"]
print("전처리 키:")
for k in state_dict.keys():
    if "spectrogram_extractor" in k or "logmel_extractor" in k:
        print(k)

전처리 키:
spectrogram_extractor.stft.conv_real.weight
spectrogram_extractor.stft.conv_imag.weight
logmel_extractor.melW


In [35]:
waveform = torch.randn(1, 16000)

In [36]:
model = Cnn14(
    sample_rate=16000, # 샘플링 레이트 16k
    window_size=512, # 윈도우 사이즈
    hop_size=160, # 홉 사이즈
    mel_bins=64, # mel 주파수 채널 수
    fmin=50, # mel 주파수 최소치
    fmax=8000, # mel 주파수 최대치
    classes_num=527 # 분류할 클래스 숫자
)
model.eval()

preproc = AudioPreprocessor(
    sample_rate=16000,
    window_size=512,
    hop_size=160,
    mel_bins=64,
    fmin=50,
    fmax=8000
)

In [37]:
with torch.no_grad():
    logmel_internal = model.spectrogram_extractor(waveform)
    logmel_internal = model.logmel_extractor(logmel_internal)

    logmel_external = preproc(waveform, training=False)

In [38]:
print(type(logmel_internal), type(logmel_external)) 

<class 'torch.Tensor'> <class 'torch.Tensor'>


In [39]:
equal = torch.allclose(logmel_internal, logmel_external, atol=1e-5)
print("equal logmel:", equal)

equal logmel: True
