In [None]:
import sys
import IPython.display as ipd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import librosa

from models.htsat import HTSAT_Swin_Transformer
import htsat_config

In [None]:
class HTSAT(nn.Module):
    def __init__(self):
        super(HTSAT, self).__init__()
        self.freq_ratio = htsat_config.htsat_spec_size // htsat_config.mel_bins
        self.target_T = int(htsat_config.hop_size * htsat_config.htsat_spec_size * self.freq_ratio) - 1
        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=htsat_config.htsat_spec_size,
            patch_size=htsat_config.htsat_patch_size,
            in_chans=1,
            num_classes=htsat_config.classes_num,
            window_size=htsat_config.htsat_window_size,
            config=htsat_config,
            depths=htsat_config.htsat_depth,
            embed_dim=htsat_config.htsat_dim,
            patch_stride=htsat_config.htsat_stride,
            num_heads=htsat_config.htsat_num_head
        )
    
    def forward(self, x, infer_mode=False):
        return self.sed_model(x, infer_mode)

    @torch.no_grad()
    def embed_sound(self, f):
        audio, sr = librosa.load(f, sr=None)
        audio, _ = librosa.effects.trim(audio, top_db=25)
        audio = torch.from_numpy(audio).float().unsqueeze(0)
        audio = audio[:, :self.target_T]
        return self.forward(audio)['latent_output']

htsat = HTSAT()
htsat.load_state_dict(torch.load('htsat_audioset_2048d.ckpt', map_location="cpu")["state_dict"])
htsat.eval()
print('')

In [None]:
y_ag0 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Acoustic_guitar/15ba13c4.wav')
y_ag1 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Acoustic_guitar/367ad7b1.wav')
y_apl0 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Applause/170eeda2.wav')
y_apl1 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Applause/2ee73a9d.wav')
y_brp0 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Burping_or_eructation/27d679b4.wav')
y_brp1 = htsat.embed_sound('../semaudio-single-ch/data/FSDSoundScapes/FSDKaggle2018/val/Burping_or_eructation/7edba859.wav')
print(y_ag0.shape)

In [None]:
print(F.cosine_similarity(y_ag0, y_ag1))
print(F.cosine_similarity(y_ag0, y_apl0))
print(F.cosine_similarity(y_ag1, y_brp1))
print(F.cosine_similarity(y_apl0, y_apl1))
print(F.cosine_similarity(y_brp0, y_brp1))