In [31]:
import torch
import torchaudio
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display, Markdown
from huggingface_hub import hf_hub_download
import ipywidgets as widgets
import torch.nn.functional as F

#### Step 1： Load the WaveSP-Net model from Hugging Face

In [None]:
from model import WaveSP_Net

ckpt_path = hf_hub_download(
    repo_id="xxuan-speech/WaveSP-Net",
    filename="WaveSP-Net.pt"
)

model = WaveSP_Net(model_dir="/home/xxuan/pre-model/huggingface/wav2vec2-xls-r-300m/", 
                   prompt_dim=1024,
                   num_prompt_tokens = 6,
                   num_wavlet_tokens = 4,
                   dropout= 0.1).cuda()

model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
model.eval()
print("WaveSP-Net Model loaded successfully.")

#### Step 2: Load Real and Fake audio

In [None]:
def preprocess_audio(wav_path):
    waveform, sr = torchaudio.load(wav_path)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)
        sr = 16000
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    return waveform, sr

# Paths
real_wav_path = "./demo/Real0001.wav"
fake_wav_path = "./demo/Fake0001.wav"

real_audio, sr = preprocess_audio(real_wav_path)
fake_audio, _ = preprocess_audio(fake_wav_path)


def draw_waveform(waveform, sr, title="Waveform"):
    waveform = waveform.squeeze().numpy()
    time = np.arange(0, len(waveform)) / sr
    plt.figure(figsize=(5, 1.6))
    plt.plot(time, waveform, linewidth=0.8, color='steelblue')
    plt.xlabel("Time [s]")
    plt.ylabel("Amplitude")
    plt.title(title, fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

def draw_spectrogram(waveform, sr, title="Spectrogram"):
    waveform = waveform.squeeze().numpy()
    spec = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256)(torch.tensor(waveform))
    spec = torch.log(spec + 1e-9)
    plt.figure(figsize=(5, 1.6))
    plt.imshow(spec.numpy(), aspect='auto', origin='lower', cmap='magma')
    plt.xlabel("Time Frame")
    plt.ylabel("Frequency Bin")
    plt.title(title, fontsize=10)
    plt.colorbar()
    plt.tight_layout()
    plt.show()


def make_plot_output(plot_func, waveform, sr, title):
    out = widgets.Output(layout={'width': '500px'})
    with out:
        plot_func(waveform, sr, title)
    return out


real_audio_out = widgets.Output()
fake_audio_out = widgets.Output()

with real_audio_out:
    display(Audio(real_audio, rate=sr, autoplay=False))
with fake_audio_out:
    display(Audio(fake_audio, rate=sr, autoplay=False))


real_wf_out = make_plot_output(draw_waveform, real_audio, sr, "Real - Waveform")
fake_wf_out = make_plot_output(draw_waveform, fake_audio, sr, "Fake - Waveform")

real_sp_out = make_plot_output(draw_spectrogram, real_audio, sr, "Real - Spectrogram")
fake_sp_out = make_plot_output(draw_spectrogram, fake_audio, sr, "Fake - Spectrogram")


real_col = widgets.VBox([
    widgets.Label("Real", layout=widgets.Layout(font_weight='bold')),
    real_audio_out,
    real_wf_out,
    real_sp_out
], layout=widgets.Layout(padding='10px', border='1px solid #ddd'))

fake_col = widgets.VBox([
    widgets.Label("Fake", layout=widgets.Layout(font_weight='bold')),
    fake_audio_out,
    fake_wf_out,
    fake_sp_out
], layout=widgets.Layout(padding='10px', border='1px solid #ddd'))


display(widgets.HBox([real_col, fake_col], layout=widgets.Layout(align_items='flex-start', gap='20px')))

HBox(children=(VBox(children=(Label(value='Real'), Output(), Output(layout=Layout(width='500px')), Output(layo…

#### Step 3： Inference

In [None]:

# Ensure model is in eval mode and on correct device
model.eval()

# -------------------------
# Inference on real_audio
# -------------------------
waveform = real_audio.clone() 

max_len = 64000
if waveform.shape[1] > max_len:
    waveform = waveform[:, :max_len]
else:
    waveform = torch.nn.functional.pad(waveform, (0, max_len - waveform.shape[1]))

waveform = waveform.to(next(model.parameters()).device)

with torch.no_grad():
    _, logits = model(waveform)

probs = F.softmax(logits, dim=1)
real_fake_prob = probs[0, 1].item()
real_pred = "fake" if real_fake_prob > 0.5 else "real"

# -------------------------
# Inference on fake_audio
# -------------------------
waveform = fake_audio.clone()

if waveform.shape[1] > max_len:
    waveform = waveform[:, :max_len]
else:
    waveform = torch.nn.functional.pad(waveform, (0, max_len - waveform.shape[1]))

waveform = waveform.to(next(model.parameters()).device)

with torch.no_grad():
    _, logits = model(waveform)

probs = F.softmax(logits, dim=1)
fake_fake_prob = probs[0, 1].item()
fake_pred = "fake" if fake_fake_prob > 0.5 else "real"

# -------------------------
# Display results
# -------------------------
display(Markdown(f"##### 🔹 Real0001.wav - Prediction: `{real_pred}` - Fake Probability: {real_fake_prob:.6f}"))
display(Markdown(f"##### 🔸 Fake0001.wav - Prediction: `{fake_pred}` - Fake Probability: {fake_fake_prob:.6f}"))

##### 🔹 Real0001.wav - Prediction: `real` - Fake Probability: 0.032278

##### 🔸 Fake0001.wav - Prediction: `fake` - Fake Probability: 0.827179