In [2]:
import os
import glob
import numpy as np
import IPython.display as ipd

import librosa
import soundfile as sd
import noisereduce as nr

import torch
from torch import Tensor
import torchaudio

from fusion_synthesis.ddsp.vocoder import SubtractiveSynthesiser
from phone_mapper.mapper import PhoneMapper
from utils import get_wavlm, get_hnr


In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
device = 'cpu'
sr = 16000
hop_length = 320
min_pitch = 80
pool_wav_num = 100
topN = 10

In [4]:
def get_wav(path, target_sr=sr, offset=None, dur=None):
    if offset == None and dur == None:
        src, orig_sr = librosa.load(src_wav_path)
    else:
        src, orig_sr = librosa.load(src_wav_path, offset=offset, duration=dur)
    src = librosa.resample(src, orig_sr=orig_sr, target_sr=target_sr)
    src = src/np.max(np.abs(src))
    return src

@torch.inference_mode()
def synthesise(ddsp: SubtractiveSynthesiser, x6: Tensor, x12: Tensor, f0: Tensor, tgt_loudness_db: float | None = -16,
               device: str = 'cpu', sr: int = 16000) -> Tensor:
    """
        Inputs:
            - `x6`: the 6th layer representation of wavlm of the source utterance
            - `x12`: the 12th layer representation of wavlm of the source utterance
             - `f0`: the z-normalised f0 of the source utterance
            - `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable.
            - `device`: if None, uses default device at initialization. Otherwise uses specified device
        Returns:
            - converted waveform of shape (T,)
        """

    prediction, _, _, _, _ = ddsp(x6=x6[None].to(device), x12=x12[None].to(device),
                                  f0_norm=f0.to(device))

    prediction = prediction.squeeze().cpu()

    # normalization
    if tgt_loudness_db is not None:
        src_loudness = torchaudio.functional.loudness(prediction[None], sr)
        tgt_loudness = tgt_loudness_db
        pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness)

    else:
        pred_wav = prediction
    return pred_wav

def get_mapper(sr:int =16000, hop_size:int =320, device:str='cpu'):
    wavlm = get_wavlm(True, True, device)
    mapper = PhoneMapper(wavlm=wavlm, device=device, hop_length=hop_size, sr=sr)
    return mapper


def get_ddsp_model(ddsp_model_path, sr: int = 16000, hop_size: int = 320, n_mag_harmonic: int = 176,
                   n_mag_noise: int = 80, n_harmonics: int = 150, device: str = 'cpu'):
    model = SubtractiveSynthesiser(
        sampling_rate=sr,
        block_size=hop_size,
        n_mag_harmonic=n_mag_harmonic,
        n_mag_noise=n_mag_noise,
        n_harmonics=n_harmonics,
        device=device
    )
    # load ddsp model's weights
    model.load_state_dict(torch.load(ddsp_model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

In [5]:
# load models and components

# get mapper
mapper = get_mapper(device=device)

# get ddsp model
model = get_ddsp_model(ddsp_model_path='resources/model_weights/ddsp-qbe.pt', device=device)

WavLM-Large loaded with 315,453,120 parameters.
 [Model] Sawtooth (with sinusoids) Subtractive Synthesiser


In [6]:
# get pool (run only once for the target speaker)
pool_path = '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/'
ref_wav_paths = glob.glob(os.path.join(pool_path, '*/*.flac'))[0:pool_wav_num]
print(ref_wav_paths[0:10])
pool = mapper.get_phone_pool(ref_wav_paths)

['/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0013.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0007.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0003.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0006.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0002.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0012.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0004.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0000.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0008.flac', '/scratch/sghosh/datasets/DDSP/libri_esd/val_extended/audio/8842/304647/8842-304647-0010.flac']


In [7]:
target_wav, osr = librosa.load(ref_wav_paths[1])
print(osr)
ipd.Audio(target_wav, rate=osr)

22050


In [8]:
src_wav_path = '/scratch/sghosh/supplementary_v2/dementia/dementia1.wav'

In [9]:
src = get_wav(src_wav_path)
src_filtered = nr.reduce_noise(src, sr, n_std_thresh_stationary=1,stationary=False, prop_decrease=0.9, hop_length=hop_length)
print('src',"HNR:", get_hnr(src))
print('src_f',"HNR:", get_hnr(src_filtered))
ipd.display(
    ipd.Audio(src_filtered, rate=sr),
    ipd.Audio(src, rate=sr)
)

src HNR: 10.691750032299959
src_f HNR: 9.59570528511392


In [10]:
x6, x12, f0_norm = mapper.get_mapping(src=src_wav_path, pool=pool)
pred_wav = synthesise(ddsp=model, x6=x6, x12=x12, f0=f0_norm, device='cpu')

In [11]:
nrw = nr.reduce_noise(pred_wav, sr, n_std_thresh_stationary=1,stationary=True, prop_decrease=0.9, hop_length=hop_length)
nrw=nrw/np.max(np.abs(nrw))
print('out_wav',"HNR:", get_hnr(pred_wav))
print('out_wav_filtered',"HNR:", get_hnr(nrw))
ipd.display(
    ipd.Audio(pred_wav, rate=sr),
    ipd.Audio(nrw, rate=sr)
)

out_wav HNR: 10.644068142379908
out_wav_filtered HNR: 9.411747429175348
