# Overview

This notebook uses perch-hoplite to compute and save embeddings for set of audio files using a pre-trained model. This is the first step in the agile modeling process. If the data you wish to search and classify is already embedded with a pre-trained model into a perch-hoplite database, then proceed to the step 2 colab notebook ([2_agile_modeling_v2.ipynb](https://github.com/google-research/perch-hoplite/blob/main/perch_hoplite/agile/2_agile_modeling_v2.ipynb)).

## [Optional] perch-hoplite installation for hosted runtimes

If you have not already installed perch-hoplite (particularly if you are using a hosted Colab runtime), make sure to install perch-hoplite from the Github source to ensure the most recent version is installed. After installation, you will need to restart your runtime before running anything else. Go to the top menu, select "Runtime" then "Restart Session".

**If you want to use the Perch V2 model, you must be on a GPU runtime and additionally install TensorFlow version 2.20.rc0 (experimental).**

In [None]:
# !pip install tensorflow[and-cuda]~=2.20.0rc0

In [None]:
from etils import epath
from IPython.display import display
# import ipywidgets as widgets
import numpy as np
from agile import colab_utils
from agile import embed
from agile import source_info
from db import brutalism
from db import interface

# Embed the audio data

In [None]:
dataset_name = 'testset_pond'  # @param {type:'string'}
dataset_base_path = '/home/reindert/Valentin_REVO/agile_modeling_freshwater/Data/test_audio/'  #@param {type:'string'}
dataset_fileglob = '*.flac'  # @param {type:'string'}
db_path = '/home/reindert/Valentin_REVO/agile_modeling_freshwater/db/'  # @param {type:'string'}
if not db_path or db_path == 'None':
  db_path = None

model_choice = 'perch_v2'  #@param['perch_v2','perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']

use_file_sharding = True  # @param {type:'boolean'}
shard_length_in_seconds = 60  # @param {type:'number'}

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=dataset_base_path,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=-2,
    shard_len_s=float(shard_length_in_seconds) if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key='sqlite_usearch',
)
configs

In [None]:
global db
db = configs.db_config.load_db()
num_embeddings = db.count_embeddings()

print('Initialized DB located at ', configs.db_config.db_config.db_path)

def drop_and_reload_db(_) -> interface.HopliteDBInterface:
  db_path = epath.Path(configs.db_config.db_config.db_path)
  for fp in db_path.glob('hoplite.sqlite*'):
    fp.unlink()
  (db_path / 'usearch.index').unlink()
  print('\n Deleted previous db at: ', configs.db_config.db_config.db_path)
  db = configs.db_config.load_db()

drop_existing_db = False  #@param {type:'boolean'}

if num_embeddings > 0 and drop_existing_db:
  print('Existing DB contains datasets: ', db.get_dataset_names())
  print('num embeddings: ', num_embeddings)
  print('\n\nClick the button below to confirm you really want to drop the database at ')
  print(f'{configs.db_config.db_config.db_path}\n')
  print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\n')
  print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\n')

  button = widgets.Button(description='Delete database?')
  button.on_click(drop_and_reload_db)
  display(button)

In [None]:
print(f'Embedding dataset: {audio_glob.dataset_name}')

worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config)

worker.process_all(target_dataset_name=audio_glob.dataset_name)

print('\n\nEmbedding complete, total embeddings: ', db.count_embeddings())

In [None]:
#@title Per dataset statistics { vertical-output: true }

for dataset in db.get_dataset_names():
  print(f'\nDataset \'{dataset}\':')
  print('\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])

In [None]:
#@title Show example embedding search
#@markdown As an example (and to show that the embedding process worked), this
#@markdown selects a single embedding from the database and outputs the embedding ids of the
#@markdown top-K (k = 128) nearest neighbors in the database.

q = db.get_embedding(db.get_one_embedding_id())
results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)
print([int(r.embedding_id) for r in results])