# Extract latents to create time series of latent USV representations

## Imports

In [1]:
import numpy as np
import pandas as pd
import os
import glob
import matplotlib.pyplot as plt

In [2]:
from ava.data.data_container import DataContainer
from ava.plotting.grid_plot import indexed_grid_plot_DC
import audio_utils.io
from itertools import repeat
from joblib import Parallel, delayed
from ava.preprocessing.utils import get_spec
from ava.models.vae import X_SHAPE, VAE
from ava.preprocessing.preprocess import process_sylls
import torch

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [3]:
%matplotlib inline

## Ava project path

In [4]:
root = '/mnt/labNAS/usv_calls/usv_note_analysis/03_div_cage_group01_18_song_empty'
audio_dirs = [os.path.join(root, 'audio')]
seg_dirs = [os.path.join(root, 'segs')]
proj_dirs = [os.path.join(root, 'projections')]
spec_dirs = [os.path.join(root, 'specs')]
model_filename = os.path.join(root, 'checkpoint_200.tar')
plots_dir = root
dc = DataContainer(projection_dirs=proj_dirs, spec_dirs=spec_dirs, plots_dir=plots_dir, model_filename=model_filename)

## Preprocess

In [None]:
params = {
    'preprocess': {
        'get_spec': get_spec,
        'max_dur': 0.2, # maximum syllable duration
        'min_freq': 10e3, # minimum frequency
        'max_freq': 100e3, # maximum frequency
        'num_freq_bins': X_SHAPE[0], # hard-coded
        'num_time_bins': X_SHAPE[1], # hard-coded
        'nperseg': 1024, # FFT
        'noverlap': 512, # FFT
        'spec_min_val': 1.5, # minimum log-spectrogram value
        'spec_noise_thres': 2.5, # if no value in spectogram is above this, co
        'spec_max_val': 4.5, # maximum log-spectrogram value
        'fs': 250000, # audio samplerate
        'mel': False, # frequency spacing, mel or linear
        'time_stretch': True, # stretch short syllables?
        'within_syll_normalize': True, # normalize spectrogram values on a
                                        # spectrogram-by-spectrogram basis
        'normalize_quantile': 0.0, # throw away values below this quantile
        'max_num_syllables': None, # maximum number of syllables per directory
        'sylls_per_file': 1, # syllable per file
        'sigma': 1,
        'flip': False,
        'notch_filter': False, # whether to apply notch filter for 60 kHz noise
        'fn': 54e3, # frequency of notch filter
        'q': 30, # quality factor
        'real_preprocess_params': ('min_freq', 'max_freq', 'spec_min_val', \
                'spec_max_val', 'max_dur'), # tunable parameters
        'int_preprocess_params': ('nperseg','noverlap'), # tunable parameters
        'binary_preprocess_params': ('time_stretch', 'mel', \
                'within_syll_normalize'), # tunable parameters
    },
}

In [None]:
n_jobs = os.cpu_count()-1
gen = zip(audio_dirs, seg_dirs, spec_dirs, repeat(params['preprocess']))
Parallel(n_jobs=n_jobs)(delayed(process_sylls)(*args) for args in gen)

In [5]:
onsets = dc.request('onsets')
offsets = dc.request('offsets')
ava_embedding_detections = pd.DataFrame(np.vstack([onsets, offsets]).T, columns=['onsets', 'offsets'])
ava_embedding_detections = ava_embedding_detections.sort_values('onsets',ascending=True)
ava_embedding_detections

Reading field: onsets
	Done with: onsets
Reading field: offsets
	Done with: offsets


Unnamed: 0,onsets,offsets
12160,0.49950,0.57300
12161,0.59950,0.68040
21356,0.66490,0.74190
12162,0.69940,0.78140
12163,0.80190,0.87090
...,...,...
8767,3784.29559,3784.35458
8768,3784.45457,3784.49856
19103,3791.94740,3792.02490
19104,3792.48640,3792.54380


In [6]:
usvseg_filepaths = glob.glob(os.path.join(root, 'combined detection files', '*_locations.csv'))
usvseg_detections = pd.concat([pd.read_csv(f) for f in usvseg_filepaths], ignore_index=True).rename(columns={'xEnd': 'end'})
usvseg_detections = usvseg_detections.sort_values('start',ascending=True)
usvseg_detections = usvseg_detections[usvseg_detections.code != 6]
offset = 0.015
usvseg_detections['start'] -= offset
usvseg_detections['end'] += offset
origin = {0: 1, 1: 0, 2: 1, 3: 0, 4: 1, 5: 0}
usvseg_detections['Left'] = usvseg_detections['code'].map(origin)
usvseg_detections = usvseg_detections.drop('detection_side', axis=1)
usvseg_detections

Unnamed: 0.1,Unnamed: 0,start,end,duration,maxfreq,maxamp,meanfreq,cvfreq,in_song,usvseg_index,...,r_3_y,r_4_x,r_4_y,r_5_x,r_5_y,r_6_x,r_6_y,hour,time,Left
12897,0,0.4995,0.5730,43.5,33.518,-74.43,33.371,0.0732,False,0,...,138.881461,346.171137,122.701708,356.390433,111.148986,367.034968,100.514674,2,2022-12-08 12:57:57,0
12898,1,0.5995,0.6804,51.0,35.249,-73.59,38.586,0.2041,False,1,...,139.330706,345.351129,125.206485,355.580691,112.454785,367.640277,102.192295,2,2022-12-08 12:57:57,0
12899,2,0.6649,0.7419,47.0,24.462,-65.25,27.155,0.1779,False,0,...,140.267890,345.325371,126.199516,355.556774,113.738124,368.286718,103.502654,2,2022-12-08 12:57:57,1
12900,3,0.6994,0.7814,52.0,26.710,-67.75,36.812,0.3191,False,2,...,140.591894,345.016189,126.852271,355.217769,114.697781,368.286638,105.112831,2,2022-12-08 12:57:57,0
12901,4,0.8019,0.8709,39.0,26.129,-58.83,31.683,0.2641,False,3,...,141.016699,344.934351,128.655562,355.560531,117.130088,370.816151,110.737177,2,2022-12-08 12:57:57,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6037,6037,3775.7858,3775.8318,16.0,20.112,-63.25,21.836,0.0644,False,3525,...,259.526129,361.631232,248.044911,347.613934,240.498409,341.226128,231.621510,0,2022-12-08 10:52:25,0
6043,6043,3776.8767,3776.9237,17.0,38.245,-62.17,39.073,0.0300,False,4664,...,260.778644,361.024796,245.537210,346.337154,240.465739,334.910936,237.945167,0,2022-12-08 10:52:25,1
6044,6044,3791.9474,3792.0249,47.5,20.134,-58.69,20.996,0.0575,False,4666,...,,,,,,,,0,2022-12-08 10:52:25,1
6045,6045,3792.4864,3792.5438,27.5,20.356,-68.72,22.576,0.0815,False,4667,...,,,,,,,,0,2022-12-08 10:52:25,1


In [7]:
song_filepaths = glob.glob(os.path.join(root, 'song detection files', '*_locations.csv'))
song_detections = pd.concat([pd.read_csv(f) for f in song_filepaths], ignore_index=True)
song_detections = song_detections.sort_values('start',ascending=True)
side = {'l':1, 'r':0}
song_detections['Left'] = song_detections['source'].map(side)
song_detections = song_detections.drop('source', axis=1)
song_detections

Unnamed: 0,#,start,end,duration,song_idx,l_1_x,l_1_y,l_2_x,l_2_y,l_3_x,...,r_3_y,r_4_x,r_4_y,r_5_x,r_5_y,r_6_x,r_6_y,hour,time,Left
2004,0,1.668804,1.686802,0.017998,0,277.475870,124.086723,265.513814,144.267367,289.444058,...,141.813165,345.369169,130.364942,358.084026,120.147029,373.369943,112.495091,2,2022-12-08 12:57:57,1
947,0,75.461669,75.473667,0.011999,0,274.340842,203.563430,280.621399,227.523872,298.193699,...,213.410980,345.359810,200.656331,360.604577,189.148283,375.959135,178.929658,1,2022-12-08 11:56:06,1
948,1,75.548658,75.561657,0.012998,0,274.306959,203.541976,278.084473,227.559155,298.173519,...,213.440327,344.695769,200.668791,359.350282,189.774713,375.898824,180.178736,1,2022-12-08 11:56:06,1
949,2,75.714638,75.727637,0.012998,0,274.305464,203.570065,278.087578,227.572733,298.159874,...,213.420018,344.113471,201.884115,359.998850,190.433075,377.199647,180.182538,1,2022-12-08 11:56:06,1
950,3,75.741635,75.755633,0.013998,0,273.120568,203.575236,278.066018,227.579165,298.159203,...,213.403902,344.113796,200.702175,359.388655,190.446227,377.162420,180.210364,1,2022-12-08 11:56:06,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
942,942,3783.891637,3783.951630,0.059993,12,266.824083,249.079578,236.678971,251.888873,286.609501,...,268.679964,346.976788,256.318619,338.689713,245.549012,344.110474,232.502373,0,2022-12-08 10:52:25,1
943,943,3784.023621,3784.086614,0.062992,12,267.158574,250.307073,248.365966,255.505798,287.546609,...,268.015684,346.661851,256.908301,338.686595,245.557117,343.152071,232.844127,0,2022-12-08 10:52:25,1
944,944,3784.160605,3784.220598,0.059993,12,269.353676,249.376027,254.058380,254.747772,289.795969,...,267.391566,345.399253,256.305443,337.739086,245.546384,342.857008,232.843338,0,2022-12-08 10:52:25,1
945,945,3784.295589,3784.354581,0.058993,12,269.070331,249.071061,254.368548,254.737409,290.734595,...,268.028284,346.640033,255.682495,338.689555,245.548073,343.795232,232.820272,0,2022-12-08 10:52:25,1


In [8]:
detections_locations_all = pd.concat([usvseg_detections,song_detections],axis=0)
detections_locations_all = detections_locations_all.sort_values('start',ascending=True)
detections_locations_all

Unnamed: 0.1,Unnamed: 0,start,end,duration,maxfreq,maxamp,meanfreq,cvfreq,in_song,usvseg_index,...,r_4_y,r_5_x,r_5_y,r_6_x,r_6_y,hour,time,Left,#,song_idx
12897,0.0,0.499500,0.573000,43.500000,33.518,-74.43,33.371,0.0732,False,0.0,...,122.701708,356.390433,111.148986,367.034968,100.514674,2,2022-12-08 12:57:57,0,,
12898,1.0,0.599500,0.680400,51.000000,35.249,-73.59,38.586,0.2041,False,1.0,...,125.206485,355.580691,112.454785,367.640277,102.192295,2,2022-12-08 12:57:57,0,,
12899,2.0,0.664900,0.741900,47.000000,24.462,-65.25,27.155,0.1779,False,0.0,...,126.199516,355.556774,113.738124,368.286718,103.502654,2,2022-12-08 12:57:57,1,,
12900,3.0,0.699400,0.781400,52.000000,26.710,-67.75,36.812,0.3191,False,2.0,...,126.852271,355.217769,114.697781,368.286638,105.112831,2,2022-12-08 12:57:57,0,,
12901,4.0,0.801900,0.870900,39.000000,26.129,-58.83,31.683,0.2641,False,3.0,...,128.655562,355.560531,117.130088,370.816151,110.737177,2,2022-12-08 12:57:57,0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
945,,3784.295589,3784.354581,0.058993,,,,,,,...,255.682495,338.689555,245.548073,343.795232,232.820272,0,2022-12-08 10:52:25,1,945.0,12.0
946,,3784.454569,3784.498564,0.043995,,,,,,,...,255.640257,336.157744,246.803809,344.229905,231.987125,0,2022-12-08 10:52:25,1,946.0,12.0
6044,6044.0,3791.947400,3792.024900,47.500000,20.134,-58.69,20.996,0.0575,False,4666.0,...,,,,,,0,2022-12-08 10:52:25,1,,
6045,6045.0,3792.486400,3792.543800,27.500000,20.356,-68.72,22.576,0.0815,False,4667.0,...,,,,,,0,2022-12-08 10:52:25,1,,


In [9]:
offset = 0
usvseg_detections.reset_index(drop=True, inplace=True)
for i in range(detections_locations_all.shape[0]):
    ava_onset = ava_embedding_detections.iloc[i,0]
    usvseg_onset = detections_locations_all.iloc[i,1]
    if abs(ava_onset+offset-usvseg_onset)>1e-3:
        #detections_locations_all.drop(i, inplace=True)
        ava_embedding_detections.drop(ava_embedding_detections.index[i], inplace=True)
        print(f'ava detection {i} not matching with usvseg detection {i}')
        i-=1

ava detection 15 not matching with usvseg detection 15
ava detection 22 not matching with usvseg detection 22
ava detection 28 not matching with usvseg detection 28
ava detection 36 not matching with usvseg detection 36
ava detection 39 not matching with usvseg detection 39
ava detection 47 not matching with usvseg detection 47
ava detection 48 not matching with usvseg detection 48
ava detection 49 not matching with usvseg detection 49
ava detection 50 not matching with usvseg detection 50
ava detection 51 not matching with usvseg detection 51
ava detection 52 not matching with usvseg detection 52
ava detection 53 not matching with usvseg detection 53
ava detection 54 not matching with usvseg detection 54
ava detection 55 not matching with usvseg detection 55
ava detection 56 not matching with usvseg detection 56
ava detection 57 not matching with usvseg detection 57
ava detection 58 not matching with usvseg detection 58
ava detection 59 not matching with usvseg detection 59
ava detect

ava detection 5991 not matching with usvseg detection 5991
ava detection 6006 not matching with usvseg detection 6006
ava detection 6007 not matching with usvseg detection 6007
ava detection 6010 not matching with usvseg detection 6010
ava detection 6078 not matching with usvseg detection 6078
ava detection 6102 not matching with usvseg detection 6102
ava detection 6119 not matching with usvseg detection 6119
ava detection 6120 not matching with usvseg detection 6120
ava detection 6126 not matching with usvseg detection 6126
ava detection 6134 not matching with usvseg detection 6134
ava detection 6154 not matching with usvseg detection 6154
ava detection 6164 not matching with usvseg detection 6164
ava detection 6173 not matching with usvseg detection 6173
ava detection 6174 not matching with usvseg detection 6174
ava detection 6229 not matching with usvseg detection 6229
ava detection 6240 not matching with usvseg detection 6240
ava detection 6254 not matching with usvseg detection 62

ava detection 11323 not matching with usvseg detection 11323
ava detection 11338 not matching with usvseg detection 11338
ava detection 11344 not matching with usvseg detection 11344
ava detection 11345 not matching with usvseg detection 11345
ava detection 11403 not matching with usvseg detection 11403
ava detection 11404 not matching with usvseg detection 11404
ava detection 11407 not matching with usvseg detection 11407
ava detection 11420 not matching with usvseg detection 11420
ava detection 11452 not matching with usvseg detection 11452
ava detection 11465 not matching with usvseg detection 11465
ava detection 11479 not matching with usvseg detection 11479
ava detection 11482 not matching with usvseg detection 11482
ava detection 11483 not matching with usvseg detection 11483
ava detection 11485 not matching with usvseg detection 11485
ava detection 11487 not matching with usvseg detection 11487
ava detection 11492 not matching with usvseg detection 11492
ava detection 11510 not 

ava detection 15855 not matching with usvseg detection 15855
ava detection 15903 not matching with usvseg detection 15903
ava detection 15928 not matching with usvseg detection 15928
ava detection 15947 not matching with usvseg detection 15947
ava detection 15978 not matching with usvseg detection 15978
ava detection 15994 not matching with usvseg detection 15994
ava detection 16036 not matching with usvseg detection 16036
ava detection 16038 not matching with usvseg detection 16038
ava detection 16060 not matching with usvseg detection 16060
ava detection 16076 not matching with usvseg detection 16076
ava detection 16134 not matching with usvseg detection 16134
ava detection 16163 not matching with usvseg detection 16163
ava detection 16171 not matching with usvseg detection 16171
ava detection 16315 not matching with usvseg detection 16315
ava detection 16322 not matching with usvseg detection 16322
ava detection 16358 not matching with usvseg detection 16358
ava detection 16387 not 

In [10]:
ava_embedding_detections.shape

(21837, 2)

## Embed latent means

In [11]:
latent_means = dc.request('latent_means')

Reading field: latent_means
	Done with: latent_means


In [12]:
# Reorder latents so that they are in temporal sequence
latent_means = latent_means[ava_embedding_detections.index,:]

## Cliff's manual labels

In [None]:
ava_embedding_detections['label'] = usvseg_detections['manual_type'].iloc[:ava_embedding_detections.shape[0]].copy()
ava_embedding_detections = pd.concat([ava_embedding_detections, pd.DataFrame(latent_means)],axis=1)

In [None]:
ava_embedding_detections.dropna(subset='label',inplace=True)
# Mapping from IDs to tokens, grouping multiple IDs to one token
tokenizer = {1:1,2:2,3:2,4:2,5:2,6:3,7:3,8:4,9:5,10:6,11:7,12:0}
ava_embedding_detections['label'] = [tokenizer[mtype] for mtype in ava_embedding_detections.label]

In [None]:
ava_embedding_detections = ava_embedding_detections[ava_embedding_detections.label != 0]
ava_embedding_detections

## Append latents

In [13]:
detections_locations_all.reset_index(drop=True, inplace=True)
latents = pd.DataFrame(latent_means)
locations_latents_all = pd.concat([detections_locations_all, latents],axis=1)
locations_latents_all

Unnamed: 0.1,Unnamed: 0,start,end,duration,maxfreq,maxamp,meanfreq,cvfreq,in_song,usvseg_index,...,22,23,24,25,26,27,28,29,30,31
0,0.0,0.499500,0.573000,43.500000,33.518,-74.43,33.371,0.0732,False,0.0,...,-0.321951,-0.004976,-0.009897,1.071670,-0.048639,0.014466,-0.004134,-0.351411,0.035195,-0.008947
1,1.0,0.599500,0.680400,51.000000,35.249,-73.59,38.586,0.2041,False,1.0,...,0.716621,-0.006194,-0.003253,1.243603,-0.055185,-0.005874,0.004700,-0.595085,0.042283,0.004533
2,2.0,0.664900,0.741900,47.000000,24.462,-65.25,27.155,0.1779,False,0.0,...,-0.149974,-0.040295,0.002160,1.758230,-0.038398,-0.068613,-0.005556,-0.312204,0.076351,0.004217
3,3.0,0.699400,0.781400,52.000000,26.710,-67.75,36.812,0.3191,False,2.0,...,1.128774,-0.019966,0.004840,1.402047,-0.060514,-0.014199,0.022991,-0.611689,0.057459,0.004964
4,4.0,0.801900,0.870900,39.000000,26.129,-58.83,31.683,0.2641,False,3.0,...,-0.687927,-0.004163,0.010576,1.201925,-0.041650,-0.019108,-0.000796,-0.507738,0.073795,0.005437
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21832,,3784.295589,3784.354581,0.058993,,,,,,,...,-0.110711,-0.127445,-0.078923,-0.920009,-0.028742,-0.014675,0.007098,1.521373,0.120173,-0.008331
21833,,3784.454569,3784.498564,0.043995,,,,,,,...,-0.215791,-0.097987,-0.106544,-1.091271,-0.055277,0.013731,0.021517,1.565341,0.124663,-0.019403
21834,6044.0,3791.947400,3792.024900,47.500000,20.134,-58.69,20.996,0.0575,False,4666.0,...,-0.847291,-0.075674,-0.061599,-0.845709,-0.039407,-0.055683,0.004596,0.194059,0.024943,0.019616
21835,6045.0,3792.486400,3792.543800,27.500000,20.356,-68.72,22.576,0.0815,False,4667.0,...,-0.500720,-0.052311,-0.025521,-0.018938,-0.024468,-0.055605,-0.043400,0.067176,0.037409,-0.005561


## Save

In [14]:
locations_latents_all.to_csv(seg_dirs[0]+'/locations_latents_all.csv', index=False)