In [None]:
%%capture
!pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# !pip install onnxruntime
!pip -q install pydub
!wget https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/UVR-MDX-NET-Inst_HQ_3.onnx
!wget https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx

In [None]:
import os
import json
import warnings
import librosa
import torch
import time
import numpy as np
import pandas as pd
import onnxruntime as ort
from pydub import AudioSegment
from IPython.display import Audio, display
from tqdm import tqdm

warnings.filterwarnings("ignore")

SAMPLING_RATE = 16000
INPUT_LENGTH = 9.01


def source_separation(predictor, audio):
    """
    Separate the audio into vocals and non-vocals using the given predictor.

    Args:
        predictor: The separation model predictor.
        audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate.

    Returns:
        dict: A dictionary containing the separated vocals and updated audio waveform.
    """

    mix, rate = None, None

    if isinstance(audio, str):
        mix, rate = librosa.load(audio, mono=False, sr=44100)
    else:
        # resample to 44100
        rate = audio["sample_rate"]
        mix = librosa.resample(audio["waveform"], orig_sr=rate, target_sr=44100)

    vocals, no_vocals = predictor.predict(mix)

    # convert vocals back to previous sample rate
    # print(f"vocals shape before resample: {vocals.shape}")
    vocals = librosa.resample(vocals.T, orig_sr=44100, target_sr=rate).T
    no_vocals = librosa.resample(no_vocals.T, orig_sr=44100, target_sr=rate).T
    # print(f"vocals shape after resample: {vocals.shape}")
    audio["waveform"] = vocals[:, 0]  # vocals is stereo, only use one channel
    audio["other_waveform"] = no_vocals[:, 0]  # no_vocals is stereo, only use one channel

    return audio

class ConvTDFNet:
    """
    ConvTDFNet - Convolutional Temporal Frequency Domain Network.
    """

    def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
        """
        Initialize ConvTDFNet.

        Args:
            target_name (str): The target name for separation.
            L (int): Number of layers.
            dim_f (int): Dimension in the frequency domain.
            dim_t (int): Dimension in the time domain (log2).
            n_fft (int): FFT size.
            hop (int, optional): Hop size. Defaults to 1024.

        Returns:
            None
        """
        super(ConvTDFNet, self).__init__()
        self.dim_c = 4
        self.dim_f = dim_f
        self.dim_t = 2**dim_t
        self.n_fft = n_fft
        self.hop = hop
        self.n_bins = self.n_fft // 2 + 1
        self.chunk_size = hop * (self.dim_t - 1)
        self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
        self.target_name = target_name

        out_c = self.dim_c * 4 if target_name == "*" else self.dim_c

        self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
        self.n = L // 2

    def stft(self, x):
        """
        Perform Short-Time Fourier Transform (STFT).

        Args:
            x (torch.Tensor): Input waveform.

        Returns:
            torch.Tensor: STFT of the input waveform.
        """
        x = x.reshape([-1, self.chunk_size])
        x = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop,
            window=self.window,
            center=True,
            return_complex=True,
        )
        x = torch.view_as_real(x)
        x = x.permute([0, 3, 1, 2])
        x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
            [-1, self.dim_c, self.n_bins, self.dim_t]
        )
        return x[:, :, : self.dim_f]

    def istft(self, x, freq_pad=None):
        """
        Perform Inverse Short-Time Fourier Transform (ISTFT).

        Args:
            x (torch.Tensor): Input STFT.
            freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None.

        Returns:
            torch.Tensor: Inverse STFT of the input.
        """
        freq_pad = (
            self.freq_pad.repeat([x.shape[0], 1, 1, 1])
            if freq_pad is None
            else freq_pad
        )
        x = torch.cat([x, freq_pad], -2)
        c = 4 * 2 if self.target_name == "*" else 2
        x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
            [-1, 2, self.n_bins, self.dim_t]
        )
        x = x.permute([0, 2, 3, 1])
        x = x.contiguous()
        x = torch.view_as_complex(x)
        x = torch.istft(
            x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
        )
        return x.reshape([-1, c, self.chunk_size])


class Predictor:
    """
    Predictor class for source separation using ConvTDFNet and ONNX Runtime.
    """

    def __init__(self, args, device):
        """
        Initialize the Predictor.

        Args:
            args (dict): Configuration arguments.
            device (str): Device to run the model ('cuda' or 'cpu').

        Returns:
            None

        Raises:
            ValueError: If the provided device is not 'cuda' or 'cpu'.
        """
        self.args = args
        self.model_ = ConvTDFNet(
            target_name="vocals",
            L=11,
            dim_f=args["dim_f"],
            dim_t=args["dim_t"],
            n_fft=args["n_fft"],
        )

        if device == "cuda":
            self.model = ort.InferenceSession(
                args["model_path"], providers=["CUDAExecutionProvider"]
            )
        elif device == "cpu":
            self.model = ort.InferenceSession(
                args["model_path"], providers=["CPUExecutionProvider"]
            )
        else:
            raise ValueError("Device must be either 'cuda' or 'cpu'")

    def demix(self, mix):
        """
        Separate the sources from the input mix.

        Args:
            mix (np.ndarray): Input mixture signal.

        Returns:
            np.ndarray: Separated sources.

        Raises:
            AssertionError: If margin is zero.
        """
        samples = mix.shape[-1]
        margin = self.args["margin"]
        chunk_size = self.args["chunks"] * 44100

        assert margin != 0, "Margin cannot be zero!"

        if margin > chunk_size:
            margin = chunk_size

        segmented_mix = {}

        if self.args["chunks"] == 0 or samples < chunk_size:
            chunk_size = samples

        counter = -1
        for skip in range(0, samples, chunk_size):
            counter += 1
            s_margin = 0 if counter == 0 else margin
            end = min(skip + chunk_size + margin, samples)
            start = skip - s_margin
            segmented_mix[skip] = mix[:, start:end].copy()
            if end == samples:
                break

        sources = self.demix_base(segmented_mix, margin_size=margin)
        return sources

    def demix_base(self, mixes, margin_size):
        """
        Base function for source separation.

        Args:
            mixes (dict): Dictionary of segmented mixtures.
            margin_size (int): Size of the margin.

        Returns:
            np.ndarray: Separated sources.
        """
        chunked_sources = []
        progress_bar = tqdm(total=len(mixes))
        progress_bar.set_description("Source separation")

        for mix in mixes:
            cmix = mixes[mix]
            sources = []
            n_sample = cmix.shape[1]
            model = self.model_
            trim = model.n_fft // 2
            gen_size = model.chunk_size - 2 * trim
            pad = gen_size - n_sample % gen_size
            mix_p = np.concatenate(
                (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
            )
            mix_waves = []
            i = 0
            while i < n_sample + pad:
                waves = np.array(mix_p[:, i : i + model.chunk_size])
                mix_waves.append(waves)
                i += gen_size

            mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)

            with torch.no_grad():
                _ort = self.model
                spek = model.stft(mix_waves)
                if self.args["denoise"]:
                    spec_pred = (
                        -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
                        + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
                    )
                    tar_waves = model.istft(torch.tensor(spec_pred))
                else:
                    tar_waves = model.istft(
                        torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
                    )
                tar_signal = (
                    tar_waves[:, :, trim:-trim]
                    .transpose(0, 1)
                    .reshape(2, -1)
                    .numpy()[:, :-pad]
                )

                start = 0 if mix == 0 else margin_size
                end = None if mix == list(mixes.keys())[::-1][0] else -margin_size

                if margin_size == 0:
                    end = None

                sources.append(tar_signal[:, start:end])

                progress_bar.update(1)

            chunked_sources.append(sources)
        _sources = np.concatenate(chunked_sources, axis=-1)

        progress_bar.close()
        return _sources

    def predict(self, mix):
        """
        Predict the separated sources from the input mix.

        Args:
            mix (np.ndarray): Input mixture signal.

        Returns:
            tuple: Tuple containing the mixture minus the separated sources and the separated sources.
        """
        if mix.ndim == 1:
            mix = np.asfortranarray([mix, mix])

        tail = mix.shape[1] % (self.args["chunks"] * 44100)
        if mix.shape[1] % (self.args["chunks"] * 44100) != 0:
            mix = np.pad(
                mix,
                (
                    (0, 0),
                    (
                        0,
                        self.args["chunks"] * 44100
                        - mix.shape[1] % (self.args["chunks"] * 44100),
                    ),
                ),
            )

        mix = mix.T
        sources = self.demix(mix.T)
        opt = sources[0].T

        if tail != 0:
            return ((mix - opt)[: -(self.args["chunks"] * 44100 - tail), :], opt)
        else:
            return ((mix - opt), opt)


class ComputeScore:
    """
    ComputeScore class for evaluating DNSMOS.
    """

    def __init__(self, primary_model_path, device="cpu") -> None:
        if device == "cuda":
            self.onnx_sess = ort.InferenceSession(
                primary_model_path, providers=["CUDAExecutionProvider"]
            )
            print(f"Using CUDA: {self.onnx_sess.get_providers()}")
        else:
            self.onnx_sess = ort.InferenceSession(primary_model_path)

    def audio_melspec(self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True):
        mel_spec = librosa.feature.melspectrogram(
            y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels
        )
        if to_db:
            mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40
        return mel_spec.T

    def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
        if is_personalized_MOS:
            p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046])
            p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
            p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132])
        else:
            p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
            p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
            p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])

        return p_sig(sig), p_bak(bak), p_ovr(ovr)

    def __call__(self, audio, sampling_rate, is_personalized_MOS=False):
        fs = SAMPLING_RATE
        if isinstance(audio, str):
            audio, _ = librosa.load(audio, sr=fs)
        elif sampling_rate != fs:
            audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=fs)

        actual_audio_len = len(audio)
        len_samples = int(INPUT_LENGTH * fs)

        while len(audio) < len_samples:
            audio = np.append(audio, audio)

        num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1
        hop_len_samples = fs

        predicted_mos_sig_seg_raw = []
        predicted_mos_bak_seg_raw = []
        predicted_mos_ovr_seg_raw = []
        predicted_mos_sig_seg = []
        predicted_mos_bak_seg = []
        predicted_mos_ovr_seg = []

        for idx in range(num_hops):
            audio_seg = audio[int(idx * hop_len_samples): int((idx + INPUT_LENGTH) * hop_len_samples)]
            if len(audio_seg) < len_samples:
                continue

            input_features = np.array(audio_seg).astype("float32")[np.newaxis, :]
            oi = {"input_1": input_features}
            mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]

            mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
                mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS
            )

            predicted_mos_sig_seg_raw.append(mos_sig_raw)
            predicted_mos_bak_seg_raw.append(mos_bak_raw)
            predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
            predicted_mos_sig_seg.append(mos_sig)
            predicted_mos_bak_seg.append(mos_bak)
            predicted_mos_ovr_seg.append(mos_ovr)

        return {
            "filename": "audio_clip",
            "len_in_sec": actual_audio_len / fs,
            "sr": fs,
            "num_hops": num_hops,
            "OVRL_raw": np.mean(predicted_mos_ovr_seg_raw),
            "SIG_raw": np.mean(predicted_mos_sig_seg_raw),
            "BAK_raw": np.mean(predicted_mos_bak_seg_raw),
            "OVRL": np.mean(predicted_mos_ovr_seg),
            "SIG": np.mean(predicted_mos_sig_seg),
            "BAK": np.mean(predicted_mos_bak_seg),
        }
def mos_prediction(waveform, sampling_rate, dnsmos_compute_score):
    dnsmos = dnsmos_compute_score(waveform, sampling_rate, False)['OVRL']
    return dnsmos

def standardization(audio, target_sr=24000, verbose=False):
    """
    Preprocess the audio file, including setting sample rate, bit depth, channels, and volume normalization.

    Args:
        audio (str or AudioSegment): Audio file path or AudioSegment object, the audio to be preprocessed.

    Returns:
        dict: A dictionary containing the preprocessed audio waveform, audio file name, and sample rate, formatted as:
              {
                  "waveform": np.ndarray, the preprocessed audio waveform, dtype is np.float32, shape is (num_samples,)
                  "name": str, the audio file name
                  "sample_rate": int, the audio sample rate
              }

    Raises:
        ValueError: If the audio parameter is neither a str nor an AudioSegment.
    """
    global audio_count
    name = "audio"

    if isinstance(audio, str):
        name = os.path.basename(audio)
        audio = AudioSegment.from_file(audio)
    elif isinstance(audio, AudioSegment):
        name = f"audio_{audio_count}"
        audio_count += 1
    else:
        raise ValueError("Invalid audio type")

    print("Entering the preprocessing of audio") if verbose else None

    # Convert the audio file to WAV format
    audio = audio.set_frame_rate(target_sr)
    audio = audio.set_sample_width(2)  # Set bit depth to 16bit
    audio = audio.set_channels(1)  # Set to mono

    print("Audio file converted to WAV format") if verbose else None

    # Calculate the gain to be applied
    target_dBFS = -20
    gain = target_dBFS - audio.dBFS
    print(f"Calculating the gain needed for the audio: {gain} dB") if verbose else None

    # Normalize volume and limit gain range to between -3 and 3
    normalized_audio = audio.apply_gain(min(max(gain, -3), 3))

    waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32)
    max_amplitude = np.max(np.abs(waveform))
    waveform /= max_amplitude  # Normalize

    print(f"waveform shape: {waveform.shape}") if verbose else None
    print("waveform in np ndarray, dtype=" + str(waveform.dtype)) if verbose else None

    return {
        "waveform": waveform,
        "name": name,
        "sample_rate": target_sr,
    }


def write_mp3(path, sr, x):
    """Convert numpy array to MP3."""
    try:
        # Ensure x is in the correct format and normalize if necessary
        if x.dtype != np.int16:
            # Normalize the array to fit in int16 range if it's not already int16
            x = np.int16(x / np.max(np.abs(x)) * 32767)

        # Create audio segment from numpy array
        audio = AudioSegment(
            x.tobytes(), frame_rate=sr, sample_width=x.dtype.itemsize, channels=1
        )
        # Export as MP3 file
        audio.export(path, format="mp3")
    except Exception as e:
        print(e)
        print("Error: Failed to write MP3 file.")

In [None]:
device_name = "cuda"
device = torch.device(device_name)
cfg = {
      "model_path": "UVR-MDX-NET-Inst_HQ_3.onnx",
      "denoise": True,
      "margin": 44100,
      "chunks": 15,
      "n_fft": 6144,
      "dim_t": 8,
      "dim_f": 3072
      }

dnsmos_compute_score = ComputeScore('sig_bak_ovr.onnx', device_name)

separate_predictor1 = Predictor(
        args=cfg, device=device_name
    )

In [None]:
audio_path = 'path'
audio = standardization(audio_path)
print("Before")
display(Audio(audio["waveform"], rate=audio["sample_rate"]))
time.sleep(1)
audio = source_separation(separate_predictor1, audio)
print("After")
print('Vocals')
display(Audio(audio["waveform"], rate=audio["sample_rate"]))
print('Other')
display(Audio(audio["other_waveform"], rate=audio["sample_rate"]))