<a href="https://colab.research.google.com/github/sridharkalluri/taalcadenza/blob/main/binaural_xtalk_demixing_estimation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Binaural cross-talk de-mixing estimation

**Author**: [Sridhar Kalluri](https://github.com/skim0514)_

Based on TorchAudio demo of Hybrid Demucs model for music source separation.


## 1. Overview

Binaural cross-talk de-mixing with the following steps

1. Build the pipeline -- install dependencies, load audio file, specify model, specify computational device.
2. Format the waveform into chunks of expected sizes and loop through chunks (with overlap) and feed into pipeline.
3. Collect output chunks and combine according to the way they have been overlapped.




## 2. Preparation

Install dependencies -- ``torchaudio`` and ``torch``.




In [1]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import matplotlib.pyplot as plt

2.1.0+cu118
2.1.0+cu118


In [2]:
from IPython.display import Audio
#from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
#from torchaudio.utils import download_asset

## 3. Construct the pipeline

Pre-trained model weights and related pipeline components are bundled as
:py:func:`torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS`. This is a
:py:class:`torchaudio.models.HDemucs` model trained on
[MUSDB18-HQ](https://zenodo.org/record/3338373)_ and additional
internal extra training data.
This specific model is suited for higher sample rates, around 44.1 kHZ
and has a nfft value of 4096 with a depth of 6 in the model implementation.



In [40]:
# Define model
#class NeuralNetwork(nn.Module):
#    def __init__(self):
#        super().__init__()
#        self.flatten = nn.Flatten()
#        self.linear_relu_stack = nn.Sequential(
#            nn.Linear(28*28, 512),
#            nn.ReLU(),
#            nn.Linear(512, 512),
#            nn.ReLU(),
#            nn.Linear(512, 10)
#        )

#    def forward(self, x):
#        x = self.flatten(x)
#        logits = self.linear_relu_stack(x)
#        return logits
#
#model = NeuralNetwork().to(device)
#print(model)

#bundle = HDEMUCS_HIGH_MUSDB_PLUS
#model = bundle.get_model()

# Define custom feature extraction pipeline.
#
# 1. Resample audio
# 2. Convert to mel-scale spectrogram
# 3. Apply channel-wise binaural calculations
# 4. Convert to mel-scale
#
class BinauralFeatures(torch.nn.Module):
    def __init__(
        self,
        input_freq=44100,
        resample_freq=22050,
        n_fft=256,
        n_mel=32,
        n_hop=16
    ):
        super().__init__()
        #self.resample = torchaudio.transforms.Resample(orig_freq=input_freq, new_freq=resample_freq)
        #self.mel_spect = torchaudio.transforms.MelSpectrogram(n_fft=n_fft,
        #                                                     sample_rate = resample_freq,
        #                                                     n_mels = n_mel)

        self.spec_centroid = torch.nn.Sequential(
            #torchaudio.transforms.Resample(input_freq, resample_freq),
            torchaudio.transforms.SpectralCentroid(sample_rate = resample_freq, n_fft= n_fft, hop_length = n_hop),
        )


    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        # Compute spectral centroid
        centroid = self.spec_centroid(waveform)
        # Compute centroid statistics
        avgcnt = torch.mean(centroid,1) #mean spectral centroid
        stdcnt = torch.std(centroid,1)  #standard deviation of spectral centroid
        slpcnt = torch.diff(centroid,1)/(n_hop/resample_freq) #slope of centroid versus time
        avgslpcnt = torch.mean(slpcnt)  #avg slope of centroid
        iacnt = torch.diff(centroid,0)       #interaural centroid difference
        avgiacnt = torch.mean(iacnt)    #avg interaural centroid difference
        avgslpiacnt = torch.mean(torch.diff(iacnt)) #avg slope of interaural centroid difference

        feats = {
                  "meancentroids": avgcnt,
                  "stdcentroids": stdcnt,
                  "meancentroidslope": avgslpcnt,
                  "avgiacnt": avgiacnt,
                }

        featvec = torch.cat([avgcnt, stdcnt, avgslpcnt, avgiacnt, avgslpiacnt],dim=1)

        return featvec

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = BinauralFeatures()
model.to(device)


BinauralFeatures(
  (spec_centroid): Sequential(
    (0): SpectralCentroid()
  )
)

## 4. Configure the application function

Difficult to have sufficient memory to apply a large model to
an entire song at once. To work around this limitation,
obtain the separated sources of a full song by
chunking the song into smaller segments and run through the
model piece by piece, and then rearrange back together.

When doing this, it is important to ensure some
overlap between each of the chunks, to accommodate for artifacts at the
edges. Due to the nature of the model, sometimes the edges have
inaccurate or undesired sounds included.

The chunking and arrangement implementation takes an overlap of 1 second on each side, and then does
a linear fade in and fade out on each side. Summing faded overlaps results in constant intensity throughout.
This accommodates for the artifacts by using less of the edges of the
model outputs.

<img src="https://download.pytorch.org/torchaudio/tutorial-assets/HDemucs_Drawing.jpg">



In [11]:
from torchaudio.transforms import Fade
sample_rate = 22050

def featureanalysis(
    mix,
    segment=10.0,
    overlap=0.1,
    device=None,
):
    """
    Conduct feature analysis on a stereo signal. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        out = binauralanalysis(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()


def binauralanalysis(signal):
  N_FFT = 256
  N_HOP = 16
  ANA_SAMPLE_RATE = 22050
  #stft = torchaudio.transforms.Spectrogram(
  #    n_fft=N_FFT,
  #    hop_length=N_HOP,
  #    power=1 #power = 0 (magnitude), 1 (power), 2 (none)
  #)
  #yft = stft(signal)

  speccentroid = torchaudio.transforms.SpectralCentroid(
      sample_rate = ANA_SAMPLE_RATE,
      n_fft= N_FFT,
      hop_length = N_HOP
  )
  ycn = speccentroid(signal)
  avgcnt = torch.mean(ycn,1) #mean spectral centroid
  stdcnt = torch.std(ycn,1)  #standard deviation of spectral centroid
  slpcnt = torch.diff(ycn,1)/(N_HOP/ANA_SAMPLE_RATE)
  avgslpcnt = torch.mean(slpcnt) #avg change of centroid
  iacnt = torch.diff(ycn,0)
  avgiacnt = torch.mean(iacnt)


  feats = {
                  "meancentroids": avgcnt,
                  "stdcentroids": stdcnt,
                  "meancentroidslope": avgslpcnt,
                  "avgiacnt": avgiacnt,
                }

  return feats




## 5. Run Model

Compute features

Test song, use Actions - One Minute Smile from
MedleyDB (Creative Commons BY-NC-SA 4.0). This is also located in
[MUSDB18-HQ](https://zenodo.org/record/3338373)_ dataset within
the ``train`` sources.

In [7]:
#Execute if necessary to upload test song in the Google Colab workspace
from google.colab import files
import os

# Upload multiple files
uploaded = files.upload()
# Print the uploaded files' details
for filename, content in uploaded.items():
    print(f'File {filename} uploaded with length {len(content)} bytes')
# Step 2: Get the original file name
original_filename = list(uploaded.keys())[0]
# Step 3: Specify the new file name
new_filename = "Actions - One Minute Smile-hlp_0005_mixture.wav"
# Step 4: Rename the file
os.rename(original_filename, new_filename)

Saving mixture.wav to mixture.wav
File mixture.wav uploaded with length 28854164 bytes


In [42]:
# Obtain audio file for processing
SAMPLE_SONG = '/content/Actions - One Minute Smile-hlp_0005_mixture.wav'
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)

#resample waveform to analyis sampling rate and standardize (0-mean, unit std) it
ANA_SAMPLE_RATE = 22050
resampler = torchaudio.transforms.Resample(sample_rate, ANA_SAMPLE_RATE, dtype=waveform.dtype)
rswaveform = resampler(waveform)
ref = rswaveform.mean(0)
rswaveform = (rswaveform - ref.mean()) / ref.std()  # normalization

# parameters
#segment: int = 10
#overlap = 0.1
#features = featureanalysis(waveform,segment,overlap)

#bn = binauralanalysis(rswaveform)

#sources = separate_sources(
#    model,
#    waveform[None],
#    device=device,
#    segment=segment,
#    overlap=overlap,
#)[0]
#sources = sources * ref.std() + ref.mean()

#sources_list = model.sources
#sources = list(sources)

#audios = dict(zip(sources_list, sources))

# Process waveform through pipeline
featurevec = BinauralFeatures(rswaveform)

In [1]:
featurevec

NameError: name 'featurevec' is not defined

In [42]:
cntransformer = torchaudio.transforms.SpectralCentroid(sample_rate = ANA_SAMPLE_RATE, n_fft = 256, hop_length=16)
ycn = cntransformer(rswaveform)

In [11]:
feats = binauralanalysis(rswaveform)

### 5.1 Feature analysis





In [15]:
N_FFT = 128
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)

In [12]:
feats

{'meancentroids': tensor([1897.6011, 1923.3502]),
 'stdcentroids': tensor([756.0094, 763.6097]),
 'meancentroidslope': tensor(-3.3405),
 'avgiacnt': tensor(1910.4757)}

tensor([0.0274, 0.0319, 0.0434,  ..., 0.0088, 0.0088, 0.0088])

In [35]:
bah.shape

torch.Size([4617455])

### 5.2 Audio Segmenting and Processing

Below is the processing steps and segmenting 5 seconds of the tracks in
order to feed into the spectrogram and to caclulate the respective SDR
scores.




In [None]:
def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
#    print(
#        "SDR score is:",
#        separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
#    )
    plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
    return Audio(predicted_source, rate=sample_rate)


segment_start = 150
segment_end = 155

frame_start = segment_start * sample_rate
frame_end = segment_end * sample_rate

drums_original = download_asset("tutorial-assets/hdemucs_drums_segment.wav")
bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")

drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)

bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)

vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)

other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)

mix_spec = mixture[:, frame_start:frame_end].cpu()

### 5.3 Spectrograms and Audio

In the next 5 cells, you can see the spectrograms with the respective
audios. The audios can be clearly visualized using the spectrogram.

The mixture clip comes from the original track, and the remaining
tracks are the model output




In [None]:
# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)

Drums SDR, Spectrogram, and Audio




In [None]:
# Drums Clip
output_results(drums, drums_spec, "drums")

Bass SDR, Spectrogram, and Audio




In [None]:
# Bass Clip
output_results(bass, bass_spec, "bass")

Vocals SDR, Spectrogram, and Audio




In [None]:
# Vocals Audio
output_results(vocals, vocals_spec, "vocals")

Other SDR, Spectrogram, and Audio




In [None]:
# Other Clip
output_results(other, other_spec, "other")

In [None]:
# Optionally, the full audios can be heard in from running the next 5
# cells. They will take a bit longer to load, so to run simply uncomment
# out the ``Audio`` cells for the respective track to produce the audio
# for the full song.
#

# Full Audio
# Audio(mixture, rate=sample_rate)

# Drums Audio
# Audio(audios["drums"], rate=sample_rate)

# Bass Audio
# Audio(audios["bass"], rate=sample_rate)

# Vocals Audio
# Audio(audios["vocals"], rate=sample_rate)

# Other Audio
# Audio(audios["other"], rate=sample_rate)