## 0. Goal
Given spectrograms of entire syllables, chop systematically to get spec windows, apply the trained VAE model to get latents of all spec windows. <br>
This notebook focused on call analysis. <br>
Vary the sliding window duration

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import os, sys, importlib, librosa, glob, h5py, tqdm, pickle, gc
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from joblib import Parallel, delayed
import random
import umap, hdbscan
from collections import Counter
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from skimage import transform
import gc

plt.rcParams['pdf.fonttype'] = 42 

In [3]:
# import my utility script
cluster_script_path = '/home/zz367/ProjectsU/EphysMONAO/Jupyter/MatlabCodes/ZZ_callClustering/'
sys.path.insert(1, cluster_script_path)
import vae_goffinet, hopkins
importlib.reload(vae_goffinet)

<module 'vae_goffinet' from '/home/zz367/ProjectsU/EphysMONAO/Jupyter/MatlabCodes/ZZ_callClustering/vae_goffinet.py'>

In [4]:
# create a custom colormap for spectrogram
jet = plt.get_cmap('jet', 255)
# Extract jet colors and prepend black at the beginning
jet_colors = jet(np.linspace(0, 1, 255))
custom_colors = np.vstack([[0, 0, 0, 1], jet_colors])  # Black for 0, then jet
custom_cmap = ListedColormap(custom_colors)

## 1. Inputs

In [5]:
fd_z4 = '/mnt/z4'
fd_data = os.path.join(fd_z4, 'zz367', 'EphysMONAO', 'Analyzed', 'vaeWav')
birdID = 'pair5RigCCU29'
# color limits when calculating spectrograms, depending on the audio amplitude, may differ between birds
clims = [1.5,7]
# what syllable to analyze
syl = ['v']
# what spectogram datasets to use
spec_suffix = 'Spectrogram2'
spec_run = 'spec_goffinet_traj_256_236'
# what VAE run to use
vae_suffix = 'VAE5'
# vae_run = 'traj_chop_32_1_32'
# fd_vae = os.path.join(fd_data, birdID, 'Traj', vae_suffix, vae_run)
# shape of the spectrogram window
X_SHAPE = (128, 128)

In [6]:
# define parameters for spectrograms
X_SHAPE = [128, 128]
p = {
    'get_spec': vae_goffinet.get_specZZ, # spectrogram maker
    'max_dur': 1e9, # maximum syllable duration
    'min_freq': 250, # minimum frequency
    'max_freq': 7500, # maximum frequency, default 7500
    'num_freq_bins': X_SHAPE[0], # hard-coded
    'num_time_bins': X_SHAPE[1], # hard-coded
    'nperseg': 256, # FFT
    'noverlap': 236, # FFT, determines window overlap when calculating spectrograms
    'spec_min_val': clims[0], # minimum log-spectrogram value
    'spec_max_val': clims[1], # maximum log-spectrogram value
    'fs': 20000, # audio samplerate
    'mel': False, # frequency spacing, mel or linear
    'time_stretch': False, # stretch short syllables?
    'within_syll_normalize': False, # normalize spectrogram values on a # spectrogram-by-spectrogram basis
    'pad': 0.08,  # when extracting syllables for calculating spectrograms, pad before syllable onset and after syllable onset, unit is sec
    'win_frame': 32,  # duration of the sliding window, unit is spectrogram column
    'hop_frame': 1, # how much to slide for consecutive window, unit is spectrogram column
    'win_pad': 32, # how much to include before syllable onset, unit is spectrogram column, default to one sliding window
}

In [7]:
# where to save VAE results
apply_suffix = 'applySyl5'
# apply_run = f'latent.{vae_run}'
# print(apply_run)

## 2. Get the call subtype data

In [8]:
# load information of the syllable spectrograms
fn_spec = os.path.join(fd_data, birdID, 'Traj', spec_suffix, f'{birdID}.{spec_run}.h5')
fn_spec_info = os.path.join(fd_data, birdID, 'Traj', spec_suffix, f'{birdID}.{spec_run}.info.csv')
info_spec = pd.read_csv(fn_spec_info, index_col=0)
print(info_spec.shape)
info_spec.head()

(84254, 12)


Unnamed: 0,fn_wav,s_idx,istart,iend,label,spec_f,spec_t,i_start,i_end,zero_start,zero_end,rel_ori
0,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,0,660,3400,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,0,5000,940,0,1600
1,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,1,5740,12940,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,4140,14540,0,0,1600
2,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,2,18800,20000,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,17200,20860,0,740,1600
3,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,3,20860,24800,v,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,20000,26400,740,0,1600
4,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,4,28880,38480,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,27280,40080,0,0,1600


In [9]:
# load the call subtype data, use the latest
subtype_suffix = 'UMAPonVAE7'
subtype_run = 'spec_goffinet_nn_256_176'
fd_subtype = os.path.join(fd_data, birdID, subtype_suffix, 'v', subtype_run)
fn_subtype = os.path.join(fd_subtype, f'{birdID}.{subtype_run}.embedding.csv')
subtype = pd.read_csv(fn_subtype)
print(subtype.shape)

(23240, 44)


In [10]:
# merge the two dataframes based on fn_wav, s_idx and istart
info_merged = info_spec.merge(
    subtype[['fn_wav', 's_idx', 'istart', 'hdbscan_cluster', 'hdbscan_prob']],
    on=['fn_wav', 's_idx', 'istart'], how='left')
print(info_merged.shape)
info_merged.head()

(84254, 14)


Unnamed: 0,fn_wav,s_idx,istart,iend,label,spec_f,spec_t,i_start,i_end,zero_start,zero_end,rel_ori,hdbscan_cluster,hdbscan_prob
0,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,0,660,3400,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,0,5000,940,0,1600,,
1,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,1,5740,12940,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,4140,14540,0,0,1600,,
2,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,2,18800,20000,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,17200,20860,0,740,1600,,
3,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,3,20860,24800,v,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,20000,26400,740,0,1600,6.0,0.954308
4,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,4,28880,38480,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,27280,40080,0,0,1600,,


In [11]:
temp = info_merged[info_merged['label']=='v']
# check if there is any calls with matched rows
print(np.where(np.isnan(temp['hdbscan_cluster'])))

(array([], dtype=int64),)


In [12]:
# add a subtype column 
info_merged['hdbscan_cluster'] = info_merged['hdbscan_cluster'].astype('Int64')
info_merged['call_subtype'] = [f'v{aa}' for aa in info_merged['hdbscan_cluster']]

In [13]:
info_merged.iloc[0:20,:]

Unnamed: 0,fn_wav,s_idx,istart,iend,label,spec_f,spec_t,i_start,i_end,zero_start,zero_end,rel_ori,hdbscan_cluster,hdbscan_prob,call_subtype
0,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,0,660,3400,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,0,5000,940,0,1600,,,v<NA>
1,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,1,5740,12940,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,4140,14540,0,0,1600,,,v<NA>
2,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,2,18800,20000,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,17200,20860,0,740,1600,,,v<NA>
3,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,3,20860,24800,v,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,20000,26400,740,0,1600,6.0,0.954308,v6
4,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,4,28880,38480,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,27280,40080,0,0,1600,,,v<NA>
5,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,5,42760,45060,e,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,41160,46660,0,0,1600,,,v<NA>
6,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,6,46760,50759,v,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,45160,52359,0,0,1600,6.0,0.667034,v6
7,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,7,55640,57560,x,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,54040,59160,0,0,1600,,,v<NA>
8,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,8,60900,70560,b,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,59300,72160,0,0,1600,,,v<NA>
9,/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5...,9,75540,76300,h,[ 250. 307.08661417 364.17322835 42...,[0. 0.001 0.002 0.003 0.004 0.005 0.006 0.0...,73940,77900,0,0,1600,,,v<NA>


In [14]:
Counter(info_merged['call_subtype'])

Counter({'v<NA>': 61014,
         'v6': 6658,
         'v1': 3798,
         'v7': 3102,
         'v4': 2421,
         'v0': 2288,
         'v3': 1796,
         'v2': 1658,
         'v5': 1519})

In [15]:
# save the merged table
fd_save = os.path.join(fd_data, birdID, 'Traj', apply_suffix)
if not os.path.exists(fd_save):
    os.makedirs(fd_save)
print(fd_save)
fn_merged = os.path.join(fd_save, f'{birdID}.info_merged.csv')
info_merged.to_csv(fn_merged)

/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5RigCCU29/Traj/applySyl5


## 3. Select syllables, slide into spectrogram windows, obtain VAE latent

In [16]:
# select what syllables to use
temp = info_merged[info_merged['label']=='v']
syl_select = [f'v{ii}' for ii in range(0, max(temp['hdbscan_cluster'])+1)]
print('Syllables to use:', syl_select)

Syllables to use: ['v0', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6', 'v7']


In [17]:
# loop through different sliding window width
win_list =  [16,   24,  32,  40,  48,  64, 80]

In [None]:
for wi in range(1, len(win_list)):
    w = win_list[wi]
    rn = f'{w}_1_{w}'
    vae_run = f'traj_chop_{rn}'
    fd_vae = os.path.join(fd_data, birdID, 'Traj', vae_suffix, vae_run)
    # save in a subfolder
    fd_save_this = os.path.join(fd_save, vae_run)
    if not os.path.exists(fd_save_this):
        os.makedirs(fd_save_this)
    print(fd_save_this)

    # load the VAE model
    fn_vae = os.path.join(fd_vae, f'{birdID}_checkpoint_final.tar')
    print(fn_vae)
    model = vae_goffinet.VAE(save_dir=fd_vae)
    model.load_state(fn_vae)

    # change the parameters for the spectrogramming
    p['win_frame'] = w
    p['win_pad'] = w

    print(fn_spec)

    # loop through call subtypes, chop to spec windows, then obtain VAE latents
    model.eval()
    for vi in range(len(syl_select)):
        # vi = 1
        v = syl_select[vi]
        idx_v = info_merged[info_merged['call_subtype']==v].index

        # analyze in batches
        b_size = 48
        b_starts = list(range(0, len(idx_v), b_size))
        if b_starts[-1]!=len(idx_v):
            b_starts.append(len(idx_v))
        # loop through batches
        latent_all_m = np.empty((0, 32))
        latent_all_d = np.empty((0, 32))
        info_all = pd.DataFrame()
        for b_i in tqdm.tqdm(range(len(b_starts)-1)):
        # for b_i in range(3):
            bs = b_starts[b_i]
            be = b_starts[b_i+1]
            # chop 
            with Parallel(n_jobs=48, verbose=0) as parallel:
                res = parallel(delayed(vae_goffinet.ZZ_slideSylWin_v1)(fn_spec, ri, p, resize=True) for ri in idx_v[bs:be])  

            # flatten the result
            temp = [aa[0] for aa in res]
            specs = [arr for sublist in temp if sublist for arr in sublist]
            spec_win_all = np.stack(specs, axis=0)
            df_list = [aa[1] for aa in res]
            info = pd.concat([df for df in df_list if not df.empty], ignore_index=True)
            print(spec_win_all.shape, info.shape)

            # convert to datasets
            X_tensor = torch.from_numpy(spec_win_all).float()
            dataset = TensorDataset(X_tensor)
            # Create DataLoader
            train_dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)  

            # loop through dataloader, obtain model latent space
            latent_m = np.zeros((info.shape[0], 32))
            latent_d = np.zeros((info.shape[0], 32))
            # recon = np.zeros((info.shape[0], X_SHAPE[0], X_SHAPE[1]))
            model.eval()
            count = 0
            for i, data in tqdm.tqdm(enumerate(train_dataloader)):
                data = data[0].to('cuda:0')
                with torch.no_grad():
                    _, _, rec, mu, d = model.forwardZZ(data, return_latent_rec=True)
                    a = mu.shape[0]
                    latent_m[count:(count+a),:] = mu
                    latent_d[count:(count+a),:] = d
                    # recon[count:(count+a),:,:] = rec
                    count += a

            # append results
            latent_all_m = np.vstack([latent_all_m, latent_m])
            latent_all_d = np.vstack([latent_all_d, latent_d])
            info_all = pd.concat([info_all, info], ignore_index=True)

        # save the latent representations
        fn_latentM = os.path.join(fd_save_this, f'latentM.{v}.csv')
        np.savetxt(fn_latentM, latent_all_m, delimiter=',')
        fn_latentD = os.path.join(fd_save_this, f'latentD.{v}.csv')
        np.savetxt(fn_latentD, latent_all_d, delimiter=',')
        fn_info = os.path.join(fd_save_this, f'info.{v}.csv')
        info_all.to_csv(fn_info)
        print(latent_all_m.shape)

        del latent_all_m, latent_all_d, info_all
        gc.collect()

/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5RigCCU29/Traj/applySyl5/traj_chop_24_1_24
/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5RigCCU29/Traj/VAE5/traj_chop_24_1_24/pair5RigCCU29_checkpoint_final.tar
