<a href="https://colab.research.google.com/github/rodgpt/MAR_FUTURA/blob/main/Boat%20Detector/Agile/RunModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RunModel

This notebook embeds all WAV files in a folder into a Hoplite DB, loads a previously-trained AGILE linear classifier, and writes an inference CSV.


In [1]:
# @title Imports
#!pip install git+https://github.com/google-research/perch-hoplite.git
from etils import epath
import os

from perch_hoplite.agile import audio_loader
from perch_hoplite.agile import classifier
from perch_hoplite.agile import colab_utils
from perch_hoplite.agile import embed
from perch_hoplite.agile import source_info
from perch_hoplite.agile.classifier import LinearClassifier
from perch_hoplite.zoo import model_configs
from pathlib import Path
import shutil # Import shutil for file operations



  if not hasattr(np, "object"):


In [None]:
# @title Configuration

# -----------------------------
# PATHS (Local vs Colab)
# -----------------------------


# For running in Colab
#from google.colab import drive
#drive.mount('/content/drive')

#input_audio_dir = "/content/drive/Shareddrives/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"
#db_path =         "/content/drive/Shareddrives/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"
#classifier_path = "content/drive/MyDrive/Agile/ 'agile_classifier_v2.pt'"
#output_csv_filepath = "/content/drive/Shareddrives/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32/inference.csv"


#For running locally
input_audio_dir =     "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Unidades compartidas/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"
db_path =             "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Unidades compartidas/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32"
output_csv_filepath = "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Unidades compartidas/MAR FUTURA/Hydrophones/Matanzas/13-11-25/32/inference.csv"

input_audio_dir      = "/Volumes/Untitled"
db_path              = "/Volumes/Untitled"
output_csv_filepath  = "/Volumes/Untitled/inference.csv"

classifier_path = "/Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile/ 'agile_classifier_v2.pt'"


# -----------------------------
# USER SETTINGS
# -----------------------------
dataset_name = 'RunDataset'
dataset_fileglob = '*.[wW][aA][vV]'

# Embedding model choice MUST match how you embedded when you trained the classifier.
model_choice = 'perch_8'

# Optional sharding (keep consistent with training if possible).
use_file_sharding = True
shard_length_in_seconds = 5

# Performance knobs
# - audio_worker_threads: parallel audio loading/processing
# - embed_batch_size: how many sources are queued per dispatch
# If you overload your machine, lower these.
audio_worker_threads = 8
embed_batch_size = 32

# Inference threshold. Higher => fewer detections.
logit_threshold = 2
labels = None


# Create directories
Path(output_csv_filepath).parent.mkdir(parents=True, exist_ok=True)
Path(db_path).mkdir(parents=True, exist_ok=True)

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=input_audio_dir,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=-2,
    shard_len_s=float(shard_length_in_seconds) if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key='sqlite_usearch',
)

# Correcting the model handle for surfperch
if model_choice == 'surfperch':
  configs.model_config.model_config.tfhub_path = 'google/surfperch/1'

print('input_audio_dir:', input_audio_dir)
print('db_path:', db_path)
print('classifier_path:', classifier_path)
print('output_csv_filepath:', output_csv_filepath)
print('audio_worker_threads:', audio_worker_threads)
print('embed_batch_size:', embed_batch_size)


input_audio_dir: /Volumes/Untitled
db_path: /Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile/outputs/db
classifier_path: /Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile/ 'agile_classifier_v2.pt'
output_csv_filepath: /Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile/outputs/my_outputTRY.csv
audio_worker_threads: 8
embed_batch_size: 32


In [None]:
#@title Embed folder, load classifier, and run inference

# 1) Connect/create DB
db = configs.db_config.load_db()
print('Initialized DB located at', configs.db_config.db_config.db_path)

# 2) Embed all files in the folder
print(f'Embedding dataset: {audio_glob.dataset_name}')
worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config,
    audio_worker_threads=int(audio_worker_threads),
)
worker.process_all(target_dataset_name=audio_glob.dataset_name, batch_size=int(embed_batch_size))
print('Embedding complete, total embeddings:', db.count_embeddings())

# 3) Load embedding model (needed for audio loader in some workflows; kept for parity)
db_model_config = db.get_metadata('model_config')
embed_config = db.get_metadata('audio_sources')
model_class = model_configs.get_model_class(db_model_config.model_key)
embedding_model = model_class.from_config(db_model_config.model_config)
audio_sources = source_info.AudioSources.from_config_dict(embed_config)
window_size_s = getattr(embedding_model, 'window_size_s', 5.0)
_ = audio_loader.make_filepath_loader(
    audio_sources=audio_sources,
    window_size_s=window_size_s,
    sample_rate_hz=embedding_model.sample_rate,
)

# 4) Load trained classifier and write inference CSV
linear_classifier = LinearClassifier.load(classifier_path)
classifier.write_inference_csv(
    linear_classifier,
    db,
    output_csv_filepath,
    logit_threshold,
    labels=labels,
)
print('Done. Wrote:', output_csv_filepath)


Initialized DB located at /Users/Rodrigo/Library/CloudStorage/GoogleDrive-royanedel@marfutura.org/Mi unidad/Agile/outputs/db
Embedding dataset: RunDataset


	Audioread support is deprecated in librosa 0.10.0 and will be removed in version 1.0.
  sr = librosa.get_samplerate(filepath)
ERROR:root:Failed to parse audio file (/Volumes/Untitled/20251203_165600.WAV) : .
	Audioread support is deprecated in librosa 0.10.0 and will be removed in version 1.0.
  sr = librosa.get_samplerate(filepath)
ERROR:root:Failed to parse audio file (/Volumes/Untitled/20251203_202800.WAV) : .
100%|██████████| 15337/15337 [10:01<00:00, 25.50it/s]
  0%|          | 0/15337 [00:00<?, ?it/s]2026-01-06 15:58:54.412411: I external/local_xla/xla/service/service.cc:163] XLA service 0x72bc55e00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2026-01-06 15:58:54.412618: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): Host, Default Version
2026-01-06 15:58:54.604045: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTO

In [None]:
#@title Plot detections over time (detections/hour)

import re
import pandas as pd
import matplotlib.pyplot as plt

csv_path = output_csv_filepath

df = pd.read_csv(csv_path)
print('rows:', len(df))
print('columns:', list(df.columns))

# Optional: focus on one label (e.g. boat). Set to None to include all labels.
focus_label = 'boat'
if focus_label and 'label' in df.columns:
  df = df[df['label'] == focus_label]

# Parse datetime from filename like YYYYMMDD_HHMMSS(.WAV)
# Example: ZAPALLAR_20241122_143550_5sec.wav -> 20241122_143550
_dt_re = re.compile(r'(\d{8})_(\d{6})')

def extract_dt(fname: str):
  m = _dt_re.search(str(fname))
  if not m:
    return pd.NaT
  return pd.to_datetime(m.group(1) + m.group(2), format='%Y%m%d%H%M%S', errors='coerce')

df['file_dt'] = df['filename'].apply(extract_dt)

# If window_start exists, shift timestamp by that many seconds.
if 'window_start' in df.columns:
  df['window_start'] = pd.to_numeric(df['window_start'], errors='coerce')
  df['dt'] = df['file_dt'] + pd.to_timedelta(df['window_start'].fillna(0), unit='s')
else:
  df['dt'] = df['file_dt']

# Drop rows where we can't parse time
plot_df = df.dropna(subset=['dt']).copy()
if plot_df.empty:
  raise RuntimeError('No rows had a parseable datetime in filename. Adjust extract_dt() regex/format.')

plot_df = plot_df.set_index('dt').sort_index()

detections_per_hour = plot_df['idx'].resample('1H').count()

plt.figure(figsize=(12, 4))
detections_per_hour.plot()
plt.title(f'Detections per hour' + (f" ({focus_label})" if focus_label else ''))
plt.xlabel('Time')
plt.ylabel('Detections / hour')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Optional: show daily totals too
show_daily = True

if show_daily:
  daily = plot_df['idx'].resample('1D').count()
  plt.figure(figsize=(12, 4))
  daily.plot()
  plt.title(f'Detections per day' + (f" ({focus_label})" if focus_label else ''))
  plt.xlabel('Date')
  plt.ylabel('Detections / day')
  plt.grid(True, alpha=0.3)
  plt.tight_layout()
  plt.show()
