In [6]:
import os
import torch
import torchaudio
import librosa
import numpy as np
import soundfile as sf
import scipy.signal as signal
import webrtcvad
from asteroid.models import DCCRNet

In [7]:
# === DEVICE SETUP ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # Optional: improves performance for fixed input sizes

# === Load pretrained Asteroid model ===
model = DCCRNet.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k")
model.to(device)

# === SAFE FILTERING HELPERS ===
def safe_cutoff(cutoff, sr, margin=0.95):
    nyquist = 0.5 * sr
    return min(cutoff, nyquist * margin)

# === AUDIO ENHANCEMENT FUNCTIONS ===
def bandpass_filter(audio, sr, lowcut=80, highcut=3400, order=4):
    nyquist = 0.5 * sr
    low = safe_cutoff(lowcut, sr) / nyquist
    high = safe_cutoff(highcut, sr) / nyquist
    if low >= high or high >= 1.0:
        return audio
    sos = signal.butter(order, [low, high], btype='band', output='sos')
    return signal.sosfilt(sos, audio)

def apply_eq(audio, sr):
    try:
        b1, a1 = signal.iirpeak(100 / (0.5 * sr), Q=1.0)
        audio = signal.lfilter(b1, a1, audio)
        audio *= 1.5
    except:
        pass
    try:
        b2, a2 = signal.iirpeak(3000 / (0.5 * sr), Q=1.5)
        audio = signal.lfilter(b2, a2, audio)
        audio *= 1.2
    except:
        pass
    return audio

def low_shelf_filter(audio, sr, cutoff=200, gain_db=6):
    cutoff = safe_cutoff(cutoff, sr)
    if cutoff <= 0:
        return audio
    gain = 10**(gain_db / 20)
    b, a = signal.butter(1, cutoff / (0.5 * sr), btype='low')
    return signal.lfilter(b, a, audio) * gain

def high_shelf_filter(audio, sr, cutoff=4000, gain_db=4):
    cutoff = safe_cutoff(cutoff, sr)
    if cutoff >= 0.5 * sr:
        return audio
    gain = 10**(gain_db / 20)
    b, a = signal.butter(1, cutoff / (0.5 * sr), btype='high')
    return signal.lfilter(b, a, audio) * gain

def compressor_limiter(audio, threshold_db=-20, ratio=4.0, makeup_gain_db=6):
    threshold = 10**(threshold_db / 20)
    makeup_gain = 10**(makeup_gain_db / 20)
    def compress_sample(x):
        abs_x = abs(x)
        if abs_x < threshold:
            return x
        else:
            compressed = threshold + (abs_x - threshold) / ratio
            return np.sign(x) * compressed
    compressed = np.array([compress_sample(x) for x in audio])
    return compressed * makeup_gain

def normalize_audio(audio):
    max_val = np.max(np.abs(audio))
    return audio / max_val if max_val > 0 else audio

# === SPECTRAL NOISE GATE ===
def spectral_noise_gate(audio, sr, gate_threshold_db=-40):
    stft = librosa.stft(audio, n_fft=1024, hop_length=256)
    magnitude, phase = np.abs(stft), np.angle(stft)
    db_mag = librosa.amplitude_to_db(magnitude)
    gate_mask = db_mag > gate_threshold_db
    gated_mag = magnitude * gate_mask
    stft_gated = gated_mag * np.exp(1j * phase)
    return librosa.istft(stft_gated, hop_length=256)

# === VAD CLEANING USING WEBRTC ===
def remove_non_speech_segments(audio, sr, aggressiveness=2):
    vad = webrtcvad.Vad(aggressiveness)
    window_duration = 30  # ms
    samples_per_window = int(sr * window_duration / 1000)
    bytes_per_sample = 2

    # Convert to 16-bit PCM
    int16_audio = np.int16(audio * 32768)
    pcm_audio = int16_audio.tobytes()

    voiced_audio = bytearray()
    for i in range(0, len(pcm_audio), samples_per_window * bytes_per_sample):
        window = pcm_audio[i:i + samples_per_window * bytes_per_sample]
        if len(window) < samples_per_window * bytes_per_sample:
            break
        if vad.is_speech(window, sample_rate=sr):
            voiced_audio.extend(window)

    # Convert back to float32
    if len(voiced_audio) == 0:
        return audio  # fallback
    voiced_np = np.frombuffer(voiced_audio, dtype=np.int16).astype(np.float32) / 32768.0
    return voiced_np

# === ASTEROID ENHANCEMENT ===
def enhance_with_asteroid_chunked(input_path, tmp_output_path, target_sr=16000, chunk_size=10):
    waveform, sr = torchaudio.load(input_path)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
    waveform = waveform.mean(dim=0, keepdim=True)  # mono

    chunk_len = chunk_size * target_sr  # e.g., 10 seconds
    total_len = waveform.shape[-1]

    enhanced_audio = []

    with torch.no_grad():
        for start in range(0, total_len, chunk_len):
            end = min(start + chunk_len, total_len)
            chunk = waveform[:, start:end].to(device)
            try:
                enhanced = model.separate(chunk)[0].squeeze(0).cpu().numpy()
            except RuntimeError as e:
                print(f"Error on chunk {start}-{end}: {e}")
                continue
            enhanced_audio.append(enhanced)

    full_audio = np.concatenate(enhanced_audio, axis=-1)
    sf.write(tmp_output_path, full_audio, target_sr)
    return tmp_output_path


  conf = torch.load(cached_model, map_location="cpu")


In [8]:
def enhance_atc_audio_asteroid_vad(input_folder):
    output_folder = input_folder.rstrip('/\\') + '-asteroid-vad-enhanced'
    os.makedirs(output_folder, exist_ok=True)

    for file_name in os.listdir(input_folder):
        if file_name.lower().endswith('.wav'):
            input_path = os.path.join(input_folder, file_name)
            tmp_path = os.path.join("tmp_asteroid_clean.wav")
            output_path = os.path.join(output_folder, file_name)

            print(f"🔊 Processing: {file_name}")
            try:
                # ML Denoising Step
                enhanced_path = enhance_with_asteroid_chunked(input_path, tmp_path)

                # Load enhanced audio
                audio, sr = librosa.load(enhanced_path, sr=None)

                # Remove static-only parts using VAD
                audio = remove_non_speech_segments(audio, sr)

                # Traditional Enhancements
                filtered = bandpass_filter(audio, sr)
                equalized = apply_eq(filtered, sr)
                shelved = low_shelf_filter(equalized, sr)
                shelved = high_shelf_filter(shelved, sr)
                compressed = compressor_limiter(shelved)
                gated = spectral_noise_gate(compressed, sr)
                normalized = normalize_audio(gated)

                sf.write(output_path, normalized, sr)
                print(f"✅ Saved to: {output_path}")

            except Exception as e:
                print(f"⚠️ Error processing {file_name}: {e}")

In [9]:
enhance_atc_audio_asteroid_vad('/home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files')

🔊 Processing: 08NC15MBP_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-enhanced/08NC15MBP_0101.wav
🔊 Processing: NI66MBQ_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-enhanced/NI66MBQ_0101.wav
🔊 Processing: NI56MBX_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-enhanced/NI56MBX_0101.wav
🔊 Processing: 15NC30MBQ_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-enhanced/15NC30MBQ_0101.wav
🔊 Processing: NI06FBP_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-enhanced/NI06FBP_0101.wav
🔊 Processing: 28NC51MBP_0101.wav
✅ Saved to: /home/shawnyzy/Documents/benchmarking-pipeline/datasets/zh-2-dev_en_separate/wav_files-asteroid-vad-en