# RunModel

This notebook embeds all WAV files in a folder into a Hoplite DB, loads a previously-trained AGILE linear classifier, and writes an inference CSV.


In [None]:
# @title Imports
from etils import epath
import os

from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import classifier
from perch_hoplite.agile import colab_utils
from perch_hoplite.agile import embed
from perch_hoplite.agile import source_info
from perch_hoplite.agile.classifier import LinearClassifier
from perch_hoplite.zoo import model_configs


In [None]:
# @title Configuration { vertical-output: true }

# -----------------------------
# PATHS (Local vs Colab)
# -----------------------------

# For running locally (Rod)
base_agile_path = epath.Path(
    '/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile'
)

# For running in Colab
# from google.colab import drive
# drive.mount('/content/drive')
# base_agile_path = epath.Path('/content/drive/Shareddrives/MAR FUTURA/Agile')

# -----------------------------
# USER SETTINGS
# -----------------------------

# Folder containing audio to classify.
input_audio_dir = "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Unidades compartidas/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"
dataset_name = 'RunDataset'  # @param {type:'string'}
dataset_fileglob = '*.[wW][aA][vV]'  # @param {type:'string'}

# Where to store the embedding DB for this run.
db_path = "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Unidades compartidas/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"

# Saved classifier created in CreateModel.ipynb (LinearClassifier.save).
classifier_path = str(base_agile_path / 'Data' / 'agile_classifier_v2.pt')  # @param {type:'string'}

# Output CSV path.
output_csv_filepath = str(base_agile_path / 'RunResults' / 'inference.csv')  # @param {type:'string'}

# Embedding model choice MUST match how you embedded when you trained the classifier.
model_choice = 'perch_8'  #@param['perch_v2','perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']

# Optional sharding (keep consistent with training if possible).
use_file_sharding = True  # @param {type:'boolean'}
shard_length_in_seconds = 5  # @param {type:'number'}

# Performance knobs
# - audio_worker_threads: parallel audio loading/processing
# - embed_batch_size: how many sources are queued per dispatch
# If you overload your machine, lower these.
audio_worker_threads = 8  # @param {type:'integer'}
embed_batch_size = 32  # @param {type:'integer'}

# Inference threshold. Higher => fewer detections.
logit_threshold = 2  # @param
labels = None  # @param

# Create output folder.
epath.Path(output_csv_filepath).parent.mkdir(parents=True, exist_ok=True)
epath.Path(db_path).mkdir(parents=True, exist_ok=True)

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=input_audio_dir,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=-2,
    shard_len_s=float(shard_length_in_seconds) if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key='sqlite_usearch',
)

# Correcting the model handle for surfperch
if model_choice == 'surfperch':
  configs.model_config.model_config.tfhub_path = 'google/surfperch/1'

print('input_audio_dir:', input_audio_dir)
print('db_path:', db_path)
print('classifier_path:', classifier_path)
print('output_csv_filepath:', output_csv_filepath)
print('audio_worker_threads:', audio_worker_threads)
print('embed_batch_size:', embed_batch_size)


In [None]:
#@title Embed folder, load classifier, and run inference { vertical-output: true }

# 1) Connect/create DB
db = configs.db_config.load_db()
print('Initialized DB located at', configs.db_config.db_config.db_path)

# 2) Embed all files in the folder
print(f'Embedding dataset: {audio_glob.dataset_name}')
worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config,
    audio_worker_threads=int(audio_worker_threads),
)
worker.process_all(target_dataset_name=audio_glob.dataset_name, batch_size=int(embed_batch_size))
print('Embedding complete, total embeddings:', db.count_embeddings())

# 3) Load embedding model (needed for audio loader in some workflows; kept for parity)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
window_size_s = getattr(embedding_model, 'window_size_s', 5.0)
_ = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)

# 4) Load trained classifier and write inference CSV
linear_classifier = LinearClassifier.load(classifier_path)
classifier.write_inference_csv(
    linear_classifier,
    db,
    output_csv_filepath,
    logit_threshold,
    labels=labels,
)
print('Done. Wrote:', output_csv_filepath)


In [None]:
#@title Plot detections over time (detections/hour)

import re
import pandas as pd
import matplotlib.pyplot as plt

csv_path = output_csv_filepath

df = pd.read_csv(csv_path)
print('rows:', len(df))
print('columns:', list(df.columns))

# Optional: focus on one label (e.g. boat). Set to None to include all labels.
focus_label = 'boat'  # @param {type:'string'}
if focus_label and 'label' in df.columns:
  df = df[df['label'] == focus_label]

# Parse datetime from filename like YYYYMMDD_HHMMSS(.WAV)
# Example: ZAPALLAR_20241122_143550_5sec.wav -> 20241122_143550
_dt_re = re.compile(r'(\d{8})_(\d{6})')

def extract_dt(fname: str):
  m = _dt_re.search(str(fname))
  if not m:
    return pd.NaT
  return pd.to_datetime(m.group(1) + m.group(2), format='%Y%m%d%H%M%S', errors='coerce')

df['file_dt'] = df['filename'].apply(extract_dt)

# If window_start exists, shift timestamp by that many seconds.
if 'window_start' in df.columns:
  df['window_start'] = pd.to_numeric(df['window_start'], errors='coerce')
  df['dt'] = df['file_dt'] + pd.to_timedelta(df['window_start'].fillna(0), unit='s')
else:
  df['dt'] = df['file_dt']

# Drop rows where we can't parse time
plot_df = df.dropna(subset=['dt']).copy()
if plot_df.empty:
  raise RuntimeError('No rows had a parseable datetime in filename. Adjust extract_dt() regex/format.')

plot_df = plot_df.set_index('dt').sort_index()

detections_per_hour = plot_df['idx'].resample('1H').count()

plt.figure(figsize=(12, 4))
detections_per_hour.plot()
plt.title(f'Detections per hour' + (f" ({focus_label})" if focus_label else ''))
plt.xlabel('Time')
plt.ylabel('Detections / hour')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Optional: show daily totals too
show_daily = False  # @param {type:'boolean'}
if show_daily:
  daily = plot_df['idx'].resample('1D').count()
  plt.figure(figsize=(12, 4))
  daily.plot()
  plt.title(f'Detections per day' + (f" ({focus_label})" if focus_label else ''))
  plt.xlabel('Date')
  plt.ylabel('Detections / day')
  plt.grid(True, alpha=0.3)
  plt.tight_layout()
  plt.show()
