In [None]:
s3prl_path= '../s3prl'


In [None]:
import sys
sys.path.append(f'{s3prl_path}/transformer/')
sys.path.append(f'{s3prl_path}/')

import yaml
import torch
from torch import nn
from model import TransformerModel , TransformerForMaskedAcousticModel , TransformerConfig
import transformer


In [None]:
import glob
import matplotlib.pyplot as plt

In [None]:
sys.path.append('../scripts')
import load_mel
from augment import do_aug
from mixup import MixUp
from preprocessor import Preprocessor

In [None]:
from IPython.display import Audio
import librosa
import librosa.display

In [None]:
def show_audio(y,sr):
    librosa.display.waveplot(y, sr=sr)

    spec= load_mel.get_spectrogram(y,sr,apply_denoise=False,return_audio=False)
    load_mel.plot_feature(spec)
    
    return Audio(y,rate=sr)

In [None]:
audio_files= glob.glob('/home/jupyter/rfcx/data/*/*.flac')
# audio_files= glob.glob('/home/jupyter/librispeech/LibriSpeech/test-other/1688/142285/*.flac')
len(audio_files)


In [None]:
def load_model(ckpt_path,device='cpu'):


    ckpt = torch.load(ckpt_path, map_location='cpu')

    weights = ckpt['Transformer']
    config = ckpt['Settings']['Config']

    # print(ckpt.keys())


    model_config = TransformerConfig(config)
    input_dim = config['transformer']['input_dim']
    dr= model_config.downsample_rate
    hidden_size = model_config.hidden_size

    output_attention= False
    device= device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    output_dim = input_dim

    model = TransformerForMaskedAcousticModel(model_config,
                                                    input_dim,
                                                    output_dim = output_dim,
                                                    output_attentions=output_attention
                                                   ).to(device)
    
    model.Transformer.load_state_dict(ckpt['Transformer'])
    model.SpecHead.load_state_dict(ckpt['SpecHead'])

    model.eval()
    return model, hidden_size, dr, device



In [None]:
class CustomTransformerModel(torch.nn.Module):
  
    def __init__(self, transformer_model: TransformerForMaskedAcousticModel):
        super(CustomTransformerModel,self).__init__()
        self.transformer = transformer_model
        self.maxlen=3000
    
    def split(self,inp):
        #shape of each input is (batch_size, sequence, mel_features)
        #goal is to split the sequence if the sequence length is greater tha maxlen
        sequence_length = inp.shape[1]
        axes_length= len(inp.shape)
        
        if sequence_length> self.maxlen:
            
            sub_sequences= []
            num_subseq= sequence_length//self.maxlen
            start= 0
            
            for i in range(1,num_subseq+1):
                end= self.maxlen*i
                if axes_length==2:
                    sub_sequences.append(inp[:, start:end])
                else:
                    sub_sequences.append(inp[:, start:end, :])
                
                start=end
                
            if end<sequence_length:
                if axes_length==2:
                    sub_sequences.append(inp[:, start:])
                else:
                    sub_sequences.append(inp[:, start:, :])
        
            return sub_sequences
        else:
            return [inp]
            
        
    def forward(self, spec, pos_enc, attn_mask):
                
        split_spec= self.split(spec)
        split_pos_enc= self.split(pos_enc)
        split_attn_mask= self.split(attn_mask)
        
        pred_spec = []
        
        for a,b,c in zip(split_spec, split_pos_enc, split_attn_mask) :
            
            _pred_spec, _ = self.transformer(spec_input=a,
                                        pos_enc=b,
                                        mask_label=None,
                                        attention_mask=c,
                                        spec_label=None,
                                        head_mask=None)
            
            pred_spec.append(_pred_spec)
            
        pred_spec= torch.cat(pred_spec, axis=1)
        return pred_spec

In [None]:
def plot_spectrogram_to_numpy(spectrogram):
    spectrogram = spectrogram.transpose(1, 0)
    fig, ax = plt.subplots(figsize=(18, 3))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower",cmap='magma')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()


In [None]:
# ckpt_path= '/home/jupyter/rfcx/rfcx/model_weights/pretrained_model/states-1000000.ckpt'

ckpt_path= '/home/jupyter/rfcx/rfcx/model_weights/mockingjay_mel80_no_delta_cmvn_run4/states-2000.ckpt'
# ckpt_path= '/home/jupyter/rfcx/rfcx/model_weights/mockingjay_mel80_no_delta_cmvn_run3/states-3000.ckpt'



model , hidden_size, dr, device = load_model(ckpt_path)

In [None]:
preprocessor= Preprocessor(hidden_size =768, dr=1, device=torch.device('cpu'))
custom_model = CustomTransformerModel(model)
custom_model.eval()

In [None]:
input_file , input_file2 = audio_files[100] , audio_files[200]


In [None]:
SAMPLE_RATE = 16000


In [None]:
mixup= MixUp(load_mel.denoise, SAMPLE_RATE)

In [None]:
y1,sr= load_mel.load_audio(input_file, SAMPLE_RATE)
show_audio(y1, SAMPLE_RATE)

In [None]:
y2,sr= load_mel.load_audio(input_file2, SAMPLE_RATE)
show_audio(y2, SAMPLE_RATE)

In [None]:
alpha,y3 = mixup(y1, y2 )
print(alpha)

show_audio(y3, SAMPLE_RATE)


In [None]:

y3= do_aug(y3, SAMPLE_RATE)
show_audio(y3, SAMPLE_RATE)

In [None]:
feat= load_mel.get_spectrogram(y3,sr,apply_denoise=False,return_audio=False)


In [None]:
spec= torch.tensor(feat)
spec= spec.permute(1, 0)
spec_stacked, pos_enc, attn_mask = preprocessor.process_MAM_data(spec=spec)


spec_stacked.shape, pos_enc.shape, attn_mask.shape

In [None]:
pred_spec= custom_model(spec_stacked, pos_enc, attn_mask)
pred_spec.shape

In [None]:
plot_spectrogram_to_numpy(pred_spec.detach().numpy().squeeze())