In [1]:
#@title Installation. { vertical-output: true }
#@markdown Run this notebook in Google Colab by following [this link](https://colab.research.google.com/github/google-research/perch/blob/main/embed_audio.ipynb).
#@markdown
#@markdown Run this cell to install the project dependencies.
%pip install git+https://github.com/google-research/perch.git

Collecting git+https://github.com/google-research/perch.git
  Cloning https://github.com/google-research/perch.git to /tmp/pip-req-build-qpv2kl0l
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/perch.git /tmp/pip-req-build-qpv2kl0l
  Resolved https://github.com/google-research/perch.git to commit 09c304e92d2ef093f353a7454206cbe2a4069249
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting perch_hoplite@ git+https://github.com/google-research/perch-hoplite.git (from chirp==0.1.0)
  Cloning https://github.com/google-research/perch-hoplite.git to /tmp/pip-install-ioeuw_mu/perch-hoplite_97371222b8be42818ac4e3611a0c4bc8
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/perch-hoplite.git /tmp/pip-install-ioeuw_mu/perch-hoplite_97371222b8be42818ac4e3611a0c4bc8
  Resolved https://github

In [1]:
#@title Imports. { vertical-output: true }

from etils import epath
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm
from chirp.inference import colab_utils
colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)

from chirp import audio_utils
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from perch_hoplite.zoo import model_configs



In [2]:
#@title Basic Configuration. { vertical-output: true }

#@markdown Define the model: perch or birdnet are most common for birds.
model_choice = 'humpback'  #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']
#@markdown Set the base directory for the project.
working_dir = '/tmp/agile'  #@param

# Set the embedding and labeled data directories.
embeddings_path = epath.Path(working_dir) / 'embeddings'
labeled_data_path = epath.Path(working_dir) / 'labeled'
embeddings_glob = embeddings_path / 'embeddings-*'

# OPTIONAL: Set up separation model.
separation_model_key = 'separator_model_tf'  #@param
separation_model_path = ''  #@param

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os
#@title Create a new folder in Drive (if it doesn't already exist) within your Google drive.
base_dir = '/content/drive/MyDrive/'
#@ markdown Name of your new folder in Drive
new_folder_name = 'whale_denoising' #@param

drive_output_directory = base_dir + new_folder_name

try:
  if not os.path.exists(drive_output_directory):
    os.makedirs(drive_output_directory, exist_ok=True)
    print(f'Directory {drive_output_directory} created successfully.')
  else:
    print(f'Directory {drive_output_directory} already exists.')
except OSError as e:
    print("Error:", e)

Directory /content/drive/MyDrive/whale_denoising already exists.


In [5]:
#@title Embedding Configuration. { vertical-output: true }

config = config_dict.ConfigDict()
config.embed_fn_config = config_dict.ConfigDict()
config.embed_fn_config.model_config = config_dict.ConfigDict()

#@markdown IMPORTANT: Select the target audio files.
#@markdown source_file_patterns should contain a list of globs of audio files, like:
#@markdown ['/home/me/*.wav', '/home/me/other/*.flac']
#config.source_file_patterns = ['gs://chirp-public-bucket/soundscapes/powdermill/Recording*/*.wav']  #@param
config.source_file_patterns = ['/content/drive/MyDrive/whale_denoising/noisy/*.wav', '/content/drive/MyDrive/whale_denoising/clean/*.wav']  #@param
config.output_dir = embeddings_path.as_posix()

preset_info = model_configs.get_preset_model_config(model_choice)
config.embed_fn_config.model_key = preset_info.model_key
config.embed_fn_config.model_config = preset_info.model_config

# Only write embeddings to reduce size.
config.embed_fn_config.write_embeddings = True
config.embed_fn_config.write_logits = False
config.embed_fn_config.write_separated_audio = False
config.embed_fn_config.write_raw_audio = False

#@markdown File sharding automatically splits audio files into one-minute chunks
#@markdown for embedding. This limits both system and GPU memory usage,
#@markdown especially useful when working with long files (>1 hour).
use_file_sharding = False  #@param {type:'boolean'}
if use_file_sharding:
  config.shard_len_s = 60.0

# Number of parent directories to include in the filename.
config.embed_fn_config.file_id_depth = 1

In [6]:
#@title Set up. { vertical-output: true }

# Set up the embedding function, including loading models.
embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)
print('\n\nLoading model(s)...')
embed_fn.setup()

# Create output directory and write the configuration.
output_dir = epath.Path(config.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
embed_lib.maybe_write_config(config, output_dir)

# Create SourceInfos.
source_infos = embed_lib.create_source_infos(
    config.source_file_patterns,
    num_shards_per_file=config.get('num_shards_per_file', -1),
    shard_len_s=config.get('shard_len_s', -1))
print(f'Found {len(source_infos)} source infos.')

print('\n\nTest-run of model...')
window_size_s = config.embed_fn_config.model_config.window_size_s
sr = config.embed_fn_config.model_config.sample_rate
z = np.zeros([int(sr * window_size_s)], dtype=np.float32)
embed_fn.embedding_model.embed(z)
print('Setup complete!')



Loading model(s)...
Found 1473 source infos.


Test-run of model...
Setup complete!


# Extract the model

In [7]:
embedding_model = embed_fn.embedding_model

In [8]:
#@title Run embedding. { vertical-output: true }

# Uses multiple threads to load audio before embedding.
# This tends to be faster, but can fail if any audio files are corrupt.

embed_fn.min_audio_s = 1.0
record_file = (output_dir / 'embeddings.tfrecord').as_posix()
succ, fail = 0, 0

existing_embedding_ids = embed_lib.get_existing_source_ids(
    output_dir, 'embeddings-*')

new_source_infos = embed_lib.get_new_source_infos(
    source_infos, existing_embedding_ids, config.embed_fn_config.file_id_depth)

print(f'Found {len(existing_embedding_ids)} existing embedding ids. \n'
      f'Processing {len(new_source_infos)} new source infos. ')

try:
  audio_loader = lambda fp, offset: audio_utils.load_audio_window(
      fp, offset, sample_rate=config.embed_fn_config.model_config.sample_rate,
      window_size_s=config.get('shard_len_s', -1.0))
  audio_iterator = audio_utils.multi_load_audio_window(
      filepaths=[s.filepath for s in new_source_infos],
      offsets=[s.shard_num * s.shard_len_s for s in new_source_infos],
      audio_loader=audio_loader,
  )
  with tf_examples.EmbeddingsTFRecordMultiWriter(
      output_dir=output_dir, num_files=config.get('tf_record_shards', 1)) as file_writer:
    for source_info, audio in tqdm.tqdm(
        zip(new_source_infos, audio_iterator), total=len(new_source_infos)):
      if not embed_fn.validate_audio(source_info, audio):
        continue
      file_id = source_info.file_id(config.embed_fn_config.file_id_depth)
      offset_s = source_info.shard_num * source_info.shard_len_s
      example = embed_fn.audio_to_example(file_id, offset_s, audio)
      if example is None:
        fail += 1
        continue
      file_writer.write(example.SerializeToString())
      succ += 1
    file_writer.flush()
finally:
  del(audio_iterator)
print(f'\n\nSuccessfully processed {succ} source_infos, failed {fail} times.')

fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
ds = ds.map(parser)
for ex in ds.as_numpy_iterator():
  print(ex['filename'])
  print(ex['embedding'].shape, flush=True)
  break

Found 0 existing embedding ids. 
Processing 1473 new source infos. 


100%|██████████| 1473/1473 [01:25<00:00, 17.18it/s]




Successfully processed 1473 source_infos, failed 0 times.
b'noisy/whale_1_10k.wav'
(1, 1, 2048)


# This is wrong

In [16]:
import json
import torch
from chirp.models import build_model
from chirp.inference import embed_lib

# Load config from JSON
with open("/tmp/agile/embeddings/config.json", "r") as f:
    full_config = json.load(f)

model_config_dict = full_config["embed_fn_config"]["model_config"]

# Build model
embedding_model = build_model.build_model(model_config_dict)

# Load weights from TensorFlow Hub
model_url = model_config_dict["model_url"]
embedding_model = embed_lib.load_tfhub_model(model_url, embedding_model)
embedding_model.eval()

# Wrap in embed_fn
embed_fn = embed_lib.EmbedFn(embedding_model, frontend=True)

ImportError: cannot import name 'build_model' from 'chirp.models' (unknown location)

In [27]:
print(type(embedding_model))
print(dir(embedding_model))


<class 'perch_hoplite.zoo.models_tf.GoogleWhaleModel'>
['__annotations__', '__class__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__match_args__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'batch_embed', 'class_list', 'embed', 'frame_audio', 'from_config', 'hop_size_s', 'load_humpback_model', 'model', 'model_url', 'normalize_audio', 'peak_norm', 'sample_rate', 'window_size_s']


# Create Denoiser

In [30]:
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np

# Simple 1D convolutional denoiser
class Denoiser1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 1, kernel_size=15, padding=7)
        )

    def forward(self, x):
        return self.net(x)

denoiser = Denoiser1D().cuda()
optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()



# Pre bandpass the audio

In [23]:
import torchaudio

def bandpass_waveform(waveform, sr, low_hz=200, high_hz=800):
    # waveform: (T,) or (1,T)
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)

    # Apply highpass then lowpass
    bp = torchaudio.functional.highpass_biquad(waveform, sr, low_hz)
    bp = torchaudio.functional.lowpass_biquad(bp, sr, high_hz)
    return bp.squeeze(0)


# Define data loader

In [28]:
from torch.utils.data import Dataset, DataLoader
import torchaudio
import random
import numpy as np

class WhaleDenoisingDatasetBP(Dataset):
    def __init__(self, clean_paths, noise_paths, snr_db_range=(10, 25),
                 low_hz=200, high_hz=800):
        self.clean_paths = clean_paths
        self.noise_paths = noise_paths
        self.snr_db_range = snr_db_range
        self.low_hz = low_hz
        self.high_hz = high_hz

    def __len__(self):
        return len(self.clean_paths)

    def __getitem__(self, idx):
        clean_path = self.clean_paths[idx]
        noise_path = random.choice(self.noise_paths)

        clean, sr = torchaudio.load(clean_path)
        noise, _  = torchaudio.load(noise_path)

        # Mono
        clean = clean[0]
        noise = noise[0]

        # Ensure same length
        target_len = 39124
        if len(clean) < target_len:
            clean = torch.nn.functional.pad(clean, (0, target_len - len(clean)))
        else:
            clean = clean[:target_len]

        if len(noise) < target_len:
            noise = torch.nn.functional.pad(noise, (0, target_len - len(noise)))
        else:
            start = random.randint(0, len(noise) - target_len)
            noise = noise[start:start+target_len]

        # Apply band-pass to both
        clean = bandpass_waveform(clean, sr, self.low_hz, self.high_hz)
        noise = bandpass_waveform(noise, sr, self.low_hz, self.high_hz)

        # Mix at random SNR
        snr_db = random.uniform(*self.snr_db_range)
        alpha = 10 ** (-snr_db / 20)
        noisy = clean + alpha * noise

        return noisy, clean


noisy_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/noisy/*.wav'))
clean_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/clean/*.wav'))
#noisy_dataset = NoisyWhaleDataset(noisy_paths)
#train_loader = DataLoader(noisy_dataset, batch_size=16, shuffle=True)
train_dataset = WhaleDenoisingDatasetBP(clean_paths, noisy_paths,
                                        snr_db_range=(10, 25),
                                        low_hz=200, high_hz=800)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)



In [29]:
noisy, clean = train_dataset[0]
ipd.display(ipd.Audio(clean.numpy(), rate=10000))  # Clean whale
ipd.display(ipd.Audio(noisy.numpy(), rate=10000))  # Clean + real-world noise


In [14]:
print(dir(embed_fn))

['BundleContextParam', 'BundleFinalizerParam', 'DoFnProcessParams', 'DynamicTimerTagParam', 'ElementParam', 'KeyParam', 'PaneInfoParam', 'RestrictionParam', 'SetupContextParam', 'SideInputParam', 'StateParam', 'TimerParam', 'TimestampParam', 'WatermarkEstimatorParam', 'WindowParam', 'WindowedValueParam', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_can_yield_batches', '_get_display_data_namespace', '_get_element_type_from_return_annotation', '_get_input_batch_type_normalized', '_get_or_create_type_hints', '_get_output_batch_type_normalized', '_known_urns', '_log_exception', '_process_argspec_fn', '_process_batch_defined', '_process_batch_yields_elements', 

# Training Loop

In [31]:
for epoch in range(50):
    denoiser.train()
    total_loss = 0
    total_score = 0
    total_batches = 0

    for noisy, clean in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()   # (B, 1, T)
        clean = clean.unsqueeze(1).cuda()   # (B, 1, T)
        denoised = denoiser(noisy)

        # Reconstruction loss (supervised denoising)
        recon_loss = mse_loss(denoised, clean)

        # Convert to numpy for whale score monitoring
        def to_numpy_for_tf(x):
            x = x.detach().squeeze(1).cpu().numpy()
            x = x[..., np.newaxis].astype(np.float32)
            target_len = 39124
            return x[:, :target_len, :] if x.shape[1] > target_len else np.pad(
                x, ((0, 0), (0, target_len - x.shape[1]), (0, 0))
            )

        denoised_np = to_numpy_for_tf(denoised)

        # Whale score from TF model (not used in loss, just for monitoring)
        whale_scores = embedding_model.model(denoised_np, False, None).numpy()
        mean_score = whale_scores.mean()

        # Backprop only on MSE loss
        optimizer.zero_grad()
        recon_loss.backward()
        optimizer.step()

        total_loss += recon_loss.item()
        total_score += mean_score
        total_batches += 1

    print(f"Epoch {epoch+1}: recon_loss = {total_loss / total_batches:.4f} | whale score = {total_score / total_batches:.4f}")



100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 1: recon_loss = 0.0001 | whale score = 3.6069


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 2: recon_loss = 0.0000 | whale score = 3.7414


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


Epoch 3: recon_loss = 0.0000 | whale score = 4.1182


100%|██████████| 35/35 [00:19<00:00,  1.75it/s]


Epoch 4: recon_loss = 0.0000 | whale score = 4.1715


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 5: recon_loss = 0.0000 | whale score = 4.2847


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 6: recon_loss = 0.0000 | whale score = 4.3121


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 7: recon_loss = 0.0000 | whale score = 4.3508


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 8: recon_loss = 0.0000 | whale score = 4.2706


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 9: recon_loss = 0.0000 | whale score = 4.2099


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 10: recon_loss = 0.0000 | whale score = 4.2244


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 11: recon_loss = 0.0000 | whale score = 4.2007


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 12: recon_loss = 0.0000 | whale score = 4.2436


100%|██████████| 35/35 [00:19<00:00,  1.80it/s]


Epoch 13: recon_loss = 0.0000 | whale score = 4.2454


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 14: recon_loss = 0.0000 | whale score = 4.1359


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 15: recon_loss = 0.0000 | whale score = 4.1491


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


Epoch 16: recon_loss = 0.0000 | whale score = 4.0235


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 17: recon_loss = 0.0000 | whale score = 4.0441


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 18: recon_loss = 0.0000 | whale score = 4.1275


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 19: recon_loss = 0.0000 | whale score = 4.0082


100%|██████████| 35/35 [00:19<00:00,  1.83it/s]


Epoch 20: recon_loss = 0.0000 | whale score = 4.0542


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 21: recon_loss = 0.0000 | whale score = 4.0050


100%|██████████| 35/35 [00:19<00:00,  1.80it/s]


Epoch 22: recon_loss = 0.0000 | whale score = 4.0629


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 23: recon_loss = 0.0000 | whale score = 4.0260


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 24: recon_loss = 0.0000 | whale score = 4.0350


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 25: recon_loss = 0.0000 | whale score = 4.0232


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 26: recon_loss = 0.0000 | whale score = 4.0696


100%|██████████| 35/35 [00:20<00:00,  1.74it/s]


Epoch 27: recon_loss = 0.0000 | whale score = 4.1236


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 28: recon_loss = 0.0000 | whale score = 4.0050


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 29: recon_loss = 0.0000 | whale score = 4.0165


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 30: recon_loss = 0.0000 | whale score = 3.9664


100%|██████████| 35/35 [00:20<00:00,  1.75it/s]


Epoch 31: recon_loss = 0.0000 | whale score = 3.9748


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 32: recon_loss = 0.0000 | whale score = 4.0320


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 33: recon_loss = 0.0000 | whale score = 4.0806


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 34: recon_loss = 0.0000 | whale score = 4.0659


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 35: recon_loss = 0.0000 | whale score = 4.0349


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 36: recon_loss = 0.0000 | whale score = 4.1084


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 37: recon_loss = 0.0000 | whale score = 4.0621


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 38: recon_loss = 0.0000 | whale score = 3.9995


100%|██████████| 35/35 [00:20<00:00,  1.75it/s]


Epoch 39: recon_loss = 0.0000 | whale score = 4.1383


100%|██████████| 35/35 [00:19<00:00,  1.76it/s]


Epoch 40: recon_loss = 0.0000 | whale score = 4.0515


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 41: recon_loss = 0.0000 | whale score = 4.0850


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]


Epoch 42: recon_loss = 0.0000 | whale score = 4.0084


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


Epoch 43: recon_loss = 0.0000 | whale score = 3.9897


100%|██████████| 35/35 [00:19<00:00,  1.78it/s]


Epoch 44: recon_loss = 0.0000 | whale score = 4.0941


100%|██████████| 35/35 [00:19<00:00,  1.77it/s]


Epoch 45: recon_loss = 0.0000 | whale score = 4.1484


100%|██████████| 35/35 [00:19<00:00,  1.82it/s]


Epoch 46: recon_loss = 0.0000 | whale score = 4.1191


100%|██████████| 35/35 [00:19<00:00,  1.75it/s]


Epoch 47: recon_loss = 0.0000 | whale score = 4.0601


100%|██████████| 35/35 [00:19<00:00,  1.79it/s]


Epoch 48: recon_loss = 0.0000 | whale score = 4.0807


100%|██████████| 35/35 [00:19<00:00,  1.80it/s]


Epoch 49: recon_loss = 0.0000 | whale score = 4.1323


100%|██████████| 35/35 [00:19<00:00,  1.81it/s]

Epoch 50: recon_loss = 0.0000 | whale score = 4.0162





# Save trained model to disk

In [21]:
torch.save(denoiser.state_dict(), "/content/drive/MyDrive/whale_denoising/whale_denoiser.pt")


# Iterative Denoising

In [64]:
import IPython.display as ipd
import torchaudio

# === Configurable number of passes ===
NUM_PASSES = 5  # ← change this to however many iterations you want

# === Load real-world noisy sample ===
noisy_path = noisy_paths[0]  # or any index
waveform, sr = torchaudio.load(noisy_path)
waveform = waveform[0]  # shape (T,)

# === Prepare input ===
current = waveform.unsqueeze(0).unsqueeze(0).cuda()  # (1, 1, T)

# === Apply denoiser iteratively ===
denoiser.eval()
outputs = [waveform.cpu().numpy()]  # store original as first output

with torch.no_grad():
    for _ in range(NUM_PASSES):
        current = denoiser(current)
        outputs.append(current.squeeze().cpu().numpy())

# === Playback original and final ===
print("🔊 Original Noisy Input:")
ipd.display(ipd.Audio(outputs[0], rate=sr))

print(f"🔊 After {NUM_PASSES} Passes:")
ipd.display(ipd.Audio(outputs[-1], rate=sr))


🔊 Original Noisy Input:


🔊 After 5 Passes:


# Listen to denoised output

In [58]:
denoiser.eval()
with torch.no_grad():
    test_waveform, _ = torchaudio.load(noisy_paths[0])
    test_tensor = test_waveform.unsqueeze(0).cuda()  # shape: (1, 1, T)
    denoised_tensor = denoiser(test_tensor).squeeze().cpu()

import IPython.display as ipd
print("🔊 Noisy input:")
ipd.display(ipd.Audio(test_waveform.numpy(), rate=10000))

print("🔊 Denoised output:")
ipd.display(ipd.Audio(denoised_tensor.numpy(), rate=10000))

🔊 Noisy input:


🔊 Denoised output:


In [50]:
import IPython.display as ipd
import torchaudio

# Pick a clean audio file from your training set
waveform, sr = torchaudio.load(clean_paths[0])  # or any other index
waveform = waveform[0]  # shape (T,)

# Add synthetic noise to simulate what was done during training
import torch
import numpy as np

noise = torch.randn_like(waveform)
snr_db = 10  # adjust this for testing
alpha = 10 ** (-snr_db / 20)
noisy = waveform + alpha * noise

# Run the denoiser
denoiser.eval()
with torch.no_grad():
    input_tensor = noisy.unsqueeze(0).unsqueeze(0).cuda()  # (1, 1, T)
    denoised_tensor = denoiser(input_tensor).squeeze().cpu()

# 🔊 Listen
print("🔊 Noisy input:")
ipd.display(ipd.Audio(noisy.numpy(), rate=sr))

print("🔊 Denoised output:")
ipd.display(ipd.Audio(denoised_tensor.numpy(), rate=sr))

print("🔊 Original clean:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))



🔊 Noisy input:


🔊 Denoised output:


🔊 Original clean:


# Check that model output is a good measure of whaliness

In [32]:
import numpy as np
import torchaudio
from tqdm import tqdm
import matplotlib.pyplot as plt

def load_and_prepare(path, target_len=39124):
    waveform, sr = torchaudio.load(path)
    waveform = waveform.squeeze(0).numpy()
    waveform = waveform[:target_len] if len(waveform) > target_len else np.pad(waveform, (0, target_len - len(waveform)))
    waveform = waveform.astype(np.float32)[np.newaxis, :, np.newaxis]  # shape (1, T, 1)
    return waveform

# Run model on clean and noisy sets
clean_scores = []
print("Processing clean files...")
for path in tqdm(clean_paths):
    audio = load_and_prepare(path)
    score = embedding_model.model(audio, False, None).numpy().item()
    clean_scores.append(score)

noisy_scores = []
print("Processing noisy files...")
for path in tqdm(noisy_paths):
    audio = load_and_prepare(path)
    score = embedding_model.model(audio, False, None).numpy().item()
    noisy_scores.append(score)

# Print means
print(f"\n✅ Clean mean score: {np.mean(clean_scores):.4f}")
print(f"✅ Noisy mean score: {np.mean(noisy_scores):.4f}")

# Optional: histogram plot
plt.hist(clean_scores, bins=30, alpha=0.7, label='Clean', color='blue')
plt.hist(noisy_scores, bins=30, alpha=0.7, label='Noisy', color='orange')
plt.xlabel("Whale classifier output")
plt.ylabel("Count")
plt.title("Distribution of Whale Scores")
plt.legend()
plt.grid(True)
plt.show()


Processing clean files...


 42%|████▏     | 231/550 [00:05<00:08, 39.48it/s]


KeyboardInterrupt: 

# Train PyTorch surrogate

## Helpers: crop/pad and TF score in batches

In [9]:
import numpy as np
import torch
import torchaudio
from tqdm import tqdm

TARGET_LEN = 39124  # TF model expects (B, 39124, 1)

def load_crop_pad_to_len(path, target_len=TARGET_LEN):
    wav, sr = torchaudio.load(path)   # (1, T)
    x = wav[0]                        # (T,)
    if x.numel() < target_len:
        x = torch.nn.functional.pad(x, (0, target_len - x.numel()))
    else:
        x = x[:target_len]
    return x  # torch.Tensor, shape (T,)

def to_tf_batch(t_batch):  # t_batch: torch (B, T)
    x = t_batch.detach().cpu().numpy().astype(np.float32)  # (B, T)
    x = x[..., np.newaxis]  # (B, T, 1)
    return x

@torch.no_grad()
def tf_scores_for_paths(paths, batch_size=32):
    scores = []
    # load tensors first to keep IO out of TF loop
    tensors = [load_crop_pad_to_len(p) for p in paths]
    for i in tqdm(range(0, len(tensors), batch_size), desc="TF scoring"):
        batch = torch.stack(tensors[i:i+batch_size], dim=0)  # (B, T)
        np_batch = to_tf_batch(batch)                        # (B, T, 1)
        s = embedding_model.model(np_batch, False, None).numpy().astype(np.float32)  # (B, 1)
        scores.append(s)
    return np.vstack(scores)  # (N, 1)


## Precompute teacher (TF) scores for clean + noisy sets

In [12]:
import glob
noisy_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/noisy/*.wav'))
clean_paths = sorted(glob.glob('/content/drive/MyDrive/whale_denoising/clean/*.wav'))

# Optionally subsample to speed up distillation
clean_paths_distill = clean_paths  # or clean_paths[:400]
noisy_paths_distill = noisy_paths  # or noisy_paths[:400]

print("Scoring clean set with TF model...")
tf_clean_scores = tf_scores_for_paths(clean_paths_distill)  # (Nc, 1)
print("Scoring noisy set with TF model...")
tf_noisy_scores = tf_scores_for_paths(noisy_paths_distill)  # (Nn, 1)

# Build tensors for student training: inputs + targets
# We'll train on both clean and noisy so the student learns the full range
X_wave = [load_crop_pad_to_len(p) for p in (clean_paths_distill + noisy_paths_distill)]
Y_score = np.vstack([tf_clean_scores, tf_noisy_scores]).astype(np.float32)  # (N, 1)

X_wave = torch.stack(X_wave, dim=0)        # (N, T)
Y_score = torch.from_numpy(Y_score).float()# (N, 1)

print("Student dataset:", X_wave.shape, Y_score.shape)


Scoring clean set with TF model...


TF scoring: 100%|██████████| 18/18 [00:02<00:00,  7.72it/s]


Scoring noisy set with TF model...


TF scoring: 100%|██████████| 29/29 [00:01<00:00, 18.26it/s]


Student dataset: torch.Size([1473, 39124]) torch.Size([1473, 1])


## Train the PyTorch “whale‑score” student (ScoreNet)

In [13]:
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Simple 1D conv regressor -> scalar
class ScoreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32,  nine := 9, padding=4), nn.ReLU(),
            nn.Conv1d(32, 64, nine, padding=4),      nn.ReLU(),
            nn.Conv1d(64, 128, nine, padding=4),     nn.ReLU(),
            nn.AdaptiveAvgPool1d(64),  # shrink time
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*64, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # predict TF logit-ish score
        )
    def forward(self, x):  # x: (B, 1, T)
        h = self.net(x)
        return self.head(h)  # (B, 1)

score_net = ScoreNet().cuda()
crit = nn.MSELoss()
opt  = optim.Adam(score_net.parameters(), lr=1e-3)

# Dataloader
student_ds = TensorDataset(X_wave.unsqueeze(1), Y_score)  # (N,1,T), (N,1)
student_dl = DataLoader(student_ds, batch_size=32, shuffle=True)

# Train student to mimic TF scores
for ep in range(10):
    score_net.train()
    tot = 0
    for xb, yb in student_dl:
        xb = xb.cuda()
        yb = yb.cuda()
        pred = score_net(xb)
        loss = crit(pred, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        tot += loss.item()
    print(f"[ScoreNet] epoch {ep+1}: MSE {tot/len(student_dl):.4f}")

# Freeze for denoiser training
for p in score_net.parameters():
    p.requires_grad = False
score_net.eval()


[ScoreNet] epoch 1: MSE 11.0376
[ScoreNet] epoch 2: MSE 9.7639
[ScoreNet] epoch 3: MSE 9.5054
[ScoreNet] epoch 4: MSE 9.0791
[ScoreNet] epoch 5: MSE 9.3942
[ScoreNet] epoch 6: MSE 8.7101
[ScoreNet] epoch 7: MSE 7.0776
[ScoreNet] epoch 8: MSE 6.1633
[ScoreNet] epoch 9: MSE 5.6876
[ScoreNet] epoch 10: MSE 5.7477


ScoreNet(
  (net): Sequential(
    (0): Conv1d(1, 32, kernel_size=(9,), stride=(1,), padding=(4,))
    (1): ReLU()
    (2): Conv1d(32, 64, kernel_size=(9,), stride=(1,), padding=(4,))
    (3): ReLU()
    (4): Conv1d(64, 128, kernel_size=(9,), stride=(1,), padding=(4,))
    (5): ReLU()
    (6): AdaptiveAvgPool1d(output_size=64)
  )
  (head): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=8192, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
)

## Fine‑tune the denoiser with whale‑score guidance

In [18]:
# Reuse your real-noise dataset & loader:
# train_dataset = WhaleDenoisingDataset(clean_paths, noisy_paths, snr_db_range=(10,25))
# train_loader  = DataLoader(train_dataset, batch_size=16, shuffle=True)

lambda_score = 0.3   # weight for score guidance (tune 0.1–1.0)
num_epochs   = 30

for epoch in range(num_epochs):
    denoiser.train()
    total_recon, total_score, nb = 0.0, 0.0, 0

    for noisy, clean in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()   # (B,1,T)
        clean = clean.unsqueeze(1).cuda()   # (B,1,T)
        denoised = denoiser(noisy)

        # 1) Supervised denoising target
        recon_loss = mse_loss(denoised, clean)

        # 2) Whale-awareness via student (fully differentiable)
        with torch.no_grad():
            score_target = score_net(clean)       # (B,1) – target whale-ness
        score_out = score_net(denoised)           # (B,1) – denoised whale-ness
        score_loss = (score_out - score_target).pow(2).mean()

        total = recon_loss + lambda_score * score_loss

        optimizer.zero_grad()
        total.backward()
        optimizer.step()

        total_recon += recon_loss.item()
        total_score += score_loss.item()
        nb += 1

    print(f"[Denoiser] epoch {epoch+1}/{num_epochs} | recon {total_recon/nb:.4f} | score {total_score/nb:.4f}")


100%|██████████| 35/35 [00:13<00:00,  2.59it/s]


[Denoiser] epoch 1/30 | recon 0.0024 | score 7.0433


100%|██████████| 35/35 [00:12<00:00,  2.78it/s]


[Denoiser] epoch 2/30 | recon 0.0020 | score 2.4891


100%|██████████| 35/35 [00:12<00:00,  2.77it/s]


[Denoiser] epoch 3/30 | recon 0.0018 | score 1.5933


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 4/30 | recon 0.0016 | score 1.3318


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 5/30 | recon 0.0013 | score 1.1291


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 6/30 | recon 0.0010 | score 0.9533


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 7/30 | recon 0.0007 | score 0.6047


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 8/30 | recon 0.0005 | score 0.3563


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 9/30 | recon 0.0005 | score 0.3636


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 10/30 | recon 0.0004 | score 0.2687


100%|██████████| 35/35 [00:12<00:00,  2.81it/s]


[Denoiser] epoch 11/30 | recon 0.0004 | score 0.2155


100%|██████████| 35/35 [00:12<00:00,  2.82it/s]


[Denoiser] epoch 12/30 | recon 0.0003 | score 0.2170


100%|██████████| 35/35 [00:12<00:00,  2.81it/s]


[Denoiser] epoch 13/30 | recon 0.0003 | score 0.1786


100%|██████████| 35/35 [00:12<00:00,  2.81it/s]


[Denoiser] epoch 14/30 | recon 0.0004 | score 0.2748


100%|██████████| 35/35 [00:12<00:00,  2.83it/s]


[Denoiser] epoch 15/30 | recon 0.0004 | score 0.2500


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 16/30 | recon 0.0004 | score 0.2567


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 17/30 | recon 0.0004 | score 0.2429


100%|██████████| 35/35 [00:12<00:00,  2.76it/s]


[Denoiser] epoch 18/30 | recon 0.0004 | score 0.2582


100%|██████████| 35/35 [00:12<00:00,  2.78it/s]


[Denoiser] epoch 19/30 | recon 0.0004 | score 0.2540


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 20/30 | recon 0.0004 | score 0.2105


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 21/30 | recon 0.0004 | score 0.2533


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 22/30 | recon 0.0003 | score 0.2534


100%|██████████| 35/35 [00:12<00:00,  2.78it/s]


[Denoiser] epoch 23/30 | recon 0.0004 | score 0.2311


100%|██████████| 35/35 [00:12<00:00,  2.79it/s]


[Denoiser] epoch 24/30 | recon 0.0003 | score 0.1997


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 25/30 | recon 0.0003 | score 0.2015


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 26/30 | recon 0.0004 | score 0.2193


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 27/30 | recon 0.0003 | score 0.2006


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]


[Denoiser] epoch 28/30 | recon 0.0003 | score 0.1988


100%|██████████| 35/35 [00:12<00:00,  2.81it/s]


[Denoiser] epoch 29/30 | recon 0.0003 | score 0.1864


100%|██████████| 35/35 [00:12<00:00,  2.80it/s]

[Denoiser] epoch 30/30 | recon 0.0003 | score 0.1661





## (Optional) Monitor with the real TF whale score

In [19]:
@torch.no_grad()
def tf_whale_score_from_tensor(x):  # x: (B,1,T) torch
    npb = to_tf_batch(x.squeeze(1)) # (B,T,1)
    return embedding_model.model(npb, False, None).numpy()  # (B,1)

# After each epoch (or on a val batch):
denoiser.eval()
noisy_b, clean_b = next(iter(train_loader))
noisy_b  = noisy_b.unsqueeze(1).cuda()
denoised_b = denoiser(noisy_b)

tf_score_noisy    = tf_whale_score_from_tensor(noisy_b).mean()
tf_score_denoised = tf_whale_score_from_tensor(denoised_b).mean()
print(f"TF whale score — noisy: {tf_score_noisy:.3f}, denoised: {tf_score_denoised:.3f}")


TF whale score — noisy: 4.324, denoised: 0.947


## Listen to the difference

In [20]:
import IPython.display as ipd
import random
import torchaudio

# Pick a random noisy file to test
noisy_path = random.choice(noisy_paths)
print(f"Testing on: {noisy_path}")

# Load noisy audio
waveform, sr = torchaudio.load(noisy_path)
waveform = waveform[0]  # (T,)

# Prepare input for denoiser
input_tensor = waveform.unsqueeze(0).unsqueeze(0).cuda()  # (1, 1, T)

# Apply denoiser
denoiser.eval()
with torch.no_grad():
    denoised_tensor = denoiser(input_tensor).squeeze().cpu()

# Listen
print("🔊 Noisy input:")
ipd.display(ipd.Audio(waveform.numpy(), rate=sr))

print("🔊 Denoised output:")
ipd.display(ipd.Audio(denoised_tensor.numpy(), rate=sr))


Testing on: /content/drive/MyDrive/whale_denoising/noisy/whale_167_10k.wav
🔊 Noisy input:


🔊 Denoised output:


## Save denoised files to disk

In [22]:
import os

output_dir = "/content/drive/MyDrive/whale_denoising/denoised_batch"
os.makedirs(output_dir, exist_ok=True)

denoiser.eval()
for noisy_path in noisy_paths:
    waveform, sr = torchaudio.load(noisy_path)
    waveform = waveform[0]
    input_tensor = waveform.unsqueeze(0).unsqueeze(0).cuda()

    with torch.no_grad():
        denoised_tensor = denoiser(input_tensor).squeeze().cpu()

    # Match filename, save as WAV
    filename = os.path.basename(noisy_path)
    out_path = os.path.join(output_dir, filename)
    torchaudio.save(out_path, denoised_tensor.unsqueeze(0), sr)

print(f"All denoised files saved to: {output_dir}")


All denoised files saved to: /content/drive/MyDrive/whale_denoising/denoised_batch
