In [None]:
#@title Imports
from IPython.display import display
import ipywidgets as widgets
from etils import epath

from perch_hoplite.agile import colab_utils
from perch_hoplite.agile import embed
from perch_hoplite.agile import source_info
from perch_hoplite.db import interface
from perch_hoplite.db import brutalism

## Embed

In [None]:
#@title Configuration { vertical-output: true }

#@markdown Configure the raw dataset location(s).  The format is a mapping from
#@markdown a dataset_name to a (base_path, fileglob) pair.  Note that the file
#@markdown globs are case sensitive.  The dataset name can be anything you want.
#
#@markdown This structure allows you to move your data around without having to
#@markdown re-embed the dataset.  The generated embedding database will be
#@markdown placed in the base path. This allows you to simply swap out
#@markdown the base path here if you ever move your dataset.

#@markdown By default we only process one dataset at a time.  Re-run this entire notebook
#@markdown once per dataset.  The embeddings database will be located in the
#@markdown database_base_path.

#@markdown For example, we might set dataset_base_path to '/home/me/myproject',
#@markdown and use the glob '*/*.wav' if all of the audio files have filepaths
#@markdown like '/home/me/myproject/site_XYZ/audio_ABC.wav'
dataset_name = ''  #@param {type:'string'}
dataset_base_path = ''  #@param {type:'string'}
dataset_fileglob = '*.wav'  #@param {type:'string'}
#@markdown You do not need to change this unless you want to maintain multiple
#@markdown distinct embedding databases.
db_path = None #@param

#@markdown Choose a supported model: `perch_8` or `birdnet_v2.3` are most common
#@markdown for birds. Other choices include `surfperch` for coral reefs or
#@markdown `multispecies_whale` for marine mammals.
model_choice = 'perch_8'  #@param['perch_8', 'humpback', 'multispecies_whale', 'surfperch', 'birdnet_V2.3']

#@markdown File sharding automatically splits audio files into one-minute chunks
#@markdown for embedding. This limits both system and GPU memory usage,
#@markdown especially useful when working with long files (>1 hour).
use_file_sharding = True  #@param {type:'boolean'}

audio_glob = source_info.AudioSourceConfig(
    dataset_name=dataset_name,
    base_path=dataset_base_path,
    file_glob=dataset_fileglob,
    min_audio_len_s=1.0,
    target_sample_rate_hz=-2,
    shard_len_s=60.0 if use_file_sharding else None,
)

configs = colab_utils.load_configs(
    source_info.AudioSources((audio_glob,)),
    db_path,
    model_config_key=model_choice,
    db_key = 'sqlite_usearch')
configs

In [None]:
#@title Initialize the DB { vertical-output: true }
global db
db = configs.db_config.load_db()
num_embeddings = db.count_embeddings()

print('Initialized DB located at ', configs.db_config.db_config.db_path)

def drop_and_reload_db(_) -> interface.HopliteDBInterface:
  db_path = epath.Path(configs.db_config.db_config.db_path)
  for fp in db_path.glob('hoplite.sqlite*'):
    fp.unlink()
  (db_path / 'usearch.index').unlink()
  print('\n Deleted previous db at: ', configs.db_config.db_config.db_path)
  db = configs.db_config.load_db()

drop_existing_db = True  #@param[True, False]

if num_embeddings > 0 and drop_existing_db:
  print('Existing DB contains datasets: ', db.get_dataset_names())
  print('num embeddings: ', num_embeddings)
  print('\n\nClick the button below to confirm you really want to drop the database at ')
  print(f'{configs.db_config.db_config.db_path}\n')
  print(f'This will permanently delete all {num_embeddings} embeddings from the existing database.\n')
  print('If you do NOT want to delete this data, set `drop_existing_db` above to `False` and re-run this cell.\n')

  button = widgets.Button(description='Delete database?')
  button.on_click(drop_and_reload_db)
  display(button)

In [None]:
#@title Run the embedding { vertical-output: true }

print(f'Embedding dataset: {audio_glob.dataset_name}')

worker = embed.EmbedWorker(
    audio_sources=configs.audio_sources_config,
    db=db,
    model_config=configs.model_config)

worker.process_all(target_dataset_name=audio_glob.dataset_name)

print('\n\nEmbedding complete, total embeddings: ', db.count_embeddings())

In [None]:
#@title Per dataset statistics { vertical-output: true }

for dataset in db.get_dataset_names():
  print(f'\nDataset \'{dataset}\':')
  print('\tnum embeddings: ', db.get_embeddings_by_source(dataset, source_id=None).shape[0])

In [None]:
#@title Show example embedding

q = db.get_embedding(db.get_one_embedding_id())
%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)
print([int(r.embedding_id) for r in results])