In [1]:
#@title Installation. { vertical-output: true }
#@markdown Run this notebook in Google Colab by following [this link](https://colab.research.google.com/github/google-research/perch/blob/main/embed_audio.ipynb).
#@markdown
#@markdown Run this cell to install the project dependencies.
%pip install git+https://github.com/google-research/perch.git

Collecting git+https://github.com/google-research/perch.git
  Cloning https://github.com/google-research/perch.git to /tmp/pip-req-build-bbd4w33l
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/perch.git /tmp/pip-req-build-bbd4w33l
  Resolved https://github.com/google-research/perch.git to commit 09c304e92d2ef093f353a7454206cbe2a4069249
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting perch_hoplite@ git+https://github.com/google-research/perch-hoplite.git (from chirp==0.1.0)
  Cloning https://github.com/google-research/perch-hoplite.git to /tmp/pip-install-tzu6r91i/perch-hoplite_4f9ff4cc5e4a431d984bdf15831564b3
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/perch-hoplite.git /tmp/pip-install-tzu6r91i/perch-hoplite_4f9ff4cc5e4a431d984bdf15831564b3
  Resolved https://github

In [1]:
#@title Imports. { vertical-output: true }

from etils import epath
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm
from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from perch_hoplite.zoo import model_configs



In [2]:
#@title Basic Configuration. { vertical-output: true }

#@markdown Define the model: perch or birdnet are most common for birds.
model_choice = 'humpback'  #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']
#@markdown Set the base directory for the project.
working_dir = '/tmp/agile'  #@param

# Set the embedding and labeled data directories.
embeddings_path = epath.Path(working_dir) / 'embeddings'
labeled_data_path = epath.Path(working_dir) / 'labeled'
embeddings_glob = embeddings_path / 'embeddings-*'

# OPTIONAL: Set up separation model.
separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os
#@title Create a new folder in Drive (if it doesn't already exist) within your Google drive.
base_dir = '/content/drive/MyDrive/'
#@ markdown Name of your new folder in Drive
new_folder_name = 'whale_denoising' #@param

drive_output_directory = base_dir + new_folder_name

try:
  if not os.path.exists(drive_output_directory):
    os.makedirs(drive_output_directory, exist_ok=True)
    print(f'Directory {drive_output_directory} created successfully.')
  else:
    print(f'Directory {drive_output_directory} already exists.')
except OSError as e:
    print("Error:", e)

Directory /content/drive/MyDrive/whale_denoising already exists.


In [5]:
#@title Embedding Configuration. { vertical-output: true }

config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()

#@markdown IMPORTANT: Select the target audio files.
#@markdown source_file_patterns should contain a list of globs of audio files, like:
#@markdown ['/home/me/*.wav', '/home/me/other/*.flac']
#config.source_file_patterns = ['gs://chirp-public-bucket/soundscapes/powdermill/Recording*/*.wav']  #@param
config.source_file_patterns = ['/content/drive/MyDrive/whale_denoising/noisy/*.wav', '/content/drive/MyDrive/whale_denoising/clean/*.wav']  #@param
config.output_dir = embeddings_path.as_posix()

preset_info = model_configs.get_preset_model_config(model_choice)
config.embed_fn_config.model_key = preset_info.model_key
config.embed_fn_config.model_config = preset_info.model_config

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

#@markdown File sharding automatically splits audio files into one-minute chunks
#@markdown for embedding. This limits both system and GPU memory usage,
#@markdown especially useful when working with long files (>1 hour).
use_file_sharding = False  #@param {type:'boolean'}
if use_file_sharding:
  config.shard_len_s = 60.0

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [6]:
#@title Set up. { vertical-output: true }

# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)], dtype=np.float32)
embed_fn.embedding_model.embed(z)
print('Setup complete!')



Loading model(s)...
Found 1473 source infos.


Test-run of model...
Setup complete!


# Extract the model

In [7]:
embedding_model = embed_fn.embedding_model

In [None]:
#@title Run embedding. { vertical-output: true }

# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.

embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)

print(f'Found {len(existing_embedding_ids)} existing embedding ids. \n'
      f'Processing {len(new_source_infos)} new source infos. ')

try:
  audio_loader = lambda fp, offset: audio_utils.load_audio_window(
      fp, offset, sample_rate=config.embed_fn_config.model_config.sample_rate,
      window_size_s=config.get('shard_len_s', -1.0))
  audio_iterator = audio_utils.multi_load_audio_window(
      filepaths=[s.filepath for s in new_source_infos],
      offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
      audio_loader=audio_loader,
  )
  with tf_examples.EmbeddingsTFRecordMultiWriter(
      output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:
    for source_info, audio in tqdm.tqdm(
        zip(new_source_infos, audio_iterator), total=len(new_source_infos)):
      if not embed_fn.validate_audio(source_info, audio):
        continue
      file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
      offset_s = source_info.shard_num * source_info.shard_len_s
      example = embed_fn.audio_to_example(file_id, offset_s, audio)
      if example is None:
        fail += 1
        continue
      file_writer.write(example.SerializeToString())
      succ += 1
    file_writer.flush()
finally:
  del(audio_iterator)
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

# Create Denoiser

In [9]:
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np

# Simple 1D convolutional denoiser
class Denoiser1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 1, kernel_size=15, padding=7)
        )

    def forward(self, x):
        return self.net(x)

denoiser = Denoiser1D().cuda()
optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()



# Pre bandpass the audio

In [11]:
import torchaudio

def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800):
    # waveform: (T,) or (1,T)
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)

    # Apply highpass then lowpass
    bp = torchaudio.functional.highpass_biquad(waveform, sr, low_hz)
    bp = torchaudio.functional.lowpass_biquad(bp, sr, high_hz)
    return bp.squeeze(0)


# Define data loader

In [13]:
from torch.utils.data import Dataset, DataLoader
import torchaudio
import random
import numpy as np
import glob

class WhaleDenoisingDatasetBP(Dataset):
    def __init__(self, clean_paths, noise_paths, snr_db_range=(10, 25),
                 low_hz=200, high_hz=800):
        self.clean_paths = clean_paths
        self.noise_paths = noise_paths
        self.snr_db_range = snr_db_range
        self.low_hz = low_hz
        self.high_hz = high_hz

    def __len__(self):
        return len(self.clean_paths)

    def __getitem__(self, idx):
        clean_path = self.clean_paths[idx]
        noise_path = random.choice(self.noise_paths)

        clean, sr = torchaudio.load(clean_path)
        noise, _  = torchaudio.load(noise_path)

        # Mono
        clean = clean[0]
        noise = noise[0]

        # Ensure same length
        target_len = 39124
        if len(clean) < target_len:
            clean = torch.nn.functional.pad(clean, (0, target_len - len(clean)))
        else:
            clean = clean[:target_len]

        if len(noise) < target_len:
            noise = torch.nn.functional.pad(noise, (0, target_len - len(noise)))
        else:
            start = random.randint(0, len(noise) - target_len)
            noise = noise[start:start+target_len]

        # Apply band-pass to both
        clean = bandpass_waveform(clean, sr, self.low_hz, self.high_hz)
        noise = bandpass_waveform(noise, sr, self.low_hz, self.high_hz)

        # Mix at random SNR
        snr_db = random.uniform(*self.snr_db_range)
        alpha = 10 ** (-snr_db / 20)
        noisy = clean + alpha * noise

        return noisy, clean


noisy_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/noisy/*.wav'))
clean_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/clean/*.wav'))
#noisy_dataset = NoisyWhaleDataset(noisy_paths)
#train_loader = DataLoader(noisy_dataset, batch_size=16, shuffle=True)
train_dataset = WhaleDenoisingDatasetBP(clean_paths, noisy_paths,
                                        snr_db_range=(10, 25),
                                        low_hz=200, high_hz=800)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)



In [36]:
noisy, clean = train_dataset[0]
ipd.display(ipd.Audio(clean.numpy(), rate=10000))  # Clean whale
ipd.display(ipd.Audio(noisy.numpy(), rate=10000))  # Clean + real-world noise


# Stage 1: Supervised Pre-training (MSE)

In [37]:
# === Stage 1: Supervised MSE training with band-pass data ===
num_epochs_stage1 = 40
denoiser.train()

for epoch in range(num_epochs_stage1):
    total_loss, total_score, total_batches = 0, 0, 0

    for noisy, clean in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()   # (B, 1, T)
        clean = clean.unsqueeze(1).cuda()   # (B, 1, T)

        denoised = denoiser(noisy)

        # Reconstruction loss
        recon_loss = mse_loss(denoised, clean)

        # Optional: monitor TF whale score (not in loss)
        def to_numpy_for_tf(x):
            x = x.detach().squeeze(1).cpu().numpy()
            x = x[..., np.newaxis].astype(np.float32)
            target_len = 39124
            return x[:, :target_len, :] if x.shape[1] > target_len else np.pad(
                x, ((0, 0), (0, target_len - x.shape[1]), (0, 0))
            )

        denoised_np = to_numpy_for_tf(denoised)
        whale_scores = embedding_model.model(denoised_np, False, None).numpy()
        mean_score = whale_scores.mean()

        optimizer.zero_grad()
        recon_loss.backward()
        optimizer.step()

        total_loss += recon_loss.item()
        total_score += mean_score
        total_batches += 1

    print(f"[Stage 1][Epoch {epoch+1}] MSE loss = {total_loss / total_batches:.4f} | "
          f"TF whale score = {total_score / total_batches:.4f}")


100%|██████████| 35/35 [00:20<00:00,  1.68it/s]


[Stage 1][Epoch 1] MSE loss = 0.0004 | TF whale score = 3.8099


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


[Stage 1][Epoch 2] MSE loss = 0.0001 | TF whale score = 3.5816


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


[Stage 1][Epoch 3] MSE loss = 0.0000 | TF whale score = 3.7047


100%|██████████| 35/35 [00:19<00:00,  1.80it/s]


[Stage 1][Epoch 4] MSE loss = 0.0000 | TF whale score = 3.7787


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


[Stage 1][Epoch 5] MSE loss = 0.0000 | TF whale score = 3.9571


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


[Stage 1][Epoch 6] MSE loss = 0.0000 | TF whale score = 3.8433


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 1][Epoch 7] MSE loss = 0.0000 | TF whale score = 3.8269


100%|██████████| 35/35 [00:19<00:00,  1.75it/s]


[Stage 1][Epoch 8] MSE loss = 0.0000 | TF whale score = 4.0921


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


[Stage 1][Epoch 9] MSE loss = 0.0000 | TF whale score = 4.0445


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


[Stage 1][Epoch 10] MSE loss = 0.0000 | TF whale score = 3.9939


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 1][Epoch 11] MSE loss = 0.0000 | TF whale score = 3.9831


100%|██████████| 35/35 [00:20<00:00,  1.73it/s]


[Stage 1][Epoch 12] MSE loss = 0.0000 | TF whale score = 4.0709


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 13] MSE loss = 0.0000 | TF whale score = 4.0723


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 14] MSE loss = 0.0000 | TF whale score = 4.1388


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 15] MSE loss = 0.0000 | TF whale score = 4.1244


100%|██████████| 35/35 [00:20<00:00,  1.69it/s]


[Stage 1][Epoch 16] MSE loss = 0.0000 | TF whale score = 4.1185


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 17] MSE loss = 0.0000 | TF whale score = 4.2401


100%|██████████| 35/35 [00:20<00:00,  1.71it/s]


[Stage 1][Epoch 18] MSE loss = 0.0000 | TF whale score = 4.1194


100%|██████████| 35/35 [00:20<00:00,  1.71it/s]


[Stage 1][Epoch 19] MSE loss = 0.0000 | TF whale score = 4.1587


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 20] MSE loss = 0.0000 | TF whale score = 4.0900


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


[Stage 1][Epoch 21] MSE loss = 0.0000 | TF whale score = 4.1326


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 22] MSE loss = 0.0000 | TF whale score = 4.1372


100%|██████████| 35/35 [00:20<00:00,  1.69it/s]


[Stage 1][Epoch 23] MSE loss = 0.0000 | TF whale score = 4.1546


100%|██████████| 35/35 [00:20<00:00,  1.69it/s]


[Stage 1][Epoch 24] MSE loss = 0.0000 | TF whale score = 4.1787


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


[Stage 1][Epoch 25] MSE loss = 0.0000 | TF whale score = 4.1436


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 26] MSE loss = 0.0000 | TF whale score = 4.2446


100%|██████████| 35/35 [00:20<00:00,  1.68it/s]


[Stage 1][Epoch 27] MSE loss = 0.0000 | TF whale score = 4.1892


100%|██████████| 35/35 [00:20<00:00,  1.70it/s]


[Stage 1][Epoch 28] MSE loss = 0.0000 | TF whale score = 4.2103


100%|██████████| 35/35 [00:20<00:00,  1.68it/s]


[Stage 1][Epoch 29] MSE loss = 0.0000 | TF whale score = 4.2605


100%|██████████| 35/35 [00:20<00:00,  1.71it/s]


[Stage 1][Epoch 30] MSE loss = 0.0000 | TF whale score = 4.2431


100%|██████████| 35/35 [00:19<00:00,  1.75it/s]


[Stage 1][Epoch 31] MSE loss = 0.0000 | TF whale score = 4.1607


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


[Stage 1][Epoch 32] MSE loss = 0.0000 | TF whale score = 4.1744


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 1][Epoch 33] MSE loss = 0.0000 | TF whale score = 4.1486


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


[Stage 1][Epoch 34] MSE loss = 0.0000 | TF whale score = 4.1762


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 1][Epoch 35] MSE loss = 0.0000 | TF whale score = 4.1616


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


[Stage 1][Epoch 36] MSE loss = 0.0000 | TF whale score = 4.2599


100%|██████████| 35/35 [00:20<00:00,  1.73it/s]


[Stage 1][Epoch 37] MSE loss = 0.0000 | TF whale score = 4.1860


100%|██████████| 35/35 [00:20<00:00,  1.71it/s]


[Stage 1][Epoch 38] MSE loss = 0.0000 | TF whale score = 4.2698


100%|██████████| 35/35 [00:20<00:00,  1.72it/s]


[Stage 1][Epoch 39] MSE loss = 0.0000 | TF whale score = 4.2031


100%|██████████| 35/35 [00:19<00:00,  1.75it/s]

[Stage 1][Epoch 40] MSE loss = 0.0000 | TF whale score = 4.1752





# Build & Train ScoreNet

In [15]:
# === Stage 2 prep: Train ScoreNet to mimic TF whale score ===
class ScoreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 16, 9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(16, 32, 9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(32, 64, 9, stride=2, padding=4),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, 1)
        )

    def forward(self, x):  # x: (B,1,T)
        return self.net(x)

score_net = ScoreNet().cuda()
opt_score = torch.optim.Adam(score_net.parameters(), lr=1e-3)
loss_score = nn.MSELoss()

# Prepare data for ScoreNet training
def get_tf_score_tensor(waveform_batch):
    x_np = waveform_batch.squeeze(1).cpu().numpy()[..., np.newaxis].astype(np.float32)
    target_len = 39124
    x_np = x_np[:, :target_len, :] if x_np.shape[1] > target_len else np.pad(
        x_np, ((0, 0), (0, target_len - x_np.shape[1]), (0, 0))
    )
    tf_scores = embedding_model.model(x_np, False, None).numpy()
    return torch.tensor(tf_scores, dtype=torch.float32).unsqueeze(1).to(waveform_batch.device)

# Train ScoreNet
epochs_score = 5
for epoch in range(epochs_score):
    total_loss = 0
    for noisy, _ in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()
        target_scores = get_tf_score_tensor(noisy)

        pred_scores = score_net(noisy)
        loss = loss_score(pred_scores, target_scores)

        opt_score.zero_grad()
        loss.backward()
        opt_score.step()

        total_loss += loss.item()

    print(f"[ScoreNet][Epoch {epoch+1}] loss = {total_loss/len(train_loader):.4f}")


100%|██████████| 35/35 [06:22<00:00, 10.92s/it]


[ScoreNet][Epoch 1] loss = 14.1566


100%|██████████| 35/35 [01:43<00:00,  2.96s/it]


[ScoreNet][Epoch 2] loss = 6.1256


100%|██████████| 35/35 [00:58<00:00,  1.66s/it]


[ScoreNet][Epoch 3] loss = 6.4080


100%|██████████| 35/35 [01:08<00:00,  1.95s/it]


[ScoreNet][Epoch 4] loss = 5.6820


100%|██████████| 35/35 [00:30<00:00,  1.14it/s]

[ScoreNet][Epoch 5] loss = 6.1510





# Stage 2: Whale-Score Maximisation Fine-Tune

In [39]:
# === Stage 2: Fine-tune denoiser to maximise whale score ===
lambda_score = 1.0   # strength of score maximisation
lambda_mse   = 0.05  # optional regulariser

num_epochs_stage2 = 15
denoiser.train()

for epoch in range(num_epochs_stage2):
    total_loss, total_score, total_batches = 0, 0, 0

    for noisy, clean in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()
        denoised = denoiser(noisy)

        # Whale score from ScoreNet (we want to maximise it)
        score_out = score_net(denoised)
        mean_score = score_out.mean()

        # Optional reg: keep output similar to input
        mse_reg = mse_loss(denoised, noisy)

        # Loss
        loss = -lambda_score * mean_score + lambda_mse * mse_reg

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_score += mean_score.item()
        total_batches += 1

    print(f"[Stage 2][Epoch {epoch+1}] loss = {total_loss/total_batches:.4f} | "
          f"ScoreNet whale score = {total_score/total_batches:.4f}")


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 2][Epoch 1] loss = -371.4446 | ScoreNet whale score = 771.0091


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


[Stage 2][Epoch 2] loss = -532.6840 | ScoreNet whale score = 1052.9669


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


[Stage 2][Epoch 3] loss = -534.4637 | ScoreNet whale score = 1064.9798


100%|██████████| 35/35 [00:20<00:00,  1.73it/s]


[Stage 2][Epoch 4] loss = -534.5084 | ScoreNet whale score = 1066.4488


100%|██████████| 35/35 [00:19<00:00,  1.83it/s]


[Stage 2][Epoch 5] loss = -534.5150 | ScoreNet whale score = 1066.7016


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 2][Epoch 6] loss = -534.5158 | ScoreNet whale score = 1066.7327


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


[Stage 2][Epoch 7] loss = -534.5151 | ScoreNet whale score = 1066.7355


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


[Stage 2][Epoch 8] loss = -534.5146 | ScoreNet whale score = 1066.7452


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


[Stage 2][Epoch 9] loss = -534.5129 | ScoreNet whale score = 1066.7429


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


[Stage 2][Epoch 10] loss = -534.5118 | ScoreNet whale score = 1066.7416


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


[Stage 2][Epoch 11] loss = -534.5107 | ScoreNet whale score = 1066.7407


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


[Stage 2][Epoch 12] loss = -534.5113 | ScoreNet whale score = 1066.7337


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


[Stage 2][Epoch 13] loss = -534.5104 | ScoreNet whale score = 1066.7439


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


[Stage 2][Epoch 14] loss = -534.5115 | ScoreNet whale score = 1066.7408


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]

[Stage 2][Epoch 15] loss = -534.5153 | ScoreNet whale score = 1066.7446





# Quick Diagnostics

In [16]:
# === Quick Diagnostics on a FRESH batch ===
import numpy as np
import torch
import IPython.display as ipd

denoiser.eval()

# 1) Pull a fresh batch
noisy_b, clean_b = next(iter(train_loader))     # (B, T)
noisy_b  = noisy_b.unsqueeze(1).cuda()          # (B, 1, T)
clean_b  = clean_b.unsqueeze(1).cuda()          # (B, 1, T)

with torch.no_grad():
    denoised_b = denoiser(noisy_b)              # (B, 1, T)

# 2) Basic amplitude / energy stats
def batch_rms(x):  # x: (B,1,T)
    return x.pow(2).mean(dim=(-1,-2)).sqrt().detach().cpu().numpy()

print("max|noisy|    :", float(noisy_b.abs().max()))
print("max|denoised| :", float(denoised_b.abs().max()))
print("RMS noisy     :", batch_rms(noisy_b)[:8])
print("RMS denoised  :", batch_rms(denoised_b)[:8])

# 3) ScoreNet scores (student) – differentiable model we trained
with torch.no_grad():
    score_noisy    = score_net(noisy_b).mean().item()
    score_denoised = score_net(denoised_b).mean().item()
print(f"ScoreNet mean   noisy={score_noisy:.3f}  denoised={score_denoised:.3f}")

# 4) TF whale score (teacher) – frozen model, numpy only
def to_tf_batch(x_torch):
    TARGET_LEN = 39124
    x = x_torch.detach().squeeze(1).cpu().numpy().astype(np.float32)  # (B,T)
    x = x[..., np.newaxis]  # (B,T,1)
    if x.shape[1] < TARGET_LEN:
        pad = TARGET_LEN - x.shape[1]
        x = np.pad(x, ((0,0),(0,pad),(0,0)))
    else:
        x = x[:, :TARGET_LEN, :]
    return x

tf_noisy    = embedding_model.model(to_tf_batch(noisy_b),    False, None).numpy().mean()
tf_denoised = embedding_model.model(to_tf_batch(denoised_b), False, None).numpy().mean()
print(f"TF mean score   noisy={tf_noisy:.3f}  denoised={tf_denoised:.3f}")

# 5) Listen to the first example (assumes 10kHz SR)
sr = 10000
i = 0
print("\n🔊 Example 0 — Noisy:")
ipd.display(ipd.Audio(noisy_b[i,0].detach().cpu().numpy(), rate=sr))
print("🔊 Example 0 — Denoised:")
ipd.display(ipd.Audio(denoised_b[i,0].detach().cpu().numpy(), rate=sr))

# If you want to hear the clean target (from the band-passed dataset):
print("🔊 Example 0 — Clean target (band-passed):")
ipd.display(ipd.Audio(clean_b[i,0].detach().cpu().numpy(), rate=sr))


max|noisy|    : 0.1870051473379135
max|denoised| : 105.6954345703125
RMS noisy     : [0.01038477 0.01430972 0.01709626 0.0112268  0.0103725  0.0169315
 0.00635836 0.00213336]
RMS denoised  : [103.17622  103.19276  103.15672  103.195045 103.2011   103.1726
 103.17705  103.17051 ]
ScoreNet mean   noisy=4.217  denoised=694.116
TF mean score   noisy=4.971  denoised=2.563

🔊 Example 0 — Noisy:


🔊 Example 0 — Denoised:


🔊 Example 0 — Clean target (band-passed):


In [18]:
import torchaudio
import torch

def postprocess_audio(waveform, sr=10000, hp_hz=10, peak=0.99):
    # waveform: torch.Tensor (T,) or (1,T) or (B,1,T)
    x = waveform
    # squeeze to (T,)
    if x.dim() == 3: x = x[0,0]
    elif x.dim() == 2: x = x[0]
    # 1) remove DC
    x = x - x.mean()
    # 2) tiny high-pass to remove residual DC/rumble
    x = torchaudio.functional.highpass_biquad(x.unsqueeze(0), sr, hp_hz).squeeze(0)
    # 3) normalize peak (optional but helpful)
    mx = x.abs().max()
    if mx > 0: x = x * (peak / mx)
    return x

In [20]:
# Example: demean + HPF + normalize
out_pp = postprocess_audio(denoised_b[0,0], sr=10000)

# Listen in Colab
import IPython.display as ipd
ipd.display(ipd.Audio(out_pp.cpu().numpy(), rate=10000))


In [17]:
import torchaudio
import os

output_dir = "/content/drive/MyDrive/whale_denoising/diagnostics"
os.makedirs(output_dir, exist_ok=True)

# Assuming sr is defined from previous cells (e.g., 10000)
output_path = os.path.join(output_dir, "denoised_example_0.wav")

# denoised_b[0,0] is (T,) so unsqueeze(0) to make it (1, T) for torchaudio.save
torchaudio.save(output_path, denoised_b[0,0].detach().cpu().unsqueeze(0), sr)

print(f"Denoised example audio saved to: {output_path}")

Denoised example audio saved to: /content/drive/MyDrive/whale_denoising/diagnostics/denoised_example_0.wav


# Listen & Compare

In [40]:
import IPython.display as ipd
import random

noisy_path = random.choice(noisy_paths)
waveform, sr = torchaudio.load(noisy_path)
waveform = waveform[0]

# Apply same band-pass as training
waveform_bp = bandpass_waveform(waveform, sr, low_hz=200, high_hz=800)

# Prepare
input_tensor = waveform_bp.unsqueeze(0).unsqueeze(0).cuda()

# Denoise
denoiser.eval()
with torch.no_grad():
    denoised_tensor = denoiser(input_tensor).squeeze().cpu()

print(f"🎧 Listening to: {noisy_path}")
print("🔊 Noisy (band-passed):")
ipd.display(ipd.Audio(waveform_bp.numpy(), rate=sr))

print("🔊 Denoised:")
ipd.display(ipd.Audio(denoised_tensor.numpy(), rate=sr))

🎧 Listening to: /content/drive/MyDrive/whale_denoising/noisy/whale_25_10k.wav
🔊 Noisy (band-passed):


🔊 Denoised:


# Save trained model to disk

In [41]:
torch.save(denoiser.state_dict(), "/content/drive/MyDrive/whale_denoising/whale_denoiser.pt")


In [10]:
denoiser.load_state_dict(torch.load("/content/drive/MyDrive/whale_denoising/whale_denoiser.pt"))
print("Denoiser loaded successfully.")

Denoiser loaded successfully.


# Save denoised files to disk

In [22]:
import os

output_dir = "/content/drive/MyDrive/whale_denoising/denoised_batch"
os.makedirs(output_dir, exist_ok=True)

denoiser.eval()
for noisy_path in noisy_paths:
    waveform, sr = torchaudio.load(noisy_path)
    waveform = waveform[0]
    input_tensor = waveform.unsqueeze(0).unsqueeze(0).cuda()

    with torch.no_grad():
        denoised_tensor = denoiser(input_tensor).squeeze().cpu()

    # Match filename, save as WAV
    filename = os.path.basename(noisy_path)
    out_path = os.path.join(output_dir, filename)
    torchaudio.save(out_path, denoised_tensor.unsqueeze(0), sr)

print(f"All denoised files saved to: {output_dir}")


All denoised files saved to: /content/drive/MyDrive/whale_denoising/denoised_batch
