In [49]:
# 📦 Imports and Setup
import os
import torch
import torchaudio
from dccrn_model import DCCRN  # Make sure dccrn_model.py is in the same folder
import IPython.display as ipd
import numpy as np

# 🔧 Config
sample_rate = 48000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 📁 Paths
base_dir = os.getcwd()
noisy_dir = os.path.join(base_dir, "dataset", "noisy")
output_dir = os.path.join(base_dir, "output")
model_path = os.path.join(base_dir, "dccrn_final.pth")
os.makedirs(output_dir, exist_ok=True)

# ✅ Load model
model = DCCRN(
    rnn_layers=2,
    rnn_units=256,
    masking_mode='E',
    use_clstm=True,
    use_cbn=False,
    kernel_num=[32, 64, 128, 256, 256, 256]
).to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print("✅ Model loaded successfully")


DCCRN(
  (stft): ConvSTFT()
  (istft): ConviSTFT()
  (encoder): ModuleList(
    (0): Sequential(
      (0): ComplexConv2d(
        (real_conv): Conv2d(1, 16, kernel_size=(5, 2), stride=(2, 1), padding=(2, 0))
        (imag_conv): Conv2d(1, 16, kernel_size=(5, 2), stride=(2, 1), padding=(2, 0))
      )
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
    )
    (1): Sequential(
      (0): ComplexConv2d(
        (real_conv): Conv2d(16, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 0))
        (imag_conv): Conv2d(16, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 0))
      )
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
    )
    (2): Sequential(
      (0): ComplexConv2d(
        (real_conv): Conv2d(32, 64, kernel_size=(5, 2), stride=(2, 1), padding=(2, 0))
        (imag_conv): Conv2d(32, 64, kernel_size=(5, 2), stride=(2, 1),

In [51]:
def denoise_file(file_path):
    file_name = os.path.basename(file_path)

    if not os.path.isfile(file_path) or os.path.getsize(file_path) == 0:
        print(f"❌ Skipping invalid file: {file_path}")
        return None, None

    try:
        waveform, sr = torchaudio.load(file_path)
    except Exception as e:
        print(f"❌ torchaudio.load failed for {file_name}: {e}")
        return None, None

    if sr != sample_rate:
        print(f"⚠️ Skipping {file_name}: expected {sample_rate}Hz, got {sr}Hz")
        return None, None

    waveform = waveform.to(device)

    with torch.no_grad():
        _, enhanced = model(waveform)

    enhanced = enhanced.cpu()

    # ✅ Ensure shape is [1, T] and float32
    if enhanced.dim() == 1:
        enhanced = enhanced.unsqueeze(0)
    elif enhanced.dim() == 2 and enhanced.shape[0] > 1:
        pass
    elif enhanced.dim() == 2:
        enhanced = enhanced.squeeze(0).unsqueeze(0)

    enhanced = enhanced.clamp(-1.0, 1.0).to(torch.float32)

    # ✅ Save output
    out_path = os.path.join(output_dir, file_name)
    torchaudio.save(out_path, enhanced, sample_rate)

    return waveform.cpu(), enhanced


In [52]:
# 🎧 Play and compare a test file (e.g., file01.wav)
test_file = os.path.join(noisy_dir, "file1.wav")
noisy_audio, denoised_audio = denoise_file(test_file)

if noisy_audio is not None and denoised_audio is not None:
    print("🔊 Noisy Audio:")
    ipd.display(ipd.Audio(noisy_audio.numpy(), rate=sample_rate))

    print("🔉 Denoised Audio:")
    ipd.display(ipd.Audio(denoised_audio.numpy(), rate=sample_rate))
else:
    print("⚠️ Could not process test file.")


🔊 Noisy Audio:


🔉 Denoised Audio:
