In [1]:
#@title Installation. { vertical-output: true }
#@markdown Run this notebook in Google Colab by following [this link](https://colab.research.google.com/github/google-research/perch/blob/main/embed_audio.ipynb).
#@markdown
#@markdown Run this cell to install the project dependencies.
#%pip install git+https://github.com/google-research/perch.git

# update to allow running under python 3.12
!git clone https://github.com/google-research/perch.git
%cd perch
!sed -i 's/<3.12/<3.13/' pyproject.toml
%pip install -e .


Cloning into 'perch'...
remote: Enumerating objects: 7505, done.[K
remote: Counting objects: 100% (926/926), done.[K
remote: Compressing objects: 100% (112/112), done.[K
remote: Total 7505 (delta 841), reused 816 (delta 814), pack-reused 6579 (from 2)[K
Receiving objects: 100% (7505/7505), 15.24 MiB | 14.75 MiB/s, done.
Resolving deltas: 100% (5435/5435), done.
/content/perch
Obtaining file:///content/perch
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting perch_hoplite@ git+https://github.com/google-research/perch-hoplite.git (from chirp==0.1.0)
  Cloning https://github.com/google-research/perch-hoplite.git to /tmp/pip-install-ixv5lqhu/perch-hoplite_30fab4e1080f43f9928c00e5a66dfba4
  Running command git clone --filter=blob:none --quiet https://github.com/google-resear

In [1]:
#@title Imports. { vertical-output: true }

from etils import epath
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm
from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from perch_hoplite.zoo import model_configs



In [2]:
#@title Basic Configuration. { vertical-output: true }

#@markdown Define the model: perch or birdnet are most common for birds.
model_choice = 'humpback'  #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']
#@markdown Set the base directory for the project.
working_dir = '/tmp/agile'  #@param

# Set the embedding and labeled data directories.
embeddings_path = epath.Path(working_dir) / 'embeddings'
labeled_data_path = epath.Path(working_dir) / 'labeled'
embeddings_glob = embeddings_path / 'embeddings-*'

# OPTIONAL: Set up separation model.
separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os
#@title Create a new folder in Drive (if it doesn't already exist) within your Google drive.
base_dir = '/content/drive/MyDrive/'
#@ markdown Name of your new folder in Drive
new_folder_name = 'whale_denoising' #@param

drive_output_directory = base_dir + new_folder_name

try:
  if not os.path.exists(drive_output_directory):
    os.makedirs(drive_output_directory, exist_ok=True)
    print(f'Directory {drive_output_directory} created successfully.')
  else:
    print(f'Directory {drive_output_directory} already exists.')
except OSError as e:
    print("Error:", e)

Directory /content/drive/MyDrive/whale_denoising already exists.


In [5]:
#@title Embedding Configuration. { vertical-output: true }

config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()

#@markdown IMPORTANT: Select the target audio files.
#@markdown source_file_patterns should contain a list of globs of audio files, like:
#@markdown ['/home/me/*.wav', '/home/me/other/*.flac']
#config.source_file_patterns = ['gs://chirp-public-bucket/soundscapes/powdermill/Recording*/*.wav']  #@param
config.source_file_patterns = ['/content/drive/MyDrive/whale_denoising/input_audio/noisy/*.wav', '/content/drive/MyDrive/whale_denoising/input_audio/clean/*.wav']  #@param
config.output_dir = embeddings_path.as_posix()

preset_info = model_configs.get_preset_model_config(model_choice)
config.embed_fn_config.model_key = preset_info.model_key
config.embed_fn_config.model_config = preset_info.model_config

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

#@markdown File sharding automatically splits audio files into one-minute chunks
#@markdown for embedding. This limits both system and GPU memory usage,
#@markdown especially useful when working with long files (>1 hour).
use_file_sharding = False  #@param {type:'boolean'}
if use_file_sharding:
  config.shard_len_s = 60.0

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [6]:
#@title Set up. { vertical-output: true }

# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)], dtype=np.float32)
embed_fn.embedding_model.embed(z)
print('Setup complete!')



Loading model(s)...
Found 1473 source infos.


Test-run of model...
Setup complete!


# Extract the model

In [7]:
embedding_model = embed_fn.embedding_model

In [8]:
#@title Run embedding. { vertical-output: true }

# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.

embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)

print(f'Found {len(existing_embedding_ids)} existing embedding ids. \n'
      f'Processing {len(new_source_infos)} new source infos. ')

try:
  audio_loader = lambda fp, offset: audio_utils.load_audio_window(
      fp, offset, sample_rate=config.embed_fn_config.model_config.sample_rate,
      window_size_s=config.get('shard_len_s', -1.0))
  audio_iterator = audio_utils.multi_load_audio_window(
      filepaths=[s.filepath for s in new_source_infos],
      offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
      audio_loader=audio_loader,
  )
  with tf_examples.EmbeddingsTFRecordMultiWriter(
      output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:
    for source_info, audio in tqdm.tqdm(
        zip(new_source_infos, audio_iterator), total=len(new_source_infos)):
      if not embed_fn.validate_audio(source_info, audio):
        continue
      file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
      offset_s = source_info.shard_num * source_info.shard_len_s
      example = embed_fn.audio_to_example(file_id, offset_s, audio)
      if example is None:
        fail += 1
        continue
      file_writer.write(example.SerializeToString())
      succ += 1
    file_writer.flush()
finally:
  del(audio_iterator)
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

Found 0 existing embedding ids. 
Processing 1473 new source infos. 


100%|██████████| 1473/1473 [02:05<00:00, 11.72it/s]



Successfully processed 1473 source_infos, failed 0 times.
b'noisy/whale_1_10k.wav'
(1, 1, 2048)





# Create Denoiser

In [9]:
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np

# Simple 1D convolutional denoiser
class Denoiser1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=15, padding=7, bias=True),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, padding=7, bias=True),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=15, padding=7, bias=True),
            nn.ReLU(),
            nn.Conv1d(32, 1, kernel_size=15, padding=7, bias=False)
        )

    def forward(self, x):
        return self.net(x)

denoiser = Denoiser1D().cuda()
optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()



# Define bandpass processing function

In [17]:
import torchaudio

def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800):
    # waveform: (T,) or (1,T)
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)

    # Apply highpass then lowpass
    bp = torchaudio.functional.highpass_biquad(waveform, sr, low_hz)
    bp = torchaudio.functional.lowpass_biquad(bp, sr, high_hz)
    return bp.squeeze(0)


# Define post-processing (de-mean, DC filter, bandpass, silence edges, fade, normalise):

In [32]:
import torchaudio
import torch

def postprocess_audio(
    waveform,
    sr=10_000,
    hp_hz=10,
    peak=0.99,
    low_hz=200,
    high_hz=800,
    silence_ms=15.0,
    fade_ms=50.0,
):
    """
    waveform may be (T,), (1,T), or (B,1,T); returns (T,) after:
      - demean
      - high-pass
      - band-pass
      - zero first/last `silence_ms`
      - fade in/out over `fade_ms` (immediately after the silenced edges)
      - peak normalization to `peak`
    """
    x = waveform
    if x.dim() == 3:  # (B,1,T)
        x = x[0, 0]
    elif x.dim() == 2:  # (1,T)
        x = x[0]
    x = x.contiguous()

    # 1) Demean
    x = x - x.mean()

    # 2) Gentle high-pass
    x = torchaudio.functional.highpass_biquad(x.unsqueeze(0), sr, hp_hz).squeeze(0)

    # 3) Band-pass
    y = bandpass_waveform(x, sr, low_hz, high_hz)

    # --- 4 & 5) Edge silencing and fades (before normalization) ---
    T = y.numel()
    device = y.device

    n_sil = int(round(silence_ms * sr / 1000.0))
    n_fade = int(round(fade_ms * sr / 1000.0))

    # Build envelope
    env = torch.ones(T, device=device)

    # Zero first/last n_sil samples (if feasible)
    n_sil_eff = min(n_sil, T // 2) if T > 0 else 0
    if n_sil_eff > 0:
        env[:n_sil_eff] = 0.0
        env[-n_sil_eff:] = 0.0

    # Fades directly after the silenced edges
    avail = T - 2 * n_sil_eff
    if avail > 0:
        n_fade_eff = min(n_fade, avail // 2)  # ensure room for both fades
        if n_fade_eff > 0:
            # Fade-in: from n_sil_eff .. n_sil_eff + n_fade_eff
            fi_start = n_sil_eff
            fi_end = fi_start + n_fade_eff
            env[fi_start:fi_end] = torch.linspace(0.0, 1.0, steps=n_fade_eff, device=device)

            # Fade-out: from T - n_sil_eff - n_fade_eff .. T - n_sil_eff
            fo_end = T - n_sil_eff
            fo_start = fo_end - n_fade_eff
            env[fo_start:fo_end] = torch.linspace(1.0, 0.0, steps=n_fade_eff, device=device)

    # Apply envelope
    y = y * env

    # 6) Peak normalize last
    m = y.abs().max()
    if m > 0:
        y = y * (peak / m)

    return y



# Load denoiser from disk if it exists

In [None]:
denoiser.load_state_dict(torch.load("/content/drive/MyDrive/whale_denoising/whale_denoiser.pt"))
print("Denoiser loaded successfully.")

# Define data loader

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import random
import numpy as np
import glob

# --- Optional band-pass (no-op if torchaudio filter fails) ---
# def bandpass_waveform(x, sr, low_hz=200, high_hz=800):
#     try:
#         return torchaudio.functional.bandpass_biquad(
#             x.unsqueeze(0), sample_rate=sr,
#             center_freq=(low_hz + high_hz) * 0.5,
#             Q=(low_hz + high_hz) * 0.5 / max(high_hz - low_hz, 1.0)
#         ).squeeze(0)
#     except Exception:
#         return x

def _to_mono_and_crop_or_pad(wav, target_len):
    # wav: (C, T)
    if wav.dim() == 2:
        wav = wav[0]  # take first channel
    else:
        wav = wav.squeeze(0)
    T = wav.shape[-1]
    if T < target_len:
        wav = torch.nn.functional.pad(wav, (0, target_len - T))
    elif T > target_len:
        start = random.randint(0, T - target_len)
        wav = wav[start:start+target_len]
    return wav.contiguous()

def _scale_noise_to_snr(clean, noise, snr_db):
    # SNR (power): P_sig / P_noise = 10^(snr_db/10)
    eps = 1e-12
    sig_p = torch.clamp(clean.pow(2).mean(), min=eps)
    noi_p = torch.clamp(noise.pow(2).mean(), min=eps)
    scale = torch.sqrt(sig_p / (noi_p * (10.0 ** (snr_db / 10.0))))
    return noise * scale


# =========================
# Stage 1: CLEAN -> (noisy, clean) with synthetic noise
# =========================
class CleanWithSyntheticNoiseDataset(Dataset):
    """
    Creates supervised pairs for Stage 1 by corrupting CLEAN clips with synthetic noise.
    - If noise_paths is provided and use_noise_clips=True, uses those clips as noise sources.
    - Otherwise uses white noise (optionally band-passed).
    Targets are the original clean clips.
    """
    def __init__(self,
                 clean_paths,
                 noise_paths=None,
                 use_noise_clips=True,
                 sr=10_000,
                 target_len=39124,
                 snr_db_range=(0, 15),
                 bandpass_noise=True,
                 noise_bp_low_hz=200,
                 noise_bp_high_hz=800):
        self.clean_paths = list(clean_paths)
        self.noise_paths = list(noise_paths) if noise_paths is not None else []
        self.use_noise_clips = use_noise_clips and len(self.noise_paths) > 0
        self.sr = sr
        self.target_len = target_len
        self.snr_db_range = snr_db_range
        self.bandpass_noise = bandpass_noise
        self.noise_bp_low_hz = noise_bp_low_hz
        self.noise_bp_high_hz = noise_bp_high_hz

    def __len__(self):
        return len(self.clean_paths)

    def __getitem__(self, idx):
        clean_path = self.clean_paths[idx]
        clean, sr = torchaudio.load(clean_path)
        assert sr == self.sr, f"Expected sr={self.sr}, got {sr} for {clean_path}"

        clean = _to_mono_and_crop_or_pad(clean, self.target_len)

        # ----- build synthetic noise -----
        if self.use_noise_clips:
            noise_path = random.choice(self.noise_paths)
            noise, sr_n = torchaudio.load(noise_path)
            if sr_n != self.sr:
                noise = torchaudio.functional.resample(noise, sr_n, self.sr)
            noise = _to_mono_and_crop_or_pad(noise, self.target_len)
            # Optionally band-limit the NOISE (not the clean target)
            if self.bandpass_noise:
                noise = bandpass_waveform(noise, self.sr,
                                          self.noise_bp_low_hz, self.noise_bp_high_hz)
        else:
            # White noise (optionally band-pass)
            noise = torch.randn_like(clean)
            if self.bandpass_noise:
                noise = bandpass_waveform(noise, self.sr,
                                          self.noise_bp_low_hz, self.noise_bp_high_hz)

        # ----- mix at random SNR -----
        snr_db = random.uniform(*self.snr_db_range)
        noise = _scale_noise_to_snr(clean, noise, snr_db)
        noisy = (clean + noise).clamp_(-1.0, 1.0)

        # Optionally band-pass the CLEAN too (comment out if you prefer raw clean target)
        # clean = bandpass_waveform(clean, self.sr, self.noise_bp_low_hz, self.noise_bp_high_hz)

        # Return (B,T) style elements; your training loop can unsqueeze to (B,1,T)
        return noisy, clean


# =========================
# Stage 2: NOISY -> (waveform) for fine-tune with ScoreNet/perceptual loss
# =========================
class NoisyOnlyDataset(Dataset):
    """
    Provides real noisy waveforms cropped/padded to target_len.
    (No targets; Stage 2 uses ScoreNet/perceptual losses.)
    """
    def __init__(self,
                 noisy_paths,
                 sr=10_000,
                 target_len=39124,
                 apply_bandpass=False,
                 bp_low_hz=200,
                 bp_high_hz=800):
        self.noisy_paths = list(noisy_paths)
        self.sr = sr
        self.target_len = target_len
        self.apply_bandpass = apply_bandpass
        self.bp_low_hz = bp_low_hz
        self.bp_high_hz = bp_high_hz

    def __len__(self):
        return len(self.noisy_paths)

    def __getitem__(self, idx):
        path = self.noisy_paths[idx]
        wav, sr = torchaudio.load(path)
        assert sr == self.sr, f"Expected sr={self.sr}, got {sr} for {path}"
        wav = _to_mono_and_crop_or_pad(wav, self.target_len)
        if self.apply_bandpass:
            wav = bandpass_waveform(wav, self.sr, self.bp_low_hz, self.bp_high_hz)
        return wav


# ====== Build the two loaders ======
noisy_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/input_audio/noisy/*.wav'))
clean_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/input_audio/clean/*.wav'))

# Stage 1: uses ONLY clean clips as targets; noise source can be noisy clips or white noise
stage1_dataset = CleanWithSyntheticNoiseDataset(
    clean_paths=clean_paths,
    noise_paths=noisy_paths,         # optional; set to None to use white noise only
    use_noise_clips=True,            # True = use 'noisy' files as a natural noise bed
    sr=10_000,
    target_len=39124,
    snr_db_range=(0, 15),
    bandpass_noise=True,
    noise_bp_low_hz=200,
    noise_bp_high_hz=800
)
clean_loader = DataLoader(stage1_dataset, batch_size=16, shuffle=True,
                          num_workers=2, pin_memory=True)

# Stage 2: real noisy waveforms (no targets)
stage2_dataset = NoisyOnlyDataset(
    noisy_paths=noisy_paths,
    sr=10_000,
    target_len=39124,
    apply_bandpass=False  # keep raw; you can set True if you want pre-filtering here
)
noisy_loader = DataLoader(stage2_dataset, batch_size=16, shuffle=True,
                          num_workers=2, pin_memory=True)



# DC offset penalty

In [13]:
def dc_loss(x):  # x: (B,1,T) torch
    m = x.mean(dim=-1, keepdim=True)  # (B,1,1)
    return (m ** 2).mean()

lambda_dc = 0.01  # tune 0.005–0.05 if needed


# Stage 1: Supervised Pre-training (MSE)

In [14]:
# === Stage 1: Supervised MSE training with synthetic noise (Option A: dataset returns (noisy, clean)) ===
from tqdm import tqdm
import numpy as np
import torch

denoiser.train()
num_epochs_stage1 = 40

def to_numpy_for_tf(x):
    # x: (B,1,T) tensor
    x = x.detach().squeeze(1).cpu().numpy()
    x = x[..., np.newaxis].astype(np.float32)  # (B,T,1)
    target_len = 39124
    return x[:, :target_len, :] if x.shape[1] > target_len else np.pad(
        x, ((0, 0), (0, target_len - x.shape[1]), (0, 0))
    )

for epoch in range(num_epochs_stage1):
    total_loss = 0.0
    total_score = 0.0
    total_batches = 0

    for noisy, clean in tqdm(clean_loader):
        # Accept (B,T) or (B,1,T)
        if noisy.ndim == 2:  # (B,T) -> (B,1,T)
            noisy = noisy.unsqueeze(1)
            clean = clean.unsqueeze(1)

        noisy = noisy.cuda(non_blocking=True)   # (B,1,T)
        clean = clean.cuda(non_blocking=True)   # (B,1,T)

        denoised = denoiser(noisy)              # (B,1,T)

        # MSE reconstruction + DC penalty
        recon_loss = mse_loss(denoised, clean) + lambda_dc * dc_loss(denoised)

        # Optional: monitor TF whale score (not part of loss)
        denoised_np = to_numpy_for_tf(denoised)
        whale_scores = embedding_model.model(denoised_np, False, None).numpy()
        mean_score = float(whale_scores.mean())

        optimizer.zero_grad(set_to_none=True)
        recon_loss.backward()
        optimizer.step()

        total_loss += recon_loss.item()
        total_score += mean_score
        total_batches += 1

    print(f"[Stage 1][Epoch {epoch+1}] MSE loss = {total_loss / max(total_batches,1):.4f} | "
          f"TF whale score = {total_score / max(total_batches,1):.4f}")


100%|██████████| 35/35 [00:10<00:00,  3.29it/s]


[Stage 1][Epoch 1] MSE loss = 0.0009 | TF whale score = 3.1209


100%|██████████| 35/35 [00:06<00:00,  5.36it/s]


[Stage 1][Epoch 2] MSE loss = 0.0006 | TF whale score = 3.2386


100%|██████████| 35/35 [00:06<00:00,  5.26it/s]


[Stage 1][Epoch 3] MSE loss = 0.0005 | TF whale score = 3.4730


100%|██████████| 35/35 [00:06<00:00,  5.40it/s]


[Stage 1][Epoch 4] MSE loss = 0.0006 | TF whale score = 3.6923


100%|██████████| 35/35 [00:06<00:00,  5.30it/s]


[Stage 1][Epoch 5] MSE loss = 0.0005 | TF whale score = 3.8049


100%|██████████| 35/35 [00:06<00:00,  5.24it/s]


[Stage 1][Epoch 6] MSE loss = 0.0006 | TF whale score = 3.8155


100%|██████████| 35/35 [00:06<00:00,  5.32it/s]


[Stage 1][Epoch 7] MSE loss = 0.0006 | TF whale score = 3.7695


100%|██████████| 35/35 [00:06<00:00,  5.26it/s]


[Stage 1][Epoch 8] MSE loss = 0.0006 | TF whale score = 3.7806


100%|██████████| 35/35 [00:06<00:00,  5.44it/s]


[Stage 1][Epoch 9] MSE loss = 0.0005 | TF whale score = 3.8660


100%|██████████| 35/35 [00:06<00:00,  5.20it/s]


[Stage 1][Epoch 10] MSE loss = 0.0005 | TF whale score = 3.8058


100%|██████████| 35/35 [00:06<00:00,  5.38it/s]


[Stage 1][Epoch 11] MSE loss = 0.0006 | TF whale score = 3.7819


100%|██████████| 35/35 [00:06<00:00,  5.32it/s]


[Stage 1][Epoch 12] MSE loss = 0.0006 | TF whale score = 3.7728


100%|██████████| 35/35 [00:06<00:00,  5.28it/s]


[Stage 1][Epoch 13] MSE loss = 0.0006 | TF whale score = 3.7890


100%|██████████| 35/35 [00:06<00:00,  5.32it/s]


[Stage 1][Epoch 14] MSE loss = 0.0006 | TF whale score = 3.6626


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 15] MSE loss = 0.0006 | TF whale score = 3.7713


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 16] MSE loss = 0.0006 | TF whale score = 3.7000


100%|██████████| 35/35 [00:06<00:00,  5.33it/s]


[Stage 1][Epoch 17] MSE loss = 0.0005 | TF whale score = 3.7667


100%|██████████| 35/35 [00:06<00:00,  5.33it/s]


[Stage 1][Epoch 18] MSE loss = 0.0005 | TF whale score = 3.8062


100%|██████████| 35/35 [00:06<00:00,  5.35it/s]


[Stage 1][Epoch 19] MSE loss = 0.0005 | TF whale score = 3.7516


100%|██████████| 35/35 [00:06<00:00,  5.36it/s]


[Stage 1][Epoch 20] MSE loss = 0.0006 | TF whale score = 3.6350


100%|██████████| 35/35 [00:06<00:00,  5.30it/s]


[Stage 1][Epoch 21] MSE loss = 0.0005 | TF whale score = 3.6567


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 22] MSE loss = 0.0005 | TF whale score = 3.6686


100%|██████████| 35/35 [00:06<00:00,  5.26it/s]


[Stage 1][Epoch 23] MSE loss = 0.0006 | TF whale score = 3.6562


100%|██████████| 35/35 [00:06<00:00,  5.36it/s]


[Stage 1][Epoch 24] MSE loss = 0.0005 | TF whale score = 3.5571


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 25] MSE loss = 0.0005 | TF whale score = 3.6395


100%|██████████| 35/35 [00:06<00:00,  5.28it/s]


[Stage 1][Epoch 26] MSE loss = 0.0005 | TF whale score = 3.5895


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 27] MSE loss = 0.0006 | TF whale score = 3.5558


100%|██████████| 35/35 [00:06<00:00,  5.41it/s]


[Stage 1][Epoch 28] MSE loss = 0.0006 | TF whale score = 3.5235


100%|██████████| 35/35 [00:06<00:00,  5.35it/s]


[Stage 1][Epoch 29] MSE loss = 0.0006 | TF whale score = 3.5567


100%|██████████| 35/35 [00:06<00:00,  5.28it/s]


[Stage 1][Epoch 30] MSE loss = 0.0005 | TF whale score = 3.6587


100%|██████████| 35/35 [00:06<00:00,  5.25it/s]


[Stage 1][Epoch 31] MSE loss = 0.0005 | TF whale score = 3.5272


100%|██████████| 35/35 [00:06<00:00,  5.29it/s]


[Stage 1][Epoch 32] MSE loss = 0.0005 | TF whale score = 3.5379


100%|██████████| 35/35 [00:06<00:00,  5.31it/s]


[Stage 1][Epoch 33] MSE loss = 0.0005 | TF whale score = 3.5008


100%|██████████| 35/35 [00:06<00:00,  5.21it/s]


[Stage 1][Epoch 34] MSE loss = 0.0006 | TF whale score = 3.3648


100%|██████████| 35/35 [00:06<00:00,  5.37it/s]


[Stage 1][Epoch 35] MSE loss = 0.0005 | TF whale score = 3.3778


100%|██████████| 35/35 [00:06<00:00,  5.32it/s]


[Stage 1][Epoch 36] MSE loss = 0.0005 | TF whale score = 3.4097


100%|██████████| 35/35 [00:06<00:00,  5.36it/s]


[Stage 1][Epoch 37] MSE loss = 0.0006 | TF whale score = 3.2428


100%|██████████| 35/35 [00:06<00:00,  5.30it/s]


[Stage 1][Epoch 38] MSE loss = 0.0005 | TF whale score = 3.1641


100%|██████████| 35/35 [00:06<00:00,  5.30it/s]


[Stage 1][Epoch 39] MSE loss = 0.0005 | TF whale score = 3.3395


100%|██████████| 35/35 [00:06<00:00,  5.38it/s]

[Stage 1][Epoch 40] MSE loss = 0.0005 | TF whale score = 3.0591





# Build & Train ScoreNet

In [None]:
# # === Stage 2 prep: Train ScoreNet to mimic TF whale score ===
# class ScoreNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Conv1d(1, 16, 9, stride=2, padding=4),
#             nn.ReLU(),
#             nn.Conv1d(16, 32, 9, stride=2, padding=4),
#             nn.ReLU(),
#             nn.Conv1d(32, 64, 9, stride=2, padding=4),
#             nn.ReLU(),
#             nn.AdaptiveAvgPool1d(1),
#             nn.Flatten(),
#             nn.Linear(64, 1)
#         )

#     def forward(self, x):  # x: (B,1,T)
#         return self.net(x)

# score_net = ScoreNet().cuda()
# opt_score = torch.optim.Adam(score_net.parameters(), lr=1e-3)
# loss_score = nn.MSELoss()

# # Prepare data for ScoreNet training
# def get_tf_score_tensor(waveform_batch):
#     x_np = waveform_batch.squeeze(1).cpu().numpy()[..., np.newaxis].astype(np.float32)
#     target_len = 39124
#     x_np = x_np[:, :target_len, :] if x_np.shape[1] > target_len else np.pad(
#         x_np, ((0, 0), (0, target_len - x_np.shape[1]), (0, 0))
#     )
#     tf_scores = embedding_model.model(x_np, False, None).numpy()
#     return torch.tensor(tf_scores, dtype=torch.float32).unsqueeze(1).to(waveform_batch.device)

# # Train ScoreNet
# epochs_score = 5
# for epoch in range(epochs_score):
#     total_loss = 0
#     for noisy, _ in tqdm(train_loader):
#         noisy = noisy.unsqueeze(1).cuda()
#         target_scores = get_tf_score_tensor(noisy)

#         pred_scores = score_net(noisy)
#         loss = loss_score(pred_scores, target_scores)

#         opt_score.zero_grad()
#         loss.backward()
#         opt_score.step()

#         total_loss += loss.item()

#     print(f"[ScoreNet][Epoch {epoch+1}] loss = {total_loss/len(train_loader):.4f}")


100%|██████████| 35/35 [00:52<00:00,  1.49s/it]


[ScoreNet][Epoch 1] loss = 13.7140


100%|██████████| 35/35 [00:40<00:00,  1.15s/it]


[ScoreNet][Epoch 2] loss = 6.4450


100%|██████████| 35/35 [00:29<00:00,  1.17it/s]


[ScoreNet][Epoch 3] loss = 6.1384


100%|██████████| 35/35 [00:23<00:00,  1.50it/s]


[ScoreNet][Epoch 4] loss = 6.2409


100%|██████████| 35/35 [00:20<00:00,  1.69it/s]

[ScoreNet][Epoch 5] loss = 6.6068





# Save ScoreNet to Disk

In [35]:
    # torch.save(score_net.state_dict(), "/content/drive/MyDrive/whale_denoising/scorenet.pt")

# Define ScoreNet and load from Disk if it exists

In [19]:
# === Stage 2: ScoreNet definition (+ helpers, no training here) ===
import torch, torch.nn as nn
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# Per‑example waveform normalization (important — your inputs were tiny std≈0.008)
def norm_per_example(x: torch.Tensor) -> torch.Tensor:
    # x: (B,1,T)
    m = x.mean(dim=-1, keepdim=True)
    s = x.std(dim=-1, keepdim=True) + 1e-6
    return (x - m) / s

class ScoreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fe = nn.Sequential(
            nn.Conv1d(1,   32, 9, stride=2, padding=4), nn.BatchNorm1d(32),  nn.ReLU(inplace=True),
            nn.Conv1d(32,  64, 9, stride=2, padding=4), nn.BatchNorm1d(64),  nn.ReLU(inplace=True),
            nn.Conv1d(64, 128, 9, stride=2, padding=4), nn.BatchNorm1d(128), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool1d(1),  # -> (B,128,1)
        )
        self.head = nn.Sequential(
            nn.Flatten(),               # -> (B,128)
            nn.Linear(128, 64), nn.ReLU(inplace=True),
            nn.Linear(64, 1)            # -> (B,1)
        )
        # Kaiming init for stability
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, x):  # x: (B,1,T), assume already normalized
        x = self.fe(x)
        return self.head(x)

# Your existing bridge to the TF humpback model (unchanged)
def get_tf_score_tensor(waveform_batch: torch.Tensor) -> torch.Tensor:
    """
    waveform_batch: torch.Tensor (B,1,T) on device (values in [-?, ?])
    Returns: torch.Tensor (B,1) of TF detector scalar scores
    """
    x_np = waveform_batch.squeeze(1).detach().cpu().numpy()[..., np.newaxis].astype(np.float32)
    target_len = 39124
    if x_np.shape[1] > target_len:
        x_np = x_np[:, :target_len, :]
    else:
        pad = target_len - x_np.shape[1]
        x_np = np.pad(x_np, ((0,0), (0,pad), (0,0)))
    tf_scores = embedding_model.model(x_np, False, None).numpy()  # shape (B,) or (B,1)
    tf_scores = tf_scores.reshape(-1, 1).astype(np.float32)
    return torch.tensor(tf_scores, dtype=torch.float32, device=waveform_batch.device)

# Instantiate + (optionally) load
score_net = ScoreNet().to(device)

ckpt_path = "/content/drive/MyDrive/whale_denoising/scorenet.pt"  # adjust if needed
import os, torch
if os.path.exists(ckpt_path):
    state = torch.load(ckpt_path, map_location=device)
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    missing, unexpected = score_net.load_state_dict(state, strict=False)
    print(f"Loaded ScoreNet from {ckpt_path}")
    if missing:   print("  Missing keys:", missing)
    if unexpected: print("  Unexpected keys:", unexpected)
    score_net.eval()
else:
    print(f"No scorenet checkpoint found at {ckpt_path} — train it in the next cell.")



No scorenet checkpoint found at /content/drive/MyDrive/whale_denoising/scorenet.pt — train it in the next cell.


# Test ScoreNet

In [None]:
# ==== ScoreNet one-step diagnostic ====
import torch, numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
score_net = score_net.to(device)

def param_vec(model):
    return torch.nn.utils.parameters_to_vector([p.detach().float().cpu() for p in model.parameters()])

def grad_vec(model):
    gs = []
    for p in model.parameters():
        if p.grad is not None:
            gs.append(p.grad.detach().float().cpu().flatten())
    return torch.cat(gs) if gs else torch.tensor([])

# Grab one batch
noisy, _ = next(iter(train_loader))             # (B,T) or (B,?) from your DataLoader
noisy = noisy.unsqueeze(1).to(device).float()   # (B,1,T)

# Targets from TF detector
with torch.no_grad():                           # TF call is external; no torch grad needed
    target_scores = get_tf_score_tensor(noisy)  # (B,1)

print("Input batch stats: mean=", noisy.mean().item(), "std=", noisy.std().item())
print("Target stats: mean=", target_scores.mean().item(), "std=", target_scores.std().item())

# Fresh optimizer for the test (won't disturb your main one)
opt = torch.optim.Adam(score_net.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

score_net.train()
p_before = param_vec(score_net)
opt.zero_grad()

pred = score_net(noisy)                         # (B,1)
loss = loss_fn(pred, target_scores)
loss.backward()

g = grad_vec(score_net)
print("Pred stats: mean=", pred.mean().item(), "std=", pred.std().item())
print("Loss:", loss.item())
print("Grad L2 norm:", float(g.norm()) if g.numel() else 0.0)

opt.step()
p_after = param_vec(score_net)

delta = (p_after - p_before).norm().item()
print("Param delta L2 after one step:", delta)


Input batch stats: mean= -4.6275096110548475e-07 std= 0.012189652770757675
Target stats: mean= 5.3401336669921875 std= 2.1295998096466064
Pred stats: mean= 4.387958526611328 std= 0.0002856987121049315
Loss: 5.158382892608643
Grad L2 norm: 24.009103775024414
Param delta L2 after one step: 0.07807283103466034


# Train ScoreNet if needed

In [20]:
# === Stage 2: Train ScoreNet (unpaired; works with noisy_loader, optionally clean_loader) ===
from tqdm import tqdm
import numpy as np
import torch.nn as nn

score_net.train()
opt_score  = torch.optim.AdamW(score_net.parameters(), lr=3e-3, weight_decay=1e-4)
loss_score = nn.MSELoss()

def _audio_from_batch(batch):
    """Accepts either a tensor (B,T) or (B,1,T), or a tuple/list where the first item is audio.
       Returns (B,1,T) float tensor on device."""
    x = batch[0] if isinstance(batch, (list, tuple)) else batch
    if x.ndim == 2:  # (B,T) -> (B,1,T)
        x = x.unsqueeze(1)
    return x.to(device).float()

# Choose which loaders to use for ScoreNet training
# Default: only real noisy audio
USE_CLEAN_SYNTHETIC_TOO = False  # set True to also learn from clean_loader's synthetic noisy
train_loaders = [noisy_loader]
if USE_CLEAN_SYNTHETIC_TOO:
    train_loaders.append(clean_loader)  # clean_loader yields (noisy, clean), we'll take the first

# (Optional) Calibrate target scaling once on a few batches (stabilizes z-space training)
CALIBRATION_BATCHES = 5
targets_for_mu = []
score_net.eval()
with torch.no_grad():
    taken = 0
    for dl in train_loaders:
        for batch in dl:
            audio = _audio_from_batch(batch)     # (B,1,T)
            audio = norm_per_example(audio)       # keep consistent with inference-time norm
            t = get_tf_score_tensor(audio)        # (B,1) from the TF model
            targets_for_mu.append(t.detach().cpu())
            taken += 1
            if taken >= CALIBRATION_BATCHES:
                break
        if taken >= CALIBRATION_BATCHES:
            break

if len(targets_for_mu):
    tgt_cat = torch.cat(targets_for_mu, dim=0).numpy()
    mu, sig = float(tgt_cat.mean()), float(tgt_cat.std() + 1e-6)
else:
    mu, sig = 0.0, 1.0
print(f"[Calibration] target mean={mu:.3f}, std={sig:.3f}")
score_net.train()

EPOCHS = 10
for epoch in range(EPOCHS):
    losses, pstds, tstds = [], [], []
    for dl in train_loaders:
        for batch in tqdm(dl, leave=False):
            audio = _audio_from_batch(batch)      # (B,1,T)
            audio = norm_per_example(audio)

            with torch.no_grad():
                target = get_tf_score_tensor(audio)  # (B,1)

            target_z = (target - mu) / sig
            pred_z   = score_net(audio)             # (B,1)
            loss     = loss_score(pred_z, target_z)

            opt_score.zero_grad(set_to_none=True)
            loss.backward()
            opt_score.step()

            losses.append(loss.item())
            pstds.append(float(pred_z.std().item()))
            tstds.append(float(target_z.std().item()))

    print(f"[ScoreNet][Epoch {epoch+1:02d}] "
          f"loss={np.mean(losses):.4f} | pred std={np.mean(pstds):.3f} | targ std={np.mean(tstds):.3f}")

# Save for reuse
torch.save(score_net.state_dict(), ckpt_path)
print(f"Saved ScoreNet to {ckpt_path}")
score_net.eval()


[Calibration] target mean=-0.517, std=2.402




[ScoreNet][Epoch 01] loss=1.0788 | pred std=0.179 | targ std=1.009




[ScoreNet][Epoch 02] loss=0.9821 | pred std=0.272 | targ std=1.008




[ScoreNet][Epoch 03] loss=0.9232 | pred std=0.248 | targ std=1.013




[ScoreNet][Epoch 04] loss=0.8012 | pred std=0.465 | targ std=1.015




[ScoreNet][Epoch 05] loss=0.7706 | pred std=0.480 | targ std=1.008




[ScoreNet][Epoch 06] loss=0.7416 | pred std=0.545 | targ std=0.994




[ScoreNet][Epoch 07] loss=0.7152 | pred std=0.566 | targ std=1.016




[ScoreNet][Epoch 08] loss=0.7144 | pred std=0.549 | targ std=0.999




[ScoreNet][Epoch 09] loss=0.7141 | pred std=0.595 | targ std=0.998


                                               

[ScoreNet][Epoch 10] loss=0.6999 | pred std=0.583 | targ std=0.999
Saved ScoreNet to /content/drive/MyDrive/whale_denoising/scorenet.pt




ScoreNet(
  (fe): Sequential(
    (0): Conv1d(1, 32, kernel_size=(9,), stride=(2,), padding=(4,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv1d(32, 64, kernel_size=(9,), stride=(2,), padding=(4,))
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv1d(64, 128, kernel_size=(9,), stride=(2,), padding=(4,))
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): AdaptiveAvgPool1d(output_size=1)
  )
  (head): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=128, out_features=64, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=64, out_features=1, bias=True)
  )
)

# Save scorenet calibration values

In [36]:
import json, os
calib_path = "/content/drive/MyDrive/whale_denoising/scorenet_calibration.json"
with open(calib_path, "w") as f:
    json.dump({"mu": mu, "sig": sig}, f)
print("Saved calibration to", calib_path)

Saved calibration to /content/drive/MyDrive/whale_denoising/scorenet_calibration.json


## Or load scorenet calibration values if loading scorenet from disk

In [None]:
import json
with open(calib_path) as f:
    cs = json.load(f)
mu, sig = float(cs["mu"]), float(cs["sig"])
print("Loaded calibration:", mu, sig)


# ScoreNet Sanity Check

In [22]:
# --- Quick sanity: parameter norm, batch preds, TF targets ---
def param_norm(model: nn.Module) -> float:
    with torch.no_grad():
        vec = torch.nn.utils.parameters_to_vector([p.detach().float().cpu() for p in model.parameters()])
        return float(vec.norm().item())

print("ScoreNet param L2 norm:", param_norm(score_net))

noisy = next(iter(noisy_loader))
noisy = norm_per_example(noisy.unsqueeze(1).to(device).float())

with torch.no_grad():
    pred_z = score_net(noisy).squeeze(-1)          # (B,)
    targ   = get_tf_score_tensor(noisy).squeeze(-1)  # (B,)
    pred   = pred_z * sig + mu                      # back to original scale

print("Pred (first 5):", pred[:5].detach().cpu().numpy())
print("Targ (first 5):", targ[:5].detach().cpu().numpy())
print("Pred std:", float(pred.std().item()), "| Targ std:", float(targ.std().item()))


ScoreNet param L2 norm: 31.741336822509766
Pred (first 5): [ 1.3257439   0.7778993   0.9369108   0.901056   -0.85629934]
Targ (first 5): [-0.28420022  1.0387262   1.9200736  -1.7230052  -5.367212  ]
Pred std: 0.5064626932144165 | Targ std: 2.184911012649536


# Find paths to files not used in ScoreNet training

In [None]:
# ==== Build a held‑out (val) list that excludes all ScoreNet training wavs ====
import os, glob, json

def get_training_paths_from_loader(loader):
    ds = loader.dataset
    # Try common attribute names your dataset might expose
    cand_attrs = [
        "paths", "files", "file_paths", "filepaths", "file_list", "items", "samples"
    ]
    for a in cand_attrs:
        if hasattr(ds, a):
            v = getattr(ds, a)
            # Some datasets use .samples = [(path, label), ...]
            if isinstance(v, list) and len(v) and isinstance(v[0], (list, tuple)) and isinstance(v[0][0], str):
                return [p for p, *_ in v]
            # Plain list[str]
            if isinstance(v, list) and (len(v) == 0 or isinstance(v[0], str)):
                return v
    # Torchaudio/ImageFolder style: .dataset.samples or .dataset.targets
    if hasattr(ds, "samples"):
        return [p for p, *_ in ds.samples]
    # If your dataset wraps another dataset
    if hasattr(ds, "dataset"):
        inner = ds.dataset
        for a in cand_attrs:
            if hasattr(inner, a):
                v = getattr(inner, a)
                if isinstance(v, list) and (len(v) == 0 or isinstance(v[0], str)):
                    return v
                if isinstance(v, list) and len(v) and isinstance(v[0], (list, tuple)) and isinstance(v[0][0], str):
                    return [p for p, *_ in v]
    return None

train_paths = get_training_paths_from_loader(train_loader)
if train_paths is None:
    print("Could not introspect training file list from train_loader.dataset.")
    print("You can: (a) add an attribute like `dataset.paths`, or (b) rebuild the split deterministically (see below).")
    train_paths = []

# Normalize to absolute paths for stable set ops
train_set = set(os.path.abspath(p) for p in train_paths)

# 1) Point this at the directory that contains *all* candidate wavs
ALL_WAV_ROOT = "/content/drive/MyDrive/whale_denoising/input_audio"  # <-- set this to your corpus root
all_paths = [os.path.abspath(p) for p in glob.glob(os.path.join(ALL_WAV_ROOT, "**/*.wav"), recursive=True)]

# 2) Held‑out = all − train
val_paths = [p for p in all_paths if p not in train_set]

print(f"All wavs       : {len(all_paths)}")
print(f"Training wavs  : {len(train_set)}")
print(f"Held‑out (val) : {len(val_paths)}")

# Save manifests for reproducibility
SPLIT_DIR = "/content/drive/MyDrive/whale_denoising/splits"
os.makedirs(SPLIT_DIR, exist_ok=True)
with open(os.path.join(SPLIT_DIR, "scorenet_train_paths.json"), "w") as f:
    json.dump(sorted(train_set), f, indent=2)
with open(os.path.join(SPLIT_DIR, "scorenet_val_paths.json"), "w") as f:
    json.dump(sorted(val_paths), f, indent=2)
print("Saved split manifests to", SPLIT_DIR)


# ScoreNet Validation

In [None]:
import glob, numpy as np, matplotlib.pyplot as plt
from scipy.stats import spearmanr

device = "cuda" if torch.cuda.is_available() else "cpu"
score_net.eval().to(device)

VAL_DIR = "/content/drive/MyDrive/whale_denoising/val_wavs"  # adjust
val_paths = sorted(glob.glob(f"{VAL_DIR}/*.wav"))

def load_mono(path, sr=10000):
    import librosa
    y, _ = librosa.load(path, sr=sr, mono=True)
    return y.astype(np.float32)

def to_batch(xs, sr=10000):
    # pad/trim to the TF target length you used in get_tf_score_tensor
    import numpy as np, torch
    target_len = 39124
    arrs = []
    for x in xs:
        if len(x) >= target_len:
            x = x[:target_len]
        else:
            x = np.pad(x, (0, target_len-len(x)))
        arrs.append(x)
    t = torch.from_numpy(np.stack(arrs)).float()  # (B,T)
    return t.unsqueeze(1)  # (B,1,T)

# Batch through files (keeps TF calls small if you reuse get_tf_score_tensor)
y_true, y_pred = [], []
B = 16
for i in range(0, len(val_paths), B):
    batch_paths = val_paths[i:i+B]
    xs = [load_mono(p) for p in batch_paths]
    xb = to_batch(xs).to(device)

    # normalize per-example (must match your training)
    m = xb.mean(dim=-1, keepdim=True)
    s = xb.std(dim=-1, keepdim=True) + 1e-6
    xb_norm = (xb - m) / s

    with torch.no_grad():
        # TF detector (targets)
        t = get_tf_score_tensor(xb_norm)  # your function expects (B,1,T) on device
        # Note: If your TF bridge expects raw (pre-norm) wave, call it with xb instead.
        # Keep it consistent with how you trained!

        # ScoreNet prediction in z-space, then un-standardize
        pred_z = score_net(xb_norm).squeeze(-1)  # (B,)
        pred   = pred_z * sig + mu               # back to original scale

    y_true.extend(t.squeeze(-1).detach().cpu().numpy().tolist())
    y_pred.extend(pred.detach().cpu().numpy().tolist())

y_true = np.array(y_true); y_pred = np.array(y_pred)

def r2(y, yhat):
    ss_res = np.sum((y - yhat)**2)
    ss_tot = np.sum((y - y.mean())**2) + 1e-12
    return 1 - ss_res/ss_tot

print(f"MAE      : {np.mean(np.abs(y_true - y_pred)):.4f}")
print(f"R^2      : {r2(y_true, y_pred):.3f}")
print(f"Spearman : {spearmanr(y_true, y_pred).correlation:.3f}")

# Scatter + y=x
plt.figure(figsize=(4.5,4))
plt.scatter(y_true, y_pred, s=12, alpha=0.7)
lo, hi = min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())
plt.plot([lo, hi], [lo, hi], "--")
plt.xlabel("TF detector score")
plt.ylabel("ScoreNet prediction")
plt.title("ScoreNet vs TF (held‑out)")
plt.tight_layout()
plt.show()

# Residuals
res = y_pred - y_true
plt.figure(figsize=(5,3))
plt.hist(res, bins=30)
plt.title("Residuals (ScoreNet - TF)")
plt.tight_layout()
plt.show()


# Denoiser Sanity Check

In [33]:
import os
from tqdm import tqdm
import torch
import torchaudio
from IPython.display import Audio
import soundfile as sf

# Make sure you have this helper defined earlier:
# def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800): ...
# def postprocess_audio(waveform, sr=10000, hp_hz=10, peak=0.99, low_hz=200, high_hz=800): ...

# # Pick one test clip
TEST_WAV = "/content/drive/MyDrive/whale_denoising/input_audio/noisy/whale_64_10k.wav"  # <-- set your path
assert os.path.exists(TEST_WAV), f"File not found: {TEST_WAV}"

denoiser.eval()
waveform, sr = torchaudio.load(TEST_WAV)  # (1, T)
x = waveform[0]

# 1) Pre-bandpass to match training
x_bp = bandpass_waveform(x, sr, low_hz=200, high_hz=800)

# 2) Denoise
tmp = x_bp.unsqueeze(0).unsqueeze(0).cuda()
with torch.no_grad():
  tmp = denoiser(tmp)
denoised = tmp.squeeze().cpu()

# 3) Post-process (demean + HPF + normalize)
out_pp = postprocess_audio(denoised, sr=sr)

# 4) Compare
display(Audio(x, rate=sr))
display(Audio(x_bp, rate=sr))
display(Audio(out_pp, rate=sr))


sf.write("/content/drive/MyDrive/whale_denoising/output_audio/test/test_noisy.wav", x, sr)
sf.write("/content/drive/MyDrive/whale_denoising/output_audio/test/test_bandpassed.wav", x_bp, sr)
sf.write("/content/drive/MyDrive/whale_denoising/output_audio/test/test_denoised.wav", out_pp, sr)
print("Wrote test/test_noisy.wav, test/test_bandpassed.wav and test/test_denoised.wav")


Wrote test/test_noisy.wav, test/test_bandpassed.wav and test/test_denoised.wav


# ScoreNet sanity check

# Stage 2: Whale-Score Maximisation Fine-Tune

In [23]:
# === Stage 2: Fine-tune denoiser to maximise whale score (uses noisy_loader) ===
lambda_score = 1.0    # strength of score maximisation
lambda_mse   = 0.05   # optional regulariser

num_epochs_stage2 = 15
denoiser.train()

# Ensure ScoreNet is fixed (no parameter updates), but still differentiable wrt input
score_net.eval()
for p in score_net.parameters():
    p.requires_grad_(False)

for epoch in range(num_epochs_stage2):
    total_loss = 0.0
    total_score = 0.0
    total_batches = 0

    for noisy in tqdm(noisy_loader):          # <-- only noisy waveforms from Stage 2 dataset
        # Accept (B,T) or (B,1,T)
        if noisy.ndim == 2:
            noisy = noisy.unsqueeze(1)        # (B,1,T)

        noisy = noisy.cuda(non_blocking=True)

        denoised = denoiser(noisy)            # (B,1,T)

        # ScoreNet whale score (maximize)
        # NOTE: do NOT wrap in torch.no_grad(); we want grads wrt denoised
        score_out = score_net(denoised)       # shape could be (B, ...) -> reduce to mean
        mean_score = score_out.mean()

        # Optional reg: keep output close to input
        mse_reg = mse_loss(denoised, noisy)

        # Combine losses
        loss = -lambda_score * mean_score + lambda_mse * mse_reg
        loss = loss + lambda_dc * dc_loss(denoised)   # DC penalty

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss  += loss.item()
        total_score += mean_score.item()
        total_batches += 1

    print(f"[Stage 2][Epoch {epoch+1}] loss = {total_loss/total_batches:.4f} | "
          f"ScoreNet whale score = {total_score/total_batches:.4f}")




100%|██████████| 58/58 [00:05<00:00, 10.65it/s]


[Stage 2][Epoch 1] loss = -8.7984 | ScoreNet whale score = 16.2232


100%|██████████| 58/58 [00:05<00:00, 10.54it/s]


[Stage 2][Epoch 2] loss = -9.7671 | ScoreNet whale score = 17.9024


100%|██████████| 58/58 [00:05<00:00, 10.73it/s]


[Stage 2][Epoch 3] loss = -9.7698 | ScoreNet whale score = 17.8993


100%|██████████| 58/58 [00:05<00:00, 10.63it/s]


[Stage 2][Epoch 4] loss = -9.7715 | ScoreNet whale score = 17.9007


100%|██████████| 58/58 [00:05<00:00, 10.77it/s]


[Stage 2][Epoch 5] loss = -9.7721 | ScoreNet whale score = 17.9027


100%|██████████| 58/58 [00:05<00:00, 10.53it/s]


[Stage 2][Epoch 6] loss = -9.7725 | ScoreNet whale score = 17.9034


100%|██████████| 58/58 [00:05<00:00, 10.97it/s]


[Stage 2][Epoch 7] loss = -9.7728 | ScoreNet whale score = 17.9044


100%|██████████| 58/58 [00:05<00:00, 10.86it/s]


[Stage 2][Epoch 8] loss = -9.7731 | ScoreNet whale score = 17.9044


100%|██████████| 58/58 [00:05<00:00, 10.38it/s]


[Stage 2][Epoch 9] loss = -9.7734 | ScoreNet whale score = 17.9053


100%|██████████| 58/58 [00:05<00:00, 10.68it/s]


[Stage 2][Epoch 10] loss = -9.7737 | ScoreNet whale score = 17.9060


100%|██████████| 58/58 [00:05<00:00, 10.60it/s]


[Stage 2][Epoch 11] loss = -9.7740 | ScoreNet whale score = 17.9062


100%|██████████| 58/58 [00:05<00:00, 10.62it/s]


[Stage 2][Epoch 12] loss = -9.7745 | ScoreNet whale score = 17.9070


100%|██████████| 58/58 [00:05<00:00, 10.60it/s]


[Stage 2][Epoch 13] loss = -9.7752 | ScoreNet whale score = 17.9091


100%|██████████| 58/58 [00:05<00:00, 10.88it/s]


[Stage 2][Epoch 14] loss = -9.7775 | ScoreNet whale score = 17.9093


100%|██████████| 58/58 [00:05<00:00, 10.69it/s]

[Stage 2][Epoch 15] loss = -9.7834 | ScoreNet whale score = 17.9162





# Quick Diagnostics

In [None]:
# === Quick Diagnostics on a FRESH batch ===
import numpy as np
import torch
import IPython.display as ipd

denoiser.eval()

# 1) Pull a fresh batch
noisy_b, clean_b = next(iter(train_loader))     # (B, T)
noisy_b  = noisy_b.unsqueeze(1).cuda()          # (B, 1, T)
clean_b  = clean_b.unsqueeze(1).cuda()          # (B, 1, T)

with torch.no_grad():
    denoised_b = denoiser(noisy_b)              # (B, 1, T)

# 2) Basic amplitude / energy stats
def batch_rms(x):  # x: (B,1,T)
    return x.pow(2).mean(dim=(-1,-2)).sqrt().detach().cpu().numpy()

print("max|noisy|    :", float(noisy_b.abs().max()))
print("max|denoised| :", float(denoised_b.abs().max()))
print("RMS noisy     :", batch_rms(noisy_b)[:8])
print("RMS denoised  :", batch_rms(denoised_b)[:8])

# 3) ScoreNet scores (student) – differentiable model we trained
with torch.no_grad():
    score_noisy    = score_net(noisy_b).mean().item()
    score_denoised = score_net(denoised_b).mean().item()
print(f"ScoreNet mean   noisy={score_noisy:.3f}  denoised={score_denoised:.3f}")

# 4) TF whale score (teacher) – frozen model, numpy only
def to_tf_batch(x_torch):
    TARGET_LEN = 39124
    x = x_torch.detach().squeeze(1).cpu().numpy().astype(np.float32)  # (B,T)
    x = x[..., np.newaxis]  # (B,T,1)
    if x.shape[1] < TARGET_LEN:
        pad = TARGET_LEN - x.shape[1]
        x = np.pad(x, ((0,0),(0,pad),(0,0)))
    else:
        x = x[:, :TARGET_LEN, :]
    return x

tf_noisy    = embedding_model.model(to_tf_batch(noisy_b),    False, None).numpy().mean()
tf_denoised = embedding_model.model(to_tf_batch(denoised_b), False, None).numpy().mean()
print(f"TF mean score   noisy={tf_noisy:.3f}  denoised={tf_denoised:.3f}")

# 5) Listen to the first example (assumes 10kHz SR)
sr = 10000
i = 0
print("\n🔊 Example 0 — Noisy:")
ipd.display(ipd.Audio(noisy_b[i,0].detach().cpu().numpy(), rate=sr))
print("🔊 Example 0 — Denoised:")
ipd.display(ipd.Audio(denoised_b[i,0].detach().cpu().numpy(), rate=sr))

# If you want to hear the clean target (from the band-passed dataset):
print("🔊 Example 0 — Clean target (band-passed):")
ipd.display(ipd.Audio(clean_b[i,0].detach().cpu().numpy(), rate=sr))


max|noisy|    : 0.14899013936519623
max|denoised| : 49.96099853515625
RMS noisy     : [0.03116972 0.01057382 0.01835111 0.01020089 0.01045195 0.00538983
 0.01200303 0.0023462 ]
RMS denoised  : [49.28384  49.27941  49.26566  49.28075  49.28463  49.27108  49.28037
 49.263042]
ScoreNet mean   noisy=4.650  denoised=294.533
TF mean score   noisy=4.259  denoised=3.525

🔊 Example 0 — Noisy:


🔊 Example 0 — Denoised:


🔊 Example 0 — Clean target (band-passed):


In [None]:
# Example: demean + HPF + normalize
out_pp = postprocess_audio(denoised_b[0,0], sr=10000)

# Listen in Colab
import IPython.display as ipd
ipd.display(ipd.Audio(out_pp.cpu().numpy(), rate=10000))


In [None]:
import torchaudio
import os

output_dir_pp = "/content/drive/MyDrive/whale_denoising/diagnostics_postprocessed"
os.makedirs(output_dir_pp, exist_ok=True)

# Assuming sr is defined from previous cells (e.g., 10000)
output_path_pp = os.path.join(output_dir_pp, "denoised_example_0_postprocessed.wav")

# out_pp is (T,) so unsqueeze(0) to make it (1, T) for torchaudio.save
torchaudio.save(output_path_pp, out_pp.detach().cpu().unsqueeze(0), sr)

print(f"Post-processed denoised example audio saved to: {output_path_pp}")

Post-processed denoised example audio saved to: /content/drive/MyDrive/whale_denoising/diagnostics_postprocessed/denoised_example_0_postprocessed.wav


In [None]:
import torchaudio
import os

output_dir = "/content/drive/MyDrive/whale_denoising/diagnostics"
os.makedirs(output_dir, exist_ok=True)

# Assuming sr is defined from previous cells (e.g., 10000)
output_path = os.path.join(output_dir, "denoised_example_0.wav")

# denoised_b[0,0] is (T,) so unsqueeze(0) to make it (1, T) for torchaudio.save
torchaudio.save(output_path, denoised_b[0,0].detach().cpu().unsqueeze(0), sr)

print(f"Denoised example audio saved to: {output_path}")

Denoised example audio saved to: /content/drive/MyDrive/whale_denoising/diagnostics/denoised_example_0.wav


# Listen & Compare

In [None]:
import IPython.display as ipd
import random

noisy_path = random.choice(noisy_paths)
waveform, sr = torchaudio.load(noisy_path)
waveform = waveform[0]

# Apply same band-pass as training
waveform_bp = bandpass_waveform(waveform, sr, low_hz=200, high_hz=800)

# Prepare
input_tensor = waveform_bp.unsqueeze(0).unsqueeze(0).cuda()

# Denoise
denoiser.eval()
with torch.no_grad():
    denoised_tensor = denoiser(input_tensor).squeeze().cpu()

print(f"🎧 Listening to: {noisy_path}")
print("🔊 Noisy (band-passed):")
ipd.display(ipd.Audio(waveform_bp.numpy(), rate=sr))

print("🔊 Denoised:")
ipd.display(ipd.Audio(denoised_tensor.numpy(), rate=sr))

🎧 Listening to: /content/drive/MyDrive/whale_denoising/noisy/whale_25_10k.wav
🔊 Noisy (band-passed):


🔊 Denoised:


# Save trained model to disk

In [34]:
torch.save(denoiser.state_dict(), "/content/drive/MyDrive/whale_denoising/whale_denoiser.pt")


# Save denoised files to disk

In [None]:
import os
from tqdm import tqdm
import torch
import torchaudio

# Make sure you have this helper defined earlier:
# def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800): ...
# def postprocess_audio(waveform, sr=10000, hp_hz=10, peak=0.99): ...

noisy_output_dir = "/content/drive/MyDrive/whale_denoising/output_audio/noisy_denoised"
os.makedirs(noisy_output_dir, exist_ok=True)

NUM_PASSES = 1  # try 1 first; if you liked iterative, set to 2–3

denoiser.eval()
for noisy_path in tqdm(noisy_paths):
    waveform, sr = torchaudio.load(noisy_path)  # (1, T)
    x = waveform[0]

    # 1) Pre-bandpass to match training
    x_bp = bandpass_waveform(x, sr, low_hz=200, high_hz=800)

    # 2) Denoise (optionally multiple passes)
    current = x_bp.unsqueeze(0).unsqueeze(0).cuda()
    with torch.no_grad():
        for _ in range(NUM_PASSES):
            current = denoiser(current)
    denoised = current.squeeze().cpu()

    # 3) Post-process (demean + HPF + normalize)
    out_pp = postprocess_audio(denoised, sr=sr)

    # 4) Save
    filename = os.path.basename(noisy_path)
    out_path = os.path.join(noisy_output_dir, filename)
    torchaudio.save(out_path, out_pp.unsqueeze(0).cpu(), sr)

print(f"All denoised files saved to: {noisy_output_dir}")


100%|██████████| 923/923 [00:23<00:00, 38.69it/s]

All denoised files saved to: /content/drive/MyDrive/whale_denoising/noisy_denoised





# Save Bandpassed noisy originals for comparison

In [None]:
import os
from tqdm import tqdm
import torchaudio

# Output folder for band-passed noisy files
bandpass_output_dir = "/content/drive/MyDrive/whale_denoising/output_audio/clean_bandpassed"
os.makedirs(bandpass_output_dir, exist_ok=True)

LOW_HZ, HIGH_HZ = 200, 800

for clean_path in tqdm(clean_paths):
    # Load audio
    waveform, sr = torchaudio.load(clean_path)  # (1, T)
    x = waveform[0]

    # Apply the same band-pass as in training/inference
    x_bp = bandpass_waveform(x, sr, low_hz=LOW_HZ, high_hz=HIGH_HZ)

    # Post-process (demean, tiny HPF, normalize) for fair comparison
    out_pp = postprocess_audio(x_bp, sr=sr)

    # Save to new directory
    filename = os.path.basename(clean_path)
    out_path = os.path.join(bandpass_output_dir, filename)
    torchaudio.save(out_path, out_pp.unsqueeze(0).cpu(), sr)

print(f"All band-passed files saved to: {bandpass_output_dir}")


100%|██████████| 550/550 [00:11<00:00, 47.45it/s]

All band-passed files saved to: /content/drive/MyDrive/whale_denoising/clean_bandpassed





# Save denoised clean files

In [None]:
import os
from tqdm import tqdm
import torch
import torchaudio

# Make sure you have this helper defined earlier:
# def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800): ...
# def postprocess_audio(waveform, sr=10000, hp_hz=10, peak=0.99): ...

clean_output_dir = "/content/drive/MyDrive/whale_denoising/output_audio/clean_denoised_batch2"
os.makedirs(noisy_output_dir, exist_ok=True)

NUM_PASSES = 1  # try 1 first; if you liked iterative, set to 2–3

denoiser.eval()
for clean_path in tqdm(clean_paths):
    waveform, sr = torchaudio.load(clean_path)  # (1, T)
    x = waveform[0]

    # 1) Pre-bandpass to match training
    x_bp = bandpass_waveform(x, sr, low_hz=200, high_hz=800)

    # 2) Denoise (optionally multiple passes)
    current = x_bp.unsqueeze(0).unsqueeze(0).cuda()
    with torch.no_grad():
        for _ in range(NUM_PASSES):
            current = denoiser(current)
    denoised = current.squeeze().cpu()

    # 3) Post-process (demean + HPF + normalize)
    out_pp = postprocess_audio(denoised, sr=sr)

    # 4) Save
    filename = os.path.basename(clean_path)
    out_path = os.path.join(clean_output_dir, filename)
    torchaudio.save(out_path, out_pp.unsqueeze(0).cpu(), sr)

print(f"All denoised files saved to: {clean_output_dir}")


100%|██████████| 550/550 [00:11<00:00, 45.91it/s]

All denoised files saved to: /content/drive/MyDrive/whale_denoising/clean_denoised_batch2





# Post-process rendered denoised files to create fades

In [None]:
import os
from tqdm import tqdm
import torchaudio

# Input and output directories
denoised_input_dir = "/content/drive/MyDrive/whale_denoising/output_audio/clean_denoised"
fade_output_dir = "/content/drive/MyDrive/whale_denoising/output_audio/clean_denoised_faded"
os.makedirs(fade_output_dir, exist_ok=True)

# Fade length in milliseconds
FADE_MS = 150  # adjust to taste (e.g. 20ms, 100ms)

for file_name in tqdm(os.listdir(denoised_input_dir)):
    if not file_name.lower().endswith(".wav"):
        continue

    path_in = os.path.join(denoised_input_dir, file_name)
    waveform, sr = torchaudio.load(path_in)

    # Calculate fade length in samples
    fade_samples = int((FADE_MS / 1000.0) * sr)

    # Apply fade in/out
    fade = torchaudio.transforms.Fade(fade_in_len=fade_samples,
                                      fade_out_len=fade_samples,
                                      fade_shape="linear")
    waveform_faded = fade(waveform)

    # Save
    path_out = os.path.join(fade_output_dir, file_name)
    torchaudio.save(path_out, waveform_faded, sr)

print(f"Faded files saved to: {fade_output_dir}")


100%|██████████| 550/550 [00:10<00:00, 54.02it/s]

Faded files saved to: /content/drive/MyDrive/whale_denoising/clean_denoised_faded



