In [None]:
!export CUDA_VISIBLE_DEVICES=4

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
def sound( x, rate=8000, label=''):
    from IPython.display import display, Audio, HTML
    if label is '':
        display( Audio( x, rate=rate))
    else:
        display( HTML( 
        '<style> table, th, td {border: 0px; }</style> <table><tr><td>' + label + 
        '</td><td>' + Audio( x, rate=rate)._repr_html_()[3:] + '</td></tr></table>'
        ))
import pdb

In [None]:
emb_hyp = {
    'input_size': 80,
    'channels': [1024, 1024, 1024, 1024, 3072],
    'kernel_sizes': [5, 3, 3, 3, 1],
    'dilations': [1, 2, 3, 4, 1],
    'groups': [1, 1, 1, 1, 1],
    'attention_channels': 128,
    'lin_neurons': 192,
}

from spid_modules.ECAPA_TDNN import ECAPA_TDNN
embedder = ECAPA_TDNN(**emb_hyp)


In [None]:
import speechbrain
pretrainer = speechbrain.utils.parameter_transfer.Pretrainer(collect_in='./emb_test', loadables={'embedding_model': embedder},
                                                             paths={'embedding_model': '/mnt/data/zhepei/outputs/ecapa_augment_8k/1986/save/CKPT+2021-10-30+08-02-06+00/embedding_model.ckpt'})
pretrainer.collect_files()
pretrainer.load_collected('cpu')
embedder.eval()


In [None]:
from speechbrain.lobes.features import Fbank
compute_features = Fbank(sample_rate=8000, n_mels=80)
mean_var_norm = speechbrain.processing.features.InputNormalization(
    norm_type='sentence',
    std_norm=False
)
mean_var_norm_emb = speechbrain.processing.features.InputNormalization(
    norm_type='global',
    std_norm=False
)
mean_var_norm.eval()
mean_var_norm_emb.eval()

In [None]:
hyp = {
    'sample_rate': 8000,
    'training_signal_len': 40000,
    'train_dataloader_opts': {
        'batch_size': 1,
        'num_workers': 0,
    },
    'valid_dataloader_opts': {
        'batch_size': 1,
        'num_workers': 0,
    },
    'data_folder': '/mnt/data/wham/wham_original',
    'wsj_folder': '/mnt/data/wsj0.8k',
    'base_folder_dm_info_list': [
        {
            'path': '/mnt/data/wsj0.8k/si_tr_s/',
            'ext': 'wav',
            'type': 'clean',
        }
    ],
    'data_clean_prob': 1.,
    
#     'train_txtpath': '/mnt/data/Speech/wsj_tse/mix_2_spk_tr_extr.txt',
#     'train_wham_folder': '/mnt/data/wham/wham_original/wav8k/min/tr',
    'train_data': '/mnt/data/zhepei/outputs/sb_tse/results/2021-10-11+21-30-05+seed_1234+xformer-wham/save/wham_tse_tr.csv',
    'valid_data': '/mnt/data/zhepei/outputs/sb_tse/results/2021-10-27+03-53-37+seed_1234+xformer-wham-pre/save/wham_tse_tt.csv',

}

In [None]:
import torch
import numpy as np
from speechbrain.dataio.batch import PaddedBatch

# Test Dynamic mixing

In [None]:
from data.wham_data_utils import dynamic_mixing_prep
train_dl = dynamic_mixing_prep(hyp, 'train')

In [None]:
def listen_batch(batch):
    mix = batch['mix_sig'].data
    s1 = batch['s1_sig'].data
    s2 = batch['s2_sig'].data
    enr = batch['enr_sig'].data
    noise = batch['noise_sig'].data
    diff = mix - s1 - s2 - noise
    print(abs(mix).max())
    print((diff**2).mean())
    sound(mix[0].numpy(), rate=hyp['sample_rate'], label='mix')
    sound(s1[0].numpy(), rate=hyp['sample_rate'], label='s1')
    sound(s2[0].numpy(), rate=hyp['sample_rate'], label='s2')
    sound(enr[0].numpy(), rate=hyp['sample_rate'], label='enr')

In [None]:
def compute_cos_sim(v1, v2):
    return (v1 * v2).sum() / (torch.norm(v1) * torch.norm(v2) + 1e-8)

def compute_embedding(wavs, wav_lens):
    """Compute speaker embeddings.
    Arguments
    ---------
    wavs : Torch.Tensor
        Tensor containing the speech waveform (batch, time).
        Make sure the sample rate is fs=16000 Hz.
    wav_lens: Torch.Tensor
        Tensor containing the relative length for each sentence
        in the length (e.g., [0.8 0.6 1.0])
    """
    with torch.no_grad():
        scales = 0.9 / torch.amax(torch.abs(wavs), dim=-1, keepdim=True)
        wavs = wavs * scales
        feats = compute_features(wavs)
        feats = mean_var_norm(feats, wav_lens)
        embeddings = embedder(feats, wav_lens)
        embeddings = mean_var_norm_emb(
            embeddings, torch.ones(embeddings.shape[0]).to(embeddings.device)
        )
        embeddings = embeddings / (1e-8 + torch.norm(embeddings, p=2, dim=-1, keepdim=True))
    return embeddings.squeeze(1)

def check_emb(batch):
    s1 = batch['s1_sig'].data
    s2 = batch['s2_sig'].data
    enr = batch['enr_sig'].data
    s1_emb = compute_embedding(s1, batch['s1_sig'].lengths)
    s2_emb = compute_embedding(s2, batch['s2_sig'].lengths)
    enr_emb = compute_embedding(enr, batch['enr_sig'].lengths)
    pos_sim = compute_cos_sim(s1_emb[0], enr_emb[0]).item()
    neg_sim = compute_cos_sim(s2_emb[0], enr_emb[0]).item()
    if np.isnan(pos_sim - neg_sim).any():
        pdb.set_trace()
    print('Positive sim: {} -- Negative sim: {} -- diff: {}'.format(pos_sim, neg_sim, pos_sim-neg_sim))

# def check_emb_noise(batch):
#     s1 = batch['s1_sig'].data
#     s2 = batch['s2_sig'].data
#     enr = batch['enr_sig'].data
#     noise = batch['noise_sig'].data
#     s1_emb = embedder(s1+noise)
#     s2_emb = embedder(s2+noise)
#     min_len = min(enr.shape[-1], noise.shape[-1])
#     enr_emb = embedder(enr[..., :min_len]+noise[..., :min_len])
#     pos_sim = compute_cos_sim(s1_emb[0], enr_emb[0]).item()
#     neg_sim = compute_cos_sim(s2_emb[0], enr_emb[0]).item()
#     print('Positive sim: {} -- Negative sim: {} -- diff: {}'.format(pos_sim, neg_sim, pos_sim-neg_sim))

In [None]:
for i, batch in enumerate(train_dl):  
    if i == 1:
        listen_batch(batch)
        break

In [None]:
for i, batch in enumerate(train_dl):
    with torch.no_grad():
        check_emb(batch)
    if i == 20:
        break

# Test static mixing 

In [None]:
from data.wham_data_utils import static_data_prep
valid_ds = static_data_prep(hyp, 'valid')
valid_dl = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=hyp["valid_dataloader_opts"]["batch_size"],
        num_workers=hyp["valid_dataloader_opts"]["num_workers"],
        collate_fn=PaddedBatch,
        worker_init_fn=lambda x: np.random.seed(
            int.from_bytes(os.urandom(4), "little") + x
        ),
    )

In [None]:
for i, batch in enumerate(valid_dl):  
    if i == 1:
        listen_batch(batch)
        break

In [None]:
for i, batch in enumerate(valid_dl):  
    check_emb(batch)
    if i == 20:
        break

In [None]:
for i, batch in enumerate(valid_dl):  
    check_emb(batch)
    if i == 20:
        break