In [46]:
from module.model import Gvector
import torch
import json
from scipy.special import softmax
import soundfile as sf
from python_speech_features import logfbank
import numpy as np

In [3]:
mdl_clf_kwargs = {
    "channels": 16, 
    "block": "BasicBlock", 
    "num_blocks": [2,2,2,2], 
    "embd_dim": 1024, 
    "drop": 0.3, 
    "n_class": 821
}
model_path = 'exp/821b_bad_aug_0.1/chkpt/chkpt_best.pth'

In [5]:
def load_model(mdl_kwargs, model_path, device):
    model = Gvector(**mdl_kwargs)
    state_dict = torch.load(model_path, map_location=device)
    if 'model' in state_dict.keys():
        state_dict = state_dict['model']
    model.load_state_dict(state_dict)
    return model

In [7]:
sample_wav = '/DATA1/ziang/data/sw/3h-aligned/AmurFalcon/149819.wav'

In [35]:
class SVExtractor():
    def __init__(self, mdl_kwargs, model_path, device):
        self.model = self.load_model(mdl_kwargs, model_path, device)
        self.model.eval()
        self.device = device
        self.model = self.model.to(self.device)

    def load_model(self, mdl_kwargs, model_path, device):
        model = Gvector(**mdl_kwargs)
        state_dict = torch.load(model_path, map_location=device)
        if 'model' in state_dict.keys():
            state_dict = state_dict['model']
        model.load_state_dict(state_dict)
        return model

    def __call__(self, frame_feats):
        feat = torch.from_numpy(frame_feats).unsqueeze(0)
        feat = feat.float().to(self.device)
        with torch.no_grad():
            embd = self.model(feat)
        embd = embd.squeeze(0).cpu().numpy()
        return embd

In [36]:
model_clf = SVExtractor(mdl_clf_kwargs, model_path, 'cpu')

In [38]:
def extract_feat(wav_path, to_tensor=False, cmn=True, device="cpu"):
    kwargs = {
        "winlen": 0.025,
        "winstep": 0.01,
        "nfilt": 80,
        "nfft": 2048,
        "lowfreq": 50,
        "highfreq": 8000,
        "preemph": 0.97
    }
    y, sr = sf.read(wav_path)
    logfbankFeat = logfbank(y, sr, **kwargs)
    if cmn:
        logfbankFeat -= logfbankFeat.mean(axis=0, keepdims=True)
#     if to_tensor:
#         logfbankFeat = torch.from_numpy(logfbankFeat).unsqueeze(0)
#         return logfbankFeat.float().to(device)
    return logfbankFeat.astype('float32')

In [39]:
sample_feats = extract_feat(sample_wav, to_tensor=True)
# feats_tensor = torch.Tensor(sample_feats)

In [40]:
sample_feats

array([[-21.112143  , -20.329369  , -19.710909  , ..., -13.713343  ,
        -14.4133215 , -12.876054  ],
       [-13.920484  , -13.640373  , -13.68454   , ...,  -8.511074  ,
         -8.854626  ,  -8.220791  ],
       [ -6.4466815 ,  -6.4411244 ,  -6.434847  , ...,  -6.472795  ,
         -7.208386  ,  -7.3128448 ],
       ...,
       [ -0.28600848,  -0.3062762 ,  -0.33613098, ...,  -2.484031  ,
         -2.3440166 ,  -2.1353269 ],
       [ -5.0375047 ,  -5.0419736 ,  -5.0797276 , ...,  -8.505127  ,
         -8.492745  ,  -7.3151937 ],
       [-10.496453  , -10.426983  , -10.33989   , ..., -11.079787  ,
        -10.737795  ,  -9.616296  ]], dtype=float32)

In [44]:
np.argmax(softmax(model_clf(sample_feats)))

750

In [51]:
softmax(model_clf(sample_feats))[750]

0.9992602

In [49]:
with open('index-try/label2int_l1k.json','r') as f:
    label2int = json.load(f)
    int2label = {v:k for k,v in label2int.items()}

In [50]:
int2label[750]

'AmurFalcon'

## save jit model

In [52]:
model_jit = load_model(mdl_clf_kwargs, model_path, 'cpu')

In [57]:
model_scripted = torch.jit.script(model_jit) # Export to TorchScript
model_scripted.eval()
model_scripted.save('821b_jit.pt') # Save

## load jit model

In [60]:
class JITExtractor():
    def __init__(self, mdl_kwargs, model_path, device):
        self.model = self.load_model(mdl_kwargs, model_path, device)
        self.model.eval()
        self.device = device
        self.model = self.model.to(self.device)

    def load_model(self, mdl_kwargs, model_path, device):
        model = torch.jit.load(model_path, map_location=device)
        return model

    def __call__(self, frame_feats):
        feat = torch.from_numpy(frame_feats).unsqueeze(0)
        feat = feat.float().to(self.device)
        with torch.no_grad():
            embd = self.model(feat)
        embd = embd.squeeze(0).cpu().numpy()
        return embd

In [61]:
model_clf_jit = JITExtractor(mdl_clf_kwargs, '821b_jit.pt', 'cpu')

In [70]:
logits = softmax(model_clf_jit(sample_feats))
top_pred = np.argmax(logits)
confidence = logits[top_pred]

In [71]:
confidence

0.9992602

In [72]:
int2label[top_pred]

'AmurFalcon'