In [1]:
import os
import random
import soundfile as sf
import torch
import yaml
import json
import argparse
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from pathlib import Path

from asteroid.metrics import get_metrics
from asteroid.utils import tensors_to_device
from asteroid.dsp.normalization import normalize_estimates

import matplotlib.pyplot as plt 
import numpy as np

import sys 
sys.path.append("/data1/wangyiwen/repos/graduateproject/speakerbeam/src/")


from models.td_speakerbeam import TimeDomainSpeakerBeam

In [2]:
model = TimeDomainSpeakerBeam(i_adapt_layer=7, adapt_layer_type='mul', adapt_enroll_dim=128, causal=True)

In [3]:
test_input = torch.randn(1, 32000)
test_enroll_input = torch.randn(1, 32000)
test_input = test_input.cuda()
test_enroll_input = test_enroll_input.cuda()
model = model.cuda()
model.eval()


TimeDomainSpeakerBeam(
  (encoder): Encoder(
    (filterbank): FreeFB()
  )
  (masker): TDConvNetInformed(
    (bottleneck): Sequential(
      (0): GlobLN()
      (1): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
    )
    (TCN): ModuleList(
      (0): Conv1DBlock(
        (shared_block): Sequential(
          (0): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
          (1): PReLU(num_parameters=1)
          (2): GlobLN()
          (3): Sequential(
            (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,), groups=512)
            (1): _Chop1d()
          )
          (4): PReLU(num_parameters=1)
          (5): GlobLN()
        )
        (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
        (skip_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
      )
      (1): Conv1DBlock(
        (shared_block): Sequential(
          (0): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
          (1): PReLU(num_parameters=1)
          (2): GlobLN()
        

In [4]:
model_forward_output = model(test_input, test_enroll_input)
print(model_forward_output.shape)

torch.Size([1, 1, 32000])


In [5]:
from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape
from asteroid.models.base_models import _shape_reconstructed, _unsqueeze_to_3d
from asteroid.models.base_models import BaseEncoderMaskerDecoder

In [6]:
shape = jitable_shape(test_input)
wav = _unsqueeze_to_3d(test_input)
tf_rep = model.forward_encoder(wav)
enroll_emb = model.auxiliary(test_enroll_input)
est_masks = model.forward_masker(tf_rep, enroll_emb)
masked_tf_rep = model.apply_masks(tf_rep, est_masks)
decoded = model.forward_decoder(masked_tf_rep)
reconstructed = pad_x_to_y(decoded, wav)
print("tf rep shape: ", tf_rep.shape)
print("enroll_emb shape: ", enroll_emb.shape)
print("est_masks shape: ", est_masks.shape)
print("masked_tf_rep shape: ", masked_tf_rep.shape)
print("decoded shape: ", decoded.shape)
print("reconstructed shape: ", reconstructed.shape)

tf rep shape:  torch.Size([1, 512, 3999])
enroll_emb shape:  torch.Size([1, 256])
est_masks shape:  torch.Size([1, 1, 512, 3999])
masked_tf_rep shape:  torch.Size([1, 1, 512, 3999])
decoded shape:  torch.Size([1, 1, 32000])
reconstructed shape:  torch.Size([1, 1, 32000])


In [7]:
streaming_reconstructed = torch.zeros((wav.shape[0], wav.shape[-1]), device=wav.device)
mixture_emb_buffer = torch.zeros((wav.shape[0], 1531, 512), device=wav.device)
curr_idx = 0
curr_tf_rep = tf_rep[:, :, curr_idx]
mixture_emb_buffer = torch.roll(mixture_emb_buffer, shifts=-1, dims=1)
mixture_emb_buffer[:, -1, :] = curr_tf_rep[:, :]
curr_est_masks = model.forward_masker(mixture_emb_buffer.permute(0, 2, 1), enroll_emb)
print(curr_est_masks.shape)

torch.Size([1, 1, 512, 1531])


In [8]:
est_masks[0, 0, :20, 0]

tensor([0.4690, 0.2773, 0.5096, 0.3732, 0.5565, 0.4809, 0.3234, 0.6867, 0.6778,
        0.7450, 0.2281, 0.6850, 0.4019, 0.1765, 0.3586, 0.4410, 0.6121, 0.4257,
        0.2441, 0.3039], device='cuda:0', grad_fn=<SelectBackward0>)

In [9]:
curr_est_masks[0, 0, :20, -1]

tensor([9.7874e-01, 9.9917e-01, 8.4398e-07, 9.6288e-08, 3.8234e-10, 1.0000e+00,
        2.3468e-02, 7.8944e-01, 7.4635e-02, 1.0000e+00, 3.1124e-01, 1.0000e+00,
        1.0000e+00, 9.9531e-01, 3.8189e-16, 1.6419e-01, 8.0911e-07, 9.8032e-01,
        4.5149e-04, 5.3674e-06], device='cuda:0', grad_fn=<SelectBackward0>)

In [29]:
model.masker

TDConvNetInformed(
  (bottleneck): Sequential(
    (0): GlobLN()
    (1): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
  )
  (TCN): ModuleList(
    (0): Conv1DBlock(
      (shared_block): Sequential(
        (0): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
        (1): PReLU(num_parameters=1)
        (2): GlobLN()
        (3): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,), groups=512)
          (1): _Chop1d()
        )
        (4): PReLU(num_parameters=1)
        (5): GlobLN()
      )
      (res_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
      (skip_conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
    )
    (1): Conv1DBlock(
      (shared_block): Sequential(
        (0): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
        (1): PReLU(num_parameters=1)
        (2): GlobLN()
        (3): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,), groups=512)
          (1): _C

In [31]:
tf_rep[0, :20, 0]

tensor([-6.2517e-05,  9.9679e-03, -2.4003e-02,  2.3115e-02,  8.3813e-02,
        -6.7316e-02, -5.7936e-02,  4.9822e-03,  1.4618e-02,  5.2281e-02,
        -3.1733e-02,  1.3381e-04, -4.3485e-03, -1.0658e-02,  3.3922e-02,
         3.2041e-02, -1.3777e-01, -1.0748e-01, -4.9663e-03, -5.4417e-02],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [32]:
curr_tf_rep[:, :20]

tensor([[-6.2517e-05,  9.9679e-03, -2.4003e-02,  2.3115e-02,  8.3813e-02,
         -6.7316e-02, -5.7936e-02,  4.9822e-03,  1.4618e-02,  5.2281e-02,
         -3.1733e-02,  1.3381e-04, -4.3485e-03, -1.0658e-02,  3.3922e-02,
          3.2041e-02, -1.3777e-01, -1.0748e-01, -4.9663e-03, -5.4417e-02]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [10]:
# est_masks = model.forward_masker(tf_rep, enroll_emb)
print("tf rep shape: ", tf_rep.shape)
print("enroll_emb shape: ", enroll_emb.shape)

tf rep shape:  torch.Size([1, 512, 3999])
enroll_emb shape:  torch.Size([1, 256])


In [11]:
est_masks1 = model.forward_masker(tf_rep, enroll_emb)
est_masks2 = model.forward_masker(tf_rep[:, :, 1000:], enroll_emb)
print("shape: ", est_masks1.shape, est_masks2.shape)
print(est_masks1[:, :, :20, -1])
print(est_masks2[:, :, :20, -1])

shape:  torch.Size([1, 1, 512, 3999]) torch.Size([1, 1, 512, 2999])
tensor([[[0.3182, 0.4103, 0.2749, 0.7780, 0.2668, 0.4868, 0.3352, 0.8830,
          0.6401, 0.5845, 0.3706, 0.5875, 0.4581, 0.2609, 0.4486, 0.9302,
          0.3708, 0.0555, 0.0947, 0.2834]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[[0.3169, 0.4103, 0.2741, 0.7786, 0.2666, 0.4863, 0.3350, 0.8828,
          0.6404, 0.5862, 0.3707, 0.5884, 0.4599, 0.2604, 0.4490, 0.9299,
          0.3687, 0.0554, 0.0949, 0.2839]]], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [49]:
est_masks1 = model.forward_masker(tf_rep, enroll_emb)
# est_masks2 = model.forward_masker(tf_rep[:, :, 100:100+1531].repeat(1, 1, 1533//3), enroll_emb)
est_mask3 = model.forward_masker(tf_rep[:, :, 100:100+1531], enroll_emb)
print("shape: ", est_masks1.shape, est_mask3.shape)
pos_idx = 0
print(est_masks1[:, :, :20, 1630])
# print(est_masks2[:, :, :20, pos_idx])
print(est_mask3[:, :, :20, -1])

shape:  torch.Size([1, 1, 512, 3999]) torch.Size([1, 1, 512, 1531])
tensor([[[0.4033, 0.4412, 0.1502, 0.6177, 0.1871, 0.6373, 0.4551, 0.7668,
          0.6814, 0.8087, 0.7720, 0.4975, 0.7221, 0.2791, 0.3796, 0.3808,
          0.2626, 0.2632, 0.0346, 0.7667]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[[0.4007, 0.4411, 0.1487, 0.6205, 0.1826, 0.6469, 0.4549, 0.7678,
          0.6834, 0.8120, 0.7707, 0.4983, 0.7289, 0.2783, 0.3755, 0.3814,
          0.2602, 0.2552, 0.0337, 0.7665]]], device='cuda:0',
       grad_fn=<SelectBackward0>)


tensor([[[0.4674, 0.3100, 0.5621, 0.3368, 0.5088, 0.4865, 0.3196, 0.6484,
          0.6562, 0.7752, 0.2165, 0.6907, 0.4120, 0.1770, 0.3037, 0.4914,
          0.5729, 0.4348, 0.2399, 0.2737]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[[0.4537, 0.2613, 0.6042, 0.4140, 0.5241, 0.5348, 0.2955, 0.6810,
          0.7003, 0.8005, 0.2099, 0.6353, 0.4235, 0.1421, 0.3295, 0.3737,
          0.6370, 0.3513, 0.2106, 0.3253]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[[0.5432, 0.3283, 0.5765, 0.3530, 0.4134, 0.6392, 0.3182, 0.6407,
          0.6767, 0.8328, 0.2003, 0.5588, 0.3980, 0.1498, 0.3431, 0.3774,
          0.6378, 0.4756, 0.1734, 0.3796]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[[0.1797, 0.3244, 0.3382, 0.6887, 0.3228, 0.7301, 0.4728, 0.7321,
          0.7064, 0.5452, 0.3705, 0.6328, 0.7500, 0.0324, 0.7958, 0.6931,
          0.3967, 0.0968, 0.0432, 0.5006]]], device='cuda:0',
       grad_fn=<SelectBackward0>)
