Skip to content

Conversation

nateanl
Copy link
Member

@nateanl nateanl commented Aug 16, 2021

Summary:

  • Add PSD module for computing covariance matrix.
  • Add Time-Frequency Masking based MVDR beamforming module.

TODO:

  • Add steering vector solution for getting mvdr beamforming weight
  • Add test code for PSD and MVDR modules
  • A colab example to show how to use MVDR on multi-channel audios
  • Add support for multi-channel Time-Frequency masks
  • Add support for complex-valued Time-Frequency masks.

@nateanl
Copy link
Member Author

nateanl commented Aug 16, 2021

I also benchmark the performance of torch.einsum and torch.matmul when computing the PSD matrix. Here's the test script.

Here is the result:

CPU:
[----------- PSD Computation ------------]
                      |  einsum  |  matmul
1 threads: -------------------------------
      [1, 2, 4000.0]  |     43   |    100 
      [1, 2, 8000]    |     62   |    170 
      [1, 2, 16000]   |     97   |    290 
      [1, 2, 32000]   |    400   |    710 
      [1, 4, 4000.0]  |     90   |    240 
      [1, 4, 8000]    |    300   |    600 
      [1, 4, 16000]   |    550   |   1100 
      [1, 4, 32000]   |   2400   |   3400 
      [1, 6, 4000.0]  |    400   |    800 
      [1, 6, 8000]    |    600   |   1200 
      [1, 6, 16000]   |    770   |   2100 
      [1, 6, 32000]   |   2300   |   4200 
      [1, 8, 4000.0]  |    520   |   1000 
      [1, 8, 8000]    |   1300   |   2200 
      [1, 8, 16000]   |   1600   |   3400 
      [1, 8, 32000]   |   9000   |  13000 
4 threads: -------------------------------
      [1, 2, 4000.0]  |     43   |    100 
      [1, 2, 8000]    |     61   |    160 
      [1, 2, 16000]   |    140   |    400 
      [1, 2, 32000]   |    460   |    900 
      [1, 4, 4000.0]  |     90   |    240 
      [1, 4, 8000]    |    290   |    570 
      [1, 4, 16000]   |    550   |   1100 
      [1, 4, 32000]   |   2000   |   3600 
      [1, 6, 4000.0]  |    460   |    700 
      [1, 6, 8000]    |    890   |   1200 
      [1, 6, 16000]   |    700   |   2460 
      [1, 6, 32000]   |   2300   |   4500 
      [1, 8, 4000.0]  |    600   |   1000 
      [1, 8, 8000]    |   1000   |   2500 
      [1, 8, 16000]   |   1600   |   3700 
      [1, 8, 32000]   |   9000   |  13000 
16 threads: ------------------------------
      [1, 2, 4000.0]  |     43   |     99 
      [1, 2, 8000]    |     60   |    160 
      [1, 2, 16000]   |    110   |    330 
      [1, 2, 32000]   |    320   |    720 
      [1, 4, 4000.0]  |     83   |    240 
      [1, 4, 8000]    |    300   |    560 
      [1, 4, 16000]   |    560   |   1500 
      [1, 4, 32000]   |   1800   |   3100 
      [1, 6, 4000.0]  |    330   |    700 
      [1, 6, 8000]    |    700   |   1300 
      [1, 6, 16000]   |    900   |   1800 
      [1, 6, 32000]   |   2100   |   4400 
      [1, 8, 4000.0]  |    510   |   1000 
      [1, 8, 8000]    |   1300   |   2400 
      [1, 8, 16000]   |   1600   |   3600 
      [1, 8, 32000]   |   8900   |  13000 
32 threads: ------------------------------
      [1, 2, 4000.0]  |     40   |     98 
      [1, 2, 8000]    |     60   |    160 
      [1, 2, 16000]   |    110   |    340 
      [1, 2, 32000]   |    320   |    710 
      [1, 4, 4000.0]  |     90   |    240 
      [1, 4, 8000]    |    300   |    600 
      [1, 4, 16000]   |    800   |   1200 
      [1, 4, 32000]   |   2000   |   3200 
      [1, 6, 4000.0]  |    300   |    640 
      [1, 6, 8000]    |    620   |   1300 
      [1, 6, 16000]   |    700   |   1700 
      [1, 6, 32000]   |   2100   |   4400 
      [1, 8, 4000.0]  |    530   |   1000 
      [1, 8, 8000]    |   1400   |   2000 
      [1, 8, 16000]   |   1600   |   3400 
      [1, 8, 32000]   |   9300   |  13000 

Times are in microseconds (us).

According to the table, einsum is faster than matmul operation.

@nateanl
Copy link
Member Author

nateanl commented Aug 17, 2021

PyTorch currently doesn't support batch version of trace method. There's a linalg.trace PR ongoing. We can change it to PyTorch method after the PR is settled.

@nateanl
Copy link
Member Author

nateanl commented Aug 24, 2021

@Emrys365 Could you help review the MVDR implementation? Thanks!
In terms of the unit tests, do you have any idea to verify the correctness of the solutions?

param(solution="ref_channel"),
param(solution="stv_power"),
# evd will fail since the eigenvalues are not distinct
# param(solution="stv_evd"),
Copy link
Member Author

@nateanl nateanl Aug 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eigenvalue decomposition solution fails the autograd test. I guess the reason is the eigenvalues are not distinct (i.e. some eigenvalues are close or identical). Is there a way to generate a masked matrix (STFT * mask) whose PSD matrix has distinct eigenvalues?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe following the normal simulation, i.e. simulating a mixture of the white noise and some signal and calculating the idea ratio mask (IRM)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. I may add an utterance and a room impulse response from open-source and do the multi-channel signal simulation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a condition that does not support autograd, please document it somewhere.
(we should be doing the same for time stretch as well.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in f976eae

Copy link
Contributor

@carolineechen carolineechen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments:

Copy link

@Emrys365 Emrys365 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MVDR implementation looks good to me. I just made some minor comments.

Comment on lines 256 to 271
def get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor:
r"""Estimate the steering vector by eigenvalue decomposition.

Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)

Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
_, indices = torch.max(w.abs(), dim=-1, keepdim=True)
indices = indices.unsqueeze(-1)
stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,))) # (..., freq, channel, 1)
return stv

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice there are several conventional approaches to estimating the relative transfer function (RTF) or normalized steering vector (reference paper), e.g. covariance subtraction (CS), covariance subtraction with EVD (CS-EVD), and covariance whitening (CW).
But I'm not sure which one is more robust.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nateanl, sorry for my late response. I ve tried several formulations this spring and this one [1] was the most robust when training e2e.

[1] Souden, M., Benesty, J., & Affes, S. (2009). On optimal frequency-domain multichannel
linear filtering for noise reduction. IEEE Transactions on audio, speech, and language processing, 18(2), 260-276.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Samuele! The formula refers to the reference channel selection solution in the module, right?
w = \Phi_NN^-1 \Phi_SS u / trace( \Phi_NN^-1 \Phi_SS)
I will benchmark the performances of all solutions and compare with yours. They should be consistent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep looks like the formula you already implemented as ref_channel

@nateanl
Copy link
Member Author

nateanl commented Aug 25, 2021

I will keep the "complex-valued mask support" as a future plan since I'm not sure of the formula to normalize the psd matrix by the mask. Here is my proposed method:

S = X * mask.unsqueeze(-3)  # complex multiplication
psd = torch.einsum("...cft,...eft->...ftce", [S, S.conj()])
if self.normalize:
    mask = mask.abs() * mask.abs()  # since the numerator multiplies the mask twice
    mask = mask.sum(dim=-1, keepdim=True)
    psd = psd / (mask + eps)

param(solution="ref_channel"),
param(solution="stv_power"),
# evd will fail since the eigenvalues are not distinct
# param(solution="stv_evd"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a condition that does not support autograd, please document it somewhere.
(we should be doing the same for time stretch as well.)

spec = torch.rand((2, 6, 201, 100), dtype=torch.cdouble)

# Single then transform then batch
expected = PSD()(spec).repeat(3, 1, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aware that this is taken from the other tests, but this turned out to miss certain cases. (Ref: #1451)

Can you make samples in batch different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we also need to update the existing batch consistency tests)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out, I updated the test in f976eae

I will create another PR for updating the rest of batch consistency tests.

@nateanl nateanl merged commit 4915524 into pytorch:main Aug 26, 2021
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
When reading a file with the ``read_csv`` method, by default the first line of the file is assumed to contain header information. In the example file mentioned in the tutorial, the first line contains data and if the ``names`` attribute of the ``read_csv`` method is not informed, the first line will be considered as a header. This way, when loading the dataset the first sample will be lost.

Co-authored-by: Holly Sweeney <77758406+holly1238@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants