## 0. Goal
Vary the inputs to UMAP:
1. Use the VAE trained only on calls <br>
Use the model trained on the proportional datasets <br>
Vary the sliding window duration

In [1]:
import os, sys, importlib, librosa, glob, h5py, tqdm, pickle
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from joblib import Parallel, delayed
import random
import umap, hdbscan
from collections import Counter, OrderedDict
import seaborn as sns
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from skimage import transform
import gc
import colorsys
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors

plt.rcParams['pdf.fonttype'] = 42 

In [2]:
# import my utility script
cluster_script_path = '/home/zz367/ProjectsU/EphysMONAO/Jupyter/MatlabCodes/ZZ_callClustering/'
sys.path.insert(1, cluster_script_path)
import vae_goffinet, hopkins
importlib.reload(vae_goffinet)

<module 'vae_goffinet' from '/home/zz367/ProjectsU/EphysMONAO/Jupyter/MatlabCodes/ZZ_callClustering/vae_goffinet.py'>

In [3]:
# create a custom colormap for spectrogram
jet = plt.get_cmap('jet', 255)
# Extract jet colors and prepend black at the beginning
jet_colors = jet(np.linspace(0, 1, 255))
custom_colors = np.vstack([[0, 0, 0, 1], jet_colors])  # Black for 0, then jet
custom_cmap = ListedColormap(custom_colors)

## 1. Inputs

In [4]:
fd_z4 = '/mnt/z4'
fd_data = os.path.join(fd_z4, 'zz367', 'EphysMONAO', 'Analyzed', 'vaeWav')
birdID = 'pair5RigCCU29'
# color limits when calculating spectrograms, depending on the audio amplitude, may differ between birds
clims = [1.5,7]
# what syllable pairs to analyze
v_all = ['v4', 'v5']
# what spectogram datasets to use
spec_suffix = 'Spectrogram2'
spec_run = 'spec_goffinet_traj_256_236'
# what VAE run to use
vae_suffix = 'VAE5'
# vae_run = 'traj_chop_32_1_32'
# fd_vae = os.path.join(fd_data, birdID, 'Traj', vae_suffix, vae_run)
# where the VAE latents results are saved
apply_suffix = 'applySyl5'
# apply_run = f'latent.{vae_run}'
# fd_latent = os.path.join(fd_data, birdID, 'Traj', apply_suffix, apply_run)
# print(fd_latent)

In [12]:
# where to save results
fd_save_base = os.path.join(fd_data, birdID, 'Traj', apply_suffix, 'paramSearch5')
print(fd_save_base)

/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5RigCCU29/Traj/applySyl5/paramSearch5


In [7]:
# grab the info on syllales
fd_latent_base = os.path.join(fd_data, birdID, 'Traj', apply_suffix)
fn_merged = os.path.join(fd_latent_base, f'{birdID}.info_merged.csv')
info_syl = pd.read_csv(fn_merged, index_col=0)
# construct a unique syllable id
info_syl['syl_id'] = [f'{info_syl["fn_wav"][ii]}_{info_syl["s_idx"][ii]}' for ii in info_syl.index]
info_syl = info_syl.reset_index()  # reset index for easy merging later

In [8]:
param_umap = {'n_components':2, 'n_neighbors':25, 'min_dist':0, 'metric':'euclidean'}

## 2. UMAP on latents for v4 and v5

In [9]:
# loop through different sliding window width
win_list =  [16,   24,  32,  40,  48,  64, 80]

In [None]:
# wi = 0 
for wi in range(1, len(win_list)):
    w = win_list[wi]
    rn = f'{w}_1_{w}'
    vae_run = f'traj_chop_{rn}'
    fd_vae = os.path.join(fd_data, birdID, 'Traj', vae_suffix, vae_run)
    fd_latent = os.path.join(fd_latent_base, vae_run)
    fd_save = os.path.join(fd_save_base, vae_run)
    if not os.path.exists(fd_save):
        os.makedirs(fd_save)
    print(fd_save)

    num_syl = 10000
    hdbscan_thre = 0.9
    # grab the latent data and info for each call subytpe
    fns_latent = sorted(glob.glob(os.path.join(fd_latent, 'latentM.v*.csv')))
    # read in all data
    latent_comb = np.empty((0, 32))
    info_comb = pd.DataFrame()
    for v in v_all:
        fn = sorted(glob.glob(os.path.join(fd_latent, f'latentM.{v}.csv')))[0]
        latent = np.loadtxt(fn, delimiter=',')
        fn_info = os.path.join(os.path.dirname(fn), f'info.{v}.csv')
        info = pd.read_csv(fn_info, index_col=0)

        # merge the syllable info to sliding window info
        merged_info = info.merge(info_syl, left_on='ri', right_on='index', how='left')

        #### sample data as UMAP inputs
        info_pass = merged_info[merged_info['hdbscan_prob']>=hdbscan_thre]
        if info_pass.empty:
            continue
        # get unique syllable id
        syl_uniq = sorted(list(set(info_pass['syl_id'])))
        # same syllables
        num_sample = min([num_syl, len(syl_uniq)])
        print(f'Total syllable {len(set(merged_info["syl_id"]))}. Pass threshold {len(syl_uniq)}. Sampled {num_sample}')
        random.seed(1118)
        syl_rd = random.sample(syl_uniq, num_sample)

        # get the latent data
        iwin = merged_info[merged_info['syl_id'].isin(syl_rd)].index
        info_rd = merged_info.iloc[iwin, :]
        latent_rd = latent[iwin, :]
        print(latent_rd.shape, info_rd.shape)

        latent_comb = np.vstack([latent_comb, latent_rd])
        info_comb = pd.concat([info_comb, info_rd], ignore_index=True)

    print(latent_comb.shape, info_comb.shape)

    umap_model, embed = vae_goffinet.ZZ_runUMAP_v1(latent_comb, param_umap, random_state=1118, meta_info=info_comb)

    # save the embedding
    syl_str = ''.join(v_all)
    fn_embed = os.path.join(fd_save, f'{birdID}.{syl_str}.hop1ms.embedding.csv')
    embed.to_csv(fn_embed)
    print(fn_embed)

    # save the UMAP model for later usage
    fn_umap = os.path.join(fd_save, f'UMAPmodel_{birdID}.{syl_str}.p')
    pickle.dump(umap_model, open(fn_umap, 'wb'))

    #### plot results, different color for different call subtypes
    col_full = ['#a65628','#4daf4a','#984ea3','#e41a1c','#ff7f00','#f781bf','#377eb8','#737373']
    col_dict = OrderedDict(zip(v_all, col_full[0:len(v_all)]))
    # get the syl id
    syl_comb = list(set(info_comb['syl_id']))
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=[15, 5.5], sharex=True, sharey=True)
    for si in range(len(syl_comb)):
    # for si in range(100):
        ss = syl_comb[si]
        embed_s = embed[embed['syl_id']==ss]
        embed_s.index = range(embed_s.shape[0])
        x = np.array(embed_s['umap1'])
        y = np.array(embed_s['umap2'])
        v = embed_s['call_subtype'][0]
        vi = v_all.index(v)
        # first  plot in overlaid
        ax = axes[0]
        ax.plot(x, y,  marker='o', linestyle='-', color=col_dict[v], markersize=3, linewidth=1, alpha=0.025, rasterized=True, markeredgecolor='none')
        # mark the start with triangle, and end with square
        ax.plot(x[0], y[0], marker='^', markersize=6, color='blue', alpha=0.1, rasterized=True, markeredgecolor='none')
        ax.plot(x[-1], y[-1], marker='s', markersize=6, color='red', alpha=0.1, rasterized=True, markeredgecolor='none')
        # then plot in separate
        ax = axes[vi+1]
        ax.plot(x, y,  marker='o', linestyle='-', color=col_dict[v], markersize=3, linewidth=1, alpha=0.025, rasterized=True, markeredgecolor='none')
        # mark the start with triangle, and end with square
        ax.plot(x[0], y[0], marker='^', markersize=6, color='blue', alpha=0.1, rasterized=True, markeredgecolor='none')
        ax.plot(x[-1], y[-1], marker='s', markersize=6, color='red', alpha=0.1, rasterized=True, markeredgecolor='none')
    # add legend
    legends = [Patch(facecolor=col_dict[name], label=name) for name in col_dict.keys()]
    ax = axes[0]
    ax.legend(handles=legends, loc='lower right', fontsize=12)
    # add xy axis labels
    for ai in range(len(axes)):
        ax = axes[ai]
        ax.set_title(f'{birdID}: {rn}', fontsize=10)
        ax.set_xlabel('UMAP axis 1', fontsize=12)
        ax.set_ylabel('UMAP axis 2', fontsize=12)
    plt.tight_layout()

    # save fig
    fn_fig = os.path.join(fd_save, f'{birdID}.{rn}.embedding.pdf')
    fig.savefig(fn_fig, dpi=600)

/mnt/z4/zz367/EphysMONAO/Analyzed/vaeWav/pair5RigCCU29/Traj/applySyl5/paramSearch5/traj_chop_24_1_24
Total syllable 2421. Pass threshold 1228. Sampled 1228
(243248, 32) (243248, 21)
Total syllable 1519. Pass threshold 1359. Sampled 1359
(265590, 32) (265590, 21)
(508838, 32) (508838, 21)
UMAP(min_dist=0, n_jobs=1, n_neighbors=25, random_state=1118, verbose=True)
Thu Jul 31 11:02:55 2025 Construct fuzzy simplicial set
Thu Jul 31 11:02:55 2025 Finding Nearest Neighbors
Thu Jul 31 11:02:55 2025 Building RP forest with 41 trees


  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")
