In [1]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
from moseq2_extras.plotutil import plot_behav_dist_and_usage
from moseq2_viz.model.util import (get_transition_matrix,
                                   parse_model_results,
                                   results_to_dataframe,
                                   relabel_by_usage, get_syllable_statistics)
from moseq2_extras.stats import plot_syllable_usage_pca
from moseq2_viz.util import parse_index
import numpy as np
import re
import pandas as pd
import seaborn as sns
import sys
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import matplotlib.pyplot as plt
from random import shuffle, sample
from sklearn.manifold import TSNE
from scipy.spatial.distance import pdist, squareform
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
import matplotlib as mpl
%matplotlib qt

In [5]:
model_file = 'F:/moseq/2021-01-15_Meloxicam/2021-02-19_moseq/rST_model_1000.p'
index_file = 'F:/moseq/2021-01-15_Meloxicam/2021-02-19_moseq/moseq2-index.role.yaml'

max_syllable = 66

groups = ['baseline', '4hrs carrageenan', '24hrs saline', '24hrs meloxicam', 'baseline meloxicam']
palette = sns.color_palette(['#f06493', '#35fab3', '#647aa3', '#020887', '#ff0000'])
markers = ['o', 's', '^', 'P', 'X']


#groups = ['baseline', '4hrs carrageenan', '24hrs saline']
#palette = sns.color_palette(['#f06493', '#35fab3', '#020887'])
#markers = ['o', 's', '^']

In [3]:
_, sorted_index = parse_index(index_file)
model = parse_model_results(model_file, sort_labels_by_usage=True, count='usage')

labels = model['labels']
label_group = [sorted_index['files'][uuid]['group'] for uuid in model['keys']]

tm_vals = []
group_vals = []
usage_vals = []
frames_vals = []

for l, g, u in tqdm(list(zip(labels, label_group, model['keys'])), leave=False):
    if g in groups:
        group_vals.append(g)
        
        tm = get_transition_matrix([l], combine=True, max_syllable=max_syllable)
        tm_vals.append(tm.ravel())
        
        u, _ = get_syllable_statistics(l, count='usage')
        total_u = np.sum(list(u.values()))
        usage_vals.append(np.array(list(u.values())) / total_u)
        
        f, _ = get_syllable_statistics(l, count='frames')
        total_f = np.sum(list(f.values()))
        frames_vals.append(np.array(list(f.values())) / total_f)

print(len(tm_vals), len(group_vals), len(usage_vals), len(frames_vals))

#tm_vals
usage_vals = np.array(usage_vals)
frames_vals = np.array(frames_vals)
#usage_vals


                                                                                                                       

80 80 80 80




In [8]:
# 2D LDA with Transitions

lda = LinearDiscriminantAnalysis(n_components=2, solver='eigen', shrinkage=0.9)
lda_result = lda.fit(tm_vals, group_vals).transform(tm_vals)
print('LDA Score: {}'.format(lda.score(tm_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))
out_base = "transitions_LDA_2D_5group"


fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

sns.kdeplot(x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, palette=palette, fill=True, alpha=0.5)
sns.scatterplot(ax=ax, x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, style=group_vals, style_order=groups, markers=markers, palette=palette, legend="full")

ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_title('LDA Transitions')

plt.savefig('{}.png'.format(out_base))
plt.savefig('{}.pdf'.format(out_base))

LDA Score: 0.375
LDA Explained Variance: [0.01615296 0.0107511 ]


In [9]:
# 2D LDA with Usage

lda = LinearDiscriminantAnalysis(n_components=2, solver='eigen', shrinkage=0.1)
lda_result = lda.fit(usage_vals, group_vals).transform(usage_vals)
print('LDA Score: {}'.format(lda.score(usage_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))
out_base = "usage_LDA_2D_5group"

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

sns.kdeplot(x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, palette=palette, fill=True, alpha=0.5)
sns.scatterplot(ax=ax, x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, style=group_vals, style_order=groups, markers=markers, palette=palette, legend="full")

ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_title('LDA Usage')

plt.savefig('{}.png'.format(out_base))
plt.savefig('{}.pdf'.format(out_base))

LDA Score: 0.375
LDA Explained Variance: [0.2579315  0.15966811]


In [10]:
# 2D LDA with Frames

lda = LinearDiscriminantAnalysis(n_components=2, solver='eigen', shrinkage=0.1)
lda_result = lda.fit(frames_vals, group_vals).transform(frames_vals)
print('LDA Score: {}'.format(lda.score(frames_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))
out_base = "frames_LDA_2D_5group"

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

sns.kdeplot(x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, palette=palette, fill=True, alpha=0.5)
sns.scatterplot(ax=ax, x=lda_result.T[0], y=lda_result.T[1], hue=group_vals, hue_order=groups, style=group_vals, style_order=groups, markers=markers, palette=palette, legend="full")

ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_title('LDA Frames')

plt.savefig('{}.png'.format(out_base))
plt.savefig('{}.pdf'.format(out_base))

LDA Score: 0.375
LDA Explained Variance: [0.23980199 0.15009701]


In [4]:
# 3D LDA with Transitions

lda = LinearDiscriminantAnalysis(n_components=3, solver='eigen', shrinkage=0.9)
lda_result = lda.fit(tm_vals, group_vals).transform(tm_vals)
print('LDA Score: {}'.format(lda.score(tm_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))

lgd_itms = [mpl.lines.Line2D([0],[0], linestyle="none", c=c, marker=m) for c, m in zip(palette, markers)]

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
cs = [palette[groups.index(g)] for g in group_vals]
ms = [markers[groups.index(g)] for g in group_vals]
for d, c, m, g in zip(lda_result, cs, ms, group_vals):
    ax.scatter(d[0], d[1], d[2], c=[c], marker=m, label=g)

plt.legend(lgd_itms, groups)
ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_zlabel('LDA_3')
ax.set_title('LDA Transitions')
fig.show()

LDA Score: 0.5
LDA Explained Variance: [0.02356179 0.01153724]


IndexError: index 2 is out of bounds for axis 0 with size 2

In [17]:
# 3D LDA with Usage

lda = LinearDiscriminantAnalysis(n_components=3, solver='eigen', shrinkage=0.2)
lda_result = lda.fit(usage_vals, group_vals).transform(usage_vals)
print('LDA Score: {}'.format(lda.score(usage_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))

lgd_itms = [mpl.lines.Line2D([0],[0], linestyle="none", c=c, marker=m) for c, m in zip(palette, markers)]

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
cs = [palette[groups.index(g)] for g in group_vals]
ms = [markers[groups.index(g)] for g in group_vals]
for d, c, m, g in zip(lda_result, cs, ms, group_vals):
    ax.scatter(d[0], d[1], d[2], c=[c], marker=m, label=g)

plt.legend(lgd_itms, groups)
ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_zlabel('LDA_3')
ax.set_title('LDA Usage')
fig.show()

LDA Score: 0.16666666666666666
LDA Explained Variance: [nan nan]


IndexError: index 2 is out of bounds for axis 0 with size 2

In [9]:
# 3D LDA with Frames

lda = LinearDiscriminantAnalysis(n_components=3, solver='eigen', shrinkage=0.1)
lda_result = lda.fit(frames_vals, group_vals).transform(frames_vals)
print('LDA Score: {}'.format(lda.score(frames_vals, group_vals)))
print('LDA Explained Variance: {}'.format(lda.explained_variance_ratio_))

lgd_itms = [mpl.lines.Line2D([0],[0], linestyle="none", c=c, marker=m) for c, m in zip(palette, markers)]

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
cs = [palette[groups.index(g)] for g in group_vals]
ms = [markers[groups.index(g)] for g in group_vals]
for d, c, m, g in zip(lda_result, cs, ms, group_vals):
    ax.scatter(d[0], d[1], d[2], c=[c], marker=m, label=g)

plt.legend(lgd_itms, groups)
ax.set_xlabel('LDA_1')
ax.set_ylabel('LDA_2')
ax.set_zlabel('LDA_3')
ax.set_title('LDA Frames')
fig.show()

LDA Score: 0.16666666666666666
LDA Explained Variance: [nan nan]


IndexError: index 2 is out of bounds for axis 0 with size 2