In [1]:
s3prl_path= '../s3prl'


In [2]:
import sys
sys.path.append(f'{s3prl_path}/transformer/')
sys.path.append(f'{s3prl_path}/')

import yaml
import torch
from torch import nn
from model import TransformerModel , TransformerForMaskedAcousticModel , TransformerConfig
import transformer


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [3]:
sys.path.append('../scripts')
from model_builder import Wrapper_Model

In [4]:
class Runner():
    def __init__(self, device):
        self.device = device
        self.base_transformer_model = None
        self.model = None
        
    def set_transformer_model(self, transformer_config_path, transformer_weights_path):
        '''
        This Function loads the base transformer model.
        
        Args:
            transformer_config_path : config path(yaml) of the transformer
            transformer_weights_path : optional . if given loads the weight as well
        
        Returns:None
        '''

        # load base transformer model from config
        with open(transformer_config_path, 'r') as file:
            config= yaml.load(file, yaml.FullLoader)        

        model_config = TransformerConfig(config)
        input_dim = config['transformer']['input_dim']
        
        dr= model_config.downsample_rate
        hidden_size = model_config.hidden_size
        output_attention= False
        
        base_transformer_model = TransformerModel(model_config,input_dim,output_attentions=output_attention).to('cpu')

        #load weights
        if transformer_weights_path:
            ckpt = torch.load(transformer_weights_path, map_location='cpu')
            base_transformer_model.load_state_dict(ckpt['Transformer'])

        self.base_transformer_model = base_transformer_model
        
    def set_model(self,transformer_config_path, transformer_weights_path=None, ckpt_path=None):
        self.set_transformer_model(transformer_config_path, transformer_weights_path)
        self.model = Wrapper_Model(self.base_transformer_model)
        
        if ckpt_path:
            ckpt = torch.load(ckpt_path, map_location='cpu')
            self.model.load_state_dict(ckpt)
            
        self.model.to(self.device)
        
        

In [5]:
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
runner= Runner(device)

  return torch._C._cuda_getDeviceCount() > 0


In [6]:
transformer_config_path = "/home/jupyter/rfcx_submission/config/upstream_config.yaml"
transformer_weights_path = '/home/jupyter/rfcx/rfcx/model_weights/mockingjay_mel80_no_delta_cmvn_run4/states-2000.ckpt'

runner.set_model(transformer_config_path, transformer_weights_path)

In [None]:
import glob
from IPython.display import Audio
import matplotlib.pyplot as plt
import os

import numpy as np
import librosa

import sys
sys.path.append('/home/jupyter/rfcx_submission/scripts')

import load_mel
from preprocessor import Preprocessor

In [None]:
def plot_spectrogram_to_numpy(spectrogram):
    spectrogram = spectrogram.transpose(1, 0)
    fig, ax = plt.subplots(figsize=(18, 3))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower",cmap='magma')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()


In [None]:
preprocessor= Preprocessor(hidden_size =768, dr=1, device=torch.device('cpu'))


In [None]:
audio_files= glob.glob('/home/jupyter/rfcx/data/*/*.flac')
# audio_files= glob.glob('/home/jupyter/librispeech/LibriSpeech/test-other/1688/142285/*.flac')
len(audio_files)

In [None]:
input_file= audio_files[-50]
input_file

In [None]:
sample_rate=32000
y,sr= load_mel.load_audio(input_file, sample_rate)
feat= load_mel.get_spectrogram(y,sr,apply_denoise=False,return_audio=False)

load_mel.plot_feature(feat)

In [None]:
spec= torch.tensor(feat)
spec= spec.permute(1, 0)
spec_stacked, pos_enc, attn_mask = preprocessor.process_MAM_data(spec=spec)


spec_stacked.shape, pos_enc.shape, attn_mask.shape

In [None]:
z= runner.model(spec_stacked, pos_enc, attn_mask)

In [None]:
z