# Overview

This notebook uses perch-hoplite to compute and save embeddings for set of audio files using a pre-trained model. This is the first step in the agile modeling process. If the data you wish to search and classify is already embedded with a pre-trained model into a perch-hoplite database, then proceed to the step 2 colab notebook ([2_agile_modeling_v2.ipynb](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/2_agile_modeling_v2.ipynb)).

## [Optional] perch-hoplite installation for hosted runtimes

If you have not already installed perch-hoplite (particularly if you are using a hosted Colab runtime), make sure to install perch-hoplite from the Github source to ensure the most recent version is installed. After installation, you will need to restart your runtime before running anything else. Go to the top menu, select "Runtime" then "Restart Session".

In [None]:
#@title Only run this code if you need to install perch-hoplite
!pip install git+https://github.com/google-research/perch-hoplite.git

In [None]:
# @title Imports
from etils import epath
from IPython.display import display
import ipywidgets as widgets
import numpy as np
from perch_hoplite.agile import colab_utils
from perch_hoplite.agile import embed
from perch_hoplite.agile import source_info
from perch_hoplite.db import brutalism
from perch_hoplite.db import interface

# Connect to Google Drive

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')

# Create Target Folder

In [None]:
#@title Create a new folder in Drive (if it doesn't already exist) within your Google drive.
base_dir = '/content/drive/MyDrive/'
#@ markdown Name of your new folder in Drive
new_folder_name = 'whale_denoising' #@param

drive_output_directory = base_dir + new_folder_name

try:
  if not os.path.exists(drive_output_directory):
    os.makedirs(drive_output_directory, exist_ok=True)
    print(f'Directory {drive_output_directory} created successfully.')
  else:
    print(f'Directory {drive_output_directory} already exists.')
except OSError as e:
    print("Error:", e)

# Embed the audio data

In [None]:
# @title Configuration { vertical-output: true }

# @markdown Configure the raw dataset and output location(s).  The format is a mapping from
# @markdown a dataset_name to a (base_path, fileglob) pair.  Note that the file
# @markdown globs are case sensitive.  The dataset name can be anything you want.
#
# @markdown This structure allows you to move your data around without having to
# @markdown re-embed the dataset.  The generated embedding database will be
# @markdown placed in the base path. This allows you to simply swap out
# @markdown the base path here if you ever move your dataset.

# @markdown By default we only process one dataset at a time.  Re-run this entire notebook
# @markdown once per dataset.

# @markdown For example, we might set dataset_base_path to '/home/me/myproject',
# @markdown and use the glob '\*/\*.wav' if all of the audio files have filepaths
# @markdown like '/home/me/myproject/site_XYZ/audio_ABC.wav' (e.g. audio files are contained in subfolders of the base directory).

# @markdown 1. Create a unique name for the database that will store the embeddings for the target data.
dataset_name = 'WhaleDenoising'  # @param {type:'string'}
# @markdown 2. Input the filepath for the folder that is containing the input audio files.
# dataset_base_path = ''  #@param {type:'string'}
# dataset_base_path = '/content/drive/MyDrive/whale_denoising'
dataset_base_path = drive_output_directory
# @markdown 3. Input the file pattern for the audio files within that folder that you want to embed. Some examples for how to input:
# @markdown - All files in the base directory of a specific type (not subdirectories): e.g. `*.wav` (or `*.flac` etc) will generate embeddings for all .wav files (or whichever format) in the dataset_base_path
# @markdown - All files in one level of subdirectories within the base directory: `*/*.flac` will generate embeddings for all .flac files
# @markdown - Single file: `myfile.wav` will only embed the audio from that specific file.
dataset_fileglob = '*/*.wav'  # @param {type:'string'}

# @markdown 4. [Optional] If saving the embeddings database to a new directory, specify here.
# @markdown Otherwise, leave blank - by default the embeddings database output will be saved within
# @markdown dataset_base_path where the audio is located. You do not need to specify db_path unless you want to maintain multiple
# @markdown distinct embedding databases, or if you would like to save the output
# @markdown in a different folder. If your input audio data is accessed
# @markdown from a public URL, we recommend specifying a separate output directory here.
# db_path = '/content/drive/MyDrive/whale_denoising'  # @param {type:'string'}
db_path = drive_output_directory
if not db_path or db_path == 'None':
  db_path = None

# @markdown 5. Choose a supported model to generate embeddings: `perch_8` or `birdnet_v2.3` are most common
# @markdown for birds. Other choices include `surfperch` for coral reefs or
# @markdown `multispecies_whale` for marine mammals.
model_choice = 'humpback'  #@param['humpback', 'multispecies_whale', 'perch_8', 'surfperch', 'birdnet_V2.3']

# @markdown 6. [Optional] Shard the audio for embeddings. File sharding automatically splits audio files into smaller chunks
# @markdown for creating embeddings. This limits both system and GPU memory usage,
# @markdown especially useful when working with long files (>1 hour).
use_file_sharding = True  # @param {type:'boolean'}
# @markdown If you want to change the length in seconds for the shards, specify here.
shard_length_in_seconds = 60  # @param {type:'number'}

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=dataset_base_path,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=-2,
    shard_len_s=float(shard_length_in_seconds) if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key='sqlite_usearch',
)
configs

In [None]:
#@title Initialize the hoplite database (DB) { vertical-output: true }
global db
db = configs.db_config.load_db()
num_embeddings = db.count_embeddings()

print('Initialized DB located at ', configs.db_config.db_config.db_path)

def drop_and_reload_db(_) -> interface.HopliteDBInterface:
  db_path = epath.Path(configs.db_config.db_config.db_path)
  for fp in db_path.glob('hoplite.sqlite*'):
    fp.unlink()
  (db_path / 'usearch.index').unlink()
  print('\n Deleted previous db at: ', configs.db_config.db_config.db_path)
  db = configs.db_config.load_db()

#@markdown If `drop_existing_db` set to True, when the database already exists and contains embeddings,
#@markdown then those existing embeddings will be erased. You will be prompted to confirm you wish to delete those existing
#@markdown embeddings. If you want to keep existing embeddings in the database, then set to False, which will append the new
#@markdown embeddings to the database.
drop_existing_db = True  #@param {type:'boolean'}

if num_embeddings > 0 and drop_existing_db:
  print('Existing DB contains datasets: ', db.get_dataset_names())
  print('num embeddings: ', num_embeddings)
  print('\n\nClick the button below to confirm you really want to drop the database at ')
  print(f'{configs.db_config.db_config.db_path}\n')
  print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\n')
  print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\n')

  button = widgets.Button(description='Delete database?')
  button.on_click(drop_and_reload_db)
  display(button)

In [None]:
#@title Run the embedding { vertical-output: true }

print(f'Embedding dataset: {audio_glob.dataset_name}')

worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config)

worker.process_all(target_dataset_name=audio_glob.dataset_name)

print('\n\nEmbedding complete, total embeddings: ', db.count_embeddings())

# Extract the embedding function

In [None]:
#@title Per dataset statistics { vertical-output: true }

for dataset in db.get_dataset_names():
  print(f'\nDataset \'{dataset}\':')
  print('\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])

In [None]:
#@title Show example embedding search
#@markdown As an example (and to show that the embedding process worked), this
#@markdown selects a single embedding from the database and outputs the embedding ids of the
#@markdown top-K (k = 128) nearest neighbors in the database.

q = db.get_embedding(db.get_one_embedding_id())
%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)
print([int(r.embedding_id) for r in results])

In [None]:
# Replace 1 with the actual embedding_id you want to display
embedding_id_to_display = 1

try:
    embedding_vector = db.get_embedding(embedding_id_to_display)
    print(f"Embedding vector for ID {embedding_id_to_display}:")
    display(embedding_vector)
except Exception as e:
    print(f"Error retrieving embedding for ID {embedding_id_to_display}: {e}")

In [None]:
display(db)

In [None]:
import sqlite3
import os

# Construct the full path to the SQLite database file
db_file_path = os.path.join(configs.db_config.db_config.db_path, 'hoplite.sqlite')

# Connect to the SQLite database
try:
    conn = sqlite3.connect(db_file_path)
    cursor = conn.cursor()
    print(f"Connected to the database at: {db_file_path}")
except sqlite3.Error as e:
    print(f"Error connecting to the database: {e}")

Now that you're connected to the database, you can execute SQL queries in the cells below.

Here are some example queries to get you started:

*   **List all tables:**

In [None]:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
print("Tables in the database:")
for table in tables:
    print(table[0])

In [None]:
# Replace 'your_table_name' with an actual table name from the output above
table_name = 'hoplite_labels'
try:
    cursor.execute(f"PRAGMA table_info({table_name});")
    columns = cursor.fetchall()
    print(f"\nColumns in the '{table_name}' table:")
    for col in columns:
        print(col[1])
except sqlite3.Error as e:
    print(f"Error querying table info: {e}")

In [None]:
try:
    cursor.execute("SELECT * FROM hoplite_labels;")
    rows = cursor.fetchall()
    print("Contents of the 'hoplite_labels' table:")
    for row in rows:
        print(row)
except sqlite3.Error as e:
    print(f"Error querying hoplite_labels table: {e}")

In [None]:
try:
    cursor.execute("SELECT * FROM hoplite_sources;")
    rows = cursor.fetchall()
    print("Contents of the 'hoplite_sources' table:")
    for row in rows:
        print(row)
except sqlite3.Error as e:
    print(f"Error querying hoplite_sources table: {e}")

In [None]:
try:
    # Delete existing labels
    cursor.execute("DELETE FROM hoplite_labels;")
    conn.commit()
    print("Successfully deleted existing labels from hoplite_labels table.")

    # Insert new labels based on source folder
    cursor.execute("""
        INSERT INTO hoplite_labels (embedding_id, label, type, provenance)
        SELECT
            he.id,
            CASE
                -- Check if the source path contains 'clean/'
                WHEN INSTR(hs.source, 'clean/') > 0 THEN 'clean'
                -- Check if the source path contains 'noisy/'
                WHEN INSTR(hs.source, 'noisy/') > 0 THEN 'noisy'
                ELSE 'unknown'
            END AS label,
            'user_provided' AS type,
            'labeled based on source folder' AS provenance
        FROM hoplite_embeddings he
        JOIN hoplite_sources hs ON he.source_idx = hs.id;
    """)
    conn.commit()
    print("Successfully populated hoplite_labels table.")

    # Check the counts of each label
    cursor.execute("SELECT label, COUNT(*) FROM hoplite_labels GROUP BY label;")
    label_counts = cursor.fetchall()
    print("\nCounts of labels in the hoplite_labels table:")
    for label, count in label_counts:
        print(f"Label: {label}, Count: {count}")

except sqlite3.Error as e:
    print(f"Error populating hoplite_labels table: {e}")

In [None]:
try:
    cursor.execute("SELECT * FROM hoplite_labels;")
    rows = cursor.fetchall()
    print("Contents of the 'hoplite_labels' table:")
    for row in rows:
        print(row)
except sqlite3.Error as e:
    print(f"Error querying hoplite_labels table: {e}")

# Train a Denoiser

In [None]:
# 📀 Denoising Dataset (whale-aware, no labels)
from torch.utils.data import Dataset, DataLoader
import torchaudio
import glob

class WhaleNoisyOnlyDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x, _ = torchaudio.load(self.paths[idx])
        return x.squeeze(0)  # return waveform only

# Use just the noisy paths (no labels, no pairing)
noisy_paths = sorted(glob.glob("/content/drive/MyDrive/whale_denoising/noisy/*.wav"))

noisy_dataset = WhaleNoisyOnlyDataset(noisy_paths)
train_loader = DataLoader(noisy_dataset, batch_size=16, shuffle=True)


In [None]:
import torch
import torch.nn as nn

class Denoiser1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=15, padding=7),
            nn.ReLU(),
            nn.Conv1d(32, 1, kernel_size=15, padding=7)
        )

    def forward(self, x):
        return self.net(x)

denoiser = Denoiser1D().cuda()
optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
mse_loss = nn.MSELoss()


In [None]:
from tqdm import tqdm

for epoch in range(10):
    denoiser.train()
    total_loss = 0
    for noisy, _ in tqdm(train_loader):
        noisy = noisy.unsqueeze(1).cuda()  # shape: (B, 1, T)
        denoised = denoiser(noisy)

        # Optional: Encourage identity if no noise model exists
        reconstruction_loss = mse_loss(denoised, noisy)

        # Whale-aware embedding loss
        with torch.no_grad():
            embed_noisy = embed_fn(noisy.squeeze(1)).detach()
        embed_denoised = embed_fn(denoised.squeeze(1))
        embedding_loss = ((embed_denoised - embed_noisy)**2).mean()

        total = reconstruction_loss + 0.1 * embedding_loss

        optimizer.zero_grad()
        total.backward()
        optimizer.step()
        total_loss += total.item()

    print(f"Epoch {epoch+1}, loss: {total_loss/len(train_loader):.4f}")
