In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import pandas as pd
from cuml.manifold.umap import UMAP as cumlUMAP
from avgn.utils.paths import DATA_DIR, most_recent_subdirectory, ensure_dir
from joblib import Parallel, delayed



In [3]:
import pomegranate
from pomegranate import DiscreteDistribution, HiddenMarkovModel
pomegranate.utils.disable_gpu()
from hmmlearn import hmm

### Fit models code

In [4]:
def AIC(log_likelihood, k):
    """ AIC given log_likelihood and # parameters (k)
    """
    aic = 2 * k - 2 * log_likelihood
    return aic


def BIC(log_likelihood, n, k):
    """ BIC given log_likelihood, number of observations (n) and # parameters (k)
    """
    bic = np.log(n) * k - 2 * log_likelihood
    return bic

def FOMM(seqs, prop_test=0.5):
    """ create a FOMM in pomegranite
    """
    if prop_test == 0:
        seqs_train = seqs_test = seqs
    else:
        # split into train and test for cross validation
        training_mask = np.random.choice(
            np.arange(len(seqs)), size=int(len(seqs) * prop_test), replace=False
        )
        testing_mask = np.array(
            [i for i in np.arange(len(seqs)) if i not in training_mask]
        )

        seqs_train = np.array(seqs)[training_mask]
        seqs_test = np.array(seqs)[testing_mask]

        # make sure test set doesn't contain any data that train doesnt
        assert np.all(
            [
                i in np.unique(np.concatenate(seqs_train))
                for i in np.unique(np.concatenate(seqs_test))
            ]
        )

    # lengths of sequences
    seq_lens = [len(i) for i in seqs_train]

    # get states
    unique_states = np.unique(np.concatenate(seqs_train))

    # get start probabilities
    seq_starts = np.array([i[0] for i in seqs_train])
    start_probs = [np.sum(seq_starts == i) / len(seqs_train) for i in unique_states]

    end_states = [seq[-1] for seq in seqs]
    end_probs = [
        np.sum(end_states == i) / (np.sum(np.concatenate(seqs) == i) + 1)
        for i in np.arange(len(unique_states))
    ]

    # transition probs
    trans_mat = np.zeros((len(unique_states), len(unique_states)))
    for seq in seqs_train:
        for i, j in zip(seq[:-1], seq[1:]):
            trans_mat[i, j] += 1
    # smooth to nonzero probabilities
    trans_mat = (trans_mat.T / trans_mat.sum(axis=1)).T  # np.sum(trans_mat, axis=1)

    # smooth emissions
    emission_prob = np.identity(len(unique_states)) + 1e-5
    emission_prob = (emission_prob.T / emission_prob.sum(axis=1)).T

    # number of datapoints
    test_seq_lens = [len(i) for i in seqs_test]
    n_data = np.sum(test_seq_lens)

    # initialize pomegranate model

    transmat = trans_mat
    start_probs = start_probs
    dists = emission_prob

    states = [
        DiscreteDistribution({vis: d[i] for i, vis in enumerate(unique_states)})
        for d in dists
    ]
    pom_model = HiddenMarkovModel.from_matrix(
        transition_probabilities=transmat,
        distributions=states,
        starts=start_probs,
        ends=end_probs,  # discluding ends and merge makes models equal log prob
        merge="None",
    )
    pom_model.bake()
    pom_log_probability = np.sum([pom_model.log_probability(seq) for seq in seqs_test])
    

    # number of params in model
    num_params = (
        pom_model.edge_count() + pom_model.node_count() + pom_model.state_count()  # no hidden states in FOMM
    )

    # AIC and BIC
    aic = AIC(pom_log_probability, num_params)
    bic = BIC(pom_log_probability, n_data, num_params)
    return (
        pom_model,
        seqs_train,
        seqs_test,
        pom_log_probability,
        num_params,
        n_data,
        aic,
        bic,
    )

def fit_fixed_latent(seqs, latent_seqs, verbose=False):

    unique_latent_labels = np.unique(np.concatenate(latent_seqs))
    n_components = len(unique_latent_labels)

    # convert latent sequences to correct format
    label_seqs_str = [
        ["None-start"] + ["s" + str(i) for i in seq] + ["None-end"]
        for seq in latent_seqs
    ]
    
    pom_model = HiddenMarkovModel.from_samples(
        distribution=DiscreteDistribution,
        n_components=len(unique_latent_labels),
        X=seqs,
        labels=label_seqs_str,
        end_state=True,
        algorithm="labeled",
        verbose=verbose,
    )

    log_prob = [pom_model.log_probability(seq) for seq in seqs]

    sum_log_prob = np.sum(log_prob)
    
    num_params = (
        pom_model.state_count() + pom_model.edge_count() + pom_model.node_count()
    )

    n_data = np.sum([len(i) for i in seqs])

    aic = AIC(sum_log_prob, num_params)
    bic = BIC(sum_log_prob, n_data, num_params)

    return pom_model, log_prob, sum_log_prob, n_components, num_params, n_data, aic, bic

In [5]:
DATASET_ID = 'koumura_bengalese_finch'
embeddings_dfs = list(DATA_DIR.glob('bf_label_dfs/'+DATASET_ID+'/*.pickle'))
DATASET_ID = 'bengalese_finch_sober'
embeddings_dfs = embeddings_dfs + list(DATA_DIR.glob('bf_label_dfs/'+DATASET_ID+'/*.pickle'))

In [6]:
embeddings_dfs

[PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird9.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird4.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird10.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird6.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird0.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird2.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird1.pickle'),
 PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data/bf_label_dfs/koumura_bengalese_finch/Bird3.pickle'),
 PosixPath('/mnt/cube/t

In [7]:
for loc in tqdm(embeddings_dfs):
    # read dataframe
    indv_df = pd.read_pickle(loc).sort_values(by=["key", "start_time"])
    indv = indv_df.indv.unique()[0]

    # Get seqs
    hand_seqs = [
        list(indv_df[indv_df.syllables_sequence_id == seqid]["labels_num"].values)
        for seqid in indv_df.syllables_sequence_id.unique()
    ]

    results_df_FOMM = pd.DataFrame(
        [FOMM(hand_seqs, prop_test=0)],
        columns=[
            "pom_model",
            "seqs_train",
            "seqs_test",
            "pom_log_probability",
            "n_params",
            "n_data",
            "aic",
            "bic",
        ],
    )
    results_df_FOMM["indv"] = indv
    save_loc = DATA_DIR / "HMM_fits" / "FOMM" / (indv + ".pickle")
    ensure_dir(save_loc)
    results_df_FOMM.to_pickle(save_loc)

    ### HDBSCAN as latent
    # HDBSCAN seqs
    
    for hdbscan_labels in ["hdbscan_labels_num", "hdbscan_labels-0.1_num",  "hdbscan_labels-0.25_num"]:
        hdbscan_latent_seqs = [
            list(
                indv_df[indv_df.syllables_sequence_id == seqid][hdbscan_labels].values
            )
            for seqid in indv_df.syllables_sequence_id.unique()
        ]

        # make latent df
        results_df_umap_hidden = pd.DataFrame(
            [fit_fixed_latent(hand_seqs, hdbscan_latent_seqs, verbose=False)],
            columns=[
                "pom_model",
                "log_prob",
                "sum_log_prob",
                "n_components",
                "num_params",
                "n_data",
                "aic",
                "bic",
            ],
        )
        results_df_umap_hidden["indv"] = indv
        save_loc = DATA_DIR / "HMM_fits"  / hdbscan_labels / "HDBSCAN" / (indv + ".pickle")
        ensure_dir(save_loc)
        results_df_umap_hidden.to_pickle(save_loc)

    ### second order model
    seqs_second_order_states = [
        list(
            indv_df[indv_df.syllables_sequence_id == seqid][
                "seqs_second_order_states"
            ].values
        )
        for seqid in indv_df.syllables_sequence_id.unique()
    ]

    results_df_second_order_hidden = pd.DataFrame(
        [fit_fixed_latent(hand_seqs, seqs_second_order_states, verbose=False)],
        columns=[
            "pom_model",
            "log_prob",
            "sum_log_prob",
            "n_components",
            "num_params",
            "n_data",
            "aic",
            "bic",
        ],
    )
    results_df_second_order_hidden["indv"] = indv
    save_loc = DATA_DIR / "SOMM" / (indv + ".pickle")
    ensure_dir(save_loc)
    results_df_second_order_hidden.to_pickle(save_loc)

    print(
        "---{}---\nAIC: \n\tSOMM: {}\n\tFOMM: {} \n\tHDBSCAN: {} \nLL: \n\tSOMM: {}\n\tFOMM: {} \n\tHDBSCAN: {}".format(
            indv,
            round(results_df_second_order_hidden.aic.values[0]),
            round(results_df_umap_hidden.aic.values[0]),
            round(results_df_FOMM.aic.values[0]),
            round(results_df_second_order_hidden.sum_log_prob.values[0]),
            round(results_df_umap_hidden.sum_log_prob.values[0]),
            round(results_df_FOMM.pom_log_probability.values[0]),
        )
    )

HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

---Bird9---\AIC: 
	SOMM: 19098.0
	FOMM: 22824.0 
	HDBSCAN: 23352.0 
LL: 
	SOMM: -9108.0
	FOMM: -11243.0 
	HDBSCAN: -11644.0
---Bird4---\AIC: 
	SOMM: 31000.0
	FOMM: 31417.0 
	HDBSCAN: 32006.0 
LL: 
	SOMM: -15100.0
	FOMM: -15627.0 
	HDBSCAN: -15973.0
---Bird10---\AIC: 
	SOMM: 10432.0
	FOMM: 7775.0 
	HDBSCAN: 8867.0 
LL: 
	SOMM: -3100.0
	FOMM: -3446.0 
	HDBSCAN: -4367.0
---Bird6---\AIC: 
	SOMM: 23582.0
	FOMM: 24004.0 
	HDBSCAN: 25202.0 
LL: 
	SOMM: -11307.0
	FOMM: -11858.0 
	HDBSCAN: -12570.0
---Bird0---\AIC: 
	SOMM: 10503.0
	FOMM: 10124.0 
	HDBSCAN: 13339.0 
LL: 
	SOMM: -4027.0
	FOMM: -4738.0 
	HDBSCAN: -6619.0
---Bird2---\AIC: 
	SOMM: 20745.0
	FOMM: 18910.0 
	HDBSCAN: 18680.0 
LL: 
	SOMM: -8348.0
	FOMM: -8879.0 
	HDBSCAN: -9263.0
---Bird1---\AIC: 
	SOMM: 45255.0
	FOMM: 43808.0 
	HDBSCAN: 53080.0 
LL: 
	SOMM: -18783.0
	FOMM: -21420.0 
	HDBSCAN: -26454.0
---Bird3---\AIC: 
	SOMM: 43372.0
	FOMM: 43977.0 
	HDBSCAN: 45204.0 
LL: 
	SOMM: -19922.0
	FOMM: -21732.0 
	HDBSCAN: -22544.0
---Bird5---



---gy6or6---\AIC: 
	SOMM: 22492.0
	FOMM: 36469.0 
	HDBSCAN: 37280.0 
LL: 
	SOMM: -9130.0
	FOMM: -17911.0 
	HDBSCAN: -18546.0
---or60yw70---\AIC: 
	SOMM: 36763.0
	FOMM: 33728.0 
	HDBSCAN: 44021.0 
LL: 
	SOMM: -16172.0
	FOMM: -16608.0 
	HDBSCAN: -21937.0
---gr41rd51---\AIC: 
	SOMM: 68759.0
	FOMM: 57241.0 
	HDBSCAN: 45385.0 
LL: 
	SOMM: -19003.0
	FOMM: -28297.0 
	HDBSCAN: -22533.0



In [8]:
DATA_DIR

PosixPath('/mnt/cube/tsainbur/Projects/github_repos/avgn_paper/data')