# Loading data

In [None]:
from utils.preprocessing_functions import *
from utils.model_fitting_functions import *
from utils.plotting_functions import *
import utils.behavior_utils

FEATURE_NAMES = utils.behavior_utils.FEATURE_NAMES
FEAT_RANGES = utils.behavior_utils.FEAT_RANGES

# Preprocessing the data

In [None]:
images_dir = "Embeddings/"
directory_path = 'PainH5s/'
known_distance_mm = 40
stim_types = ["cb", "db", "lp", "hp"]
group_types = ["Base", "24h", "4h"]

locations, session_ends, stims, groups = preprocess_data(directory_path, known_distance_mm, stim_types, group_types)
locations = fill_missing(locations)
roll_mean = compound_smoother(locations, window_size = 3)

In [None]:
plot_paw_trajectory(locations = locations, smoothed_locations = roll_mean, toe_index=0, center_index=1, heel_index=2)

# Extracting features

In [None]:
paw_withdrawal_features, norm_min, norm_max = instance_node_velocities(my_locations = roll_mean, session_ends = session_ends, window_size = 15)
plot_withdrawal_features(paw_withdrawal_features, node_features = FEATURE_NAMES)
plot_withdrawal_features(paw_withdrawal_features[0:250,:], node_features = FEATURE_NAMES)

# Training a variational autoencoder

In [None]:
encoded_data_flat = train_variational_autoencoder(paw_withdrawal_features, sliding_window_size = 100, epochs_n = 20, batch_size = 8, validation_split = 0.2)

# Dimensionality Reduction by PCA

In [None]:
# Apply PCA
pca = PCA(n_components=2)
pca_result = pca.fit_transform(encoded_data_flat)

# Visualize the PCA result
plt.scatter(pca_result[:, 0], pca_result[:, 1], s = 0.5)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA Visualization of Encoded Data')
plt.show()

In [None]:
import plotly.graph_objs as go
from sklearn.decomposition import PCA

# Apply PCA with 3 components
pca = PCA(n_components=3)
pca_result = pca.fit_transform(encoded_data_flat)

# Create trace for the scatter plot
trace = go.Scatter3d(
    x=pca_result[:, 0],
    y=pca_result[:, 1],
    z=pca_result[:, 2],
    mode='markers',
    marker=dict(
        size=3,
        color='blue',                # set color to an array/list of desired values
        opacity=0.5
    )
)

# Create layout for the plot
layout = go.Layout(
    title='PCA Visualization of Encoded Data (3D)',
    scene=dict(
        xaxis=dict(title='PC1'),
        yaxis=dict(title='PC2'),
        zaxis=dict(title='PC3')
    )
)

# Create figure and add trace to it
fig = go.Figure(data=[trace], layout=layout)

# Show the plot
fig.show()


In [None]:
plot_3D_vector_field(embedding_data=pca_result, session_ends=session_ends, grid_rows=50, grid_cols=50, grid_depth=50, scale=9, width = 1000, height = 800)

In [None]:
'horizontal slice'

N_feats = paw_withdrawal_features.shape[1]

 # Define the dimensions of the grid
num_rows = 4
num_cols = 6

# Create a grid of subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 9))

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Loop through each feature and plot it in a subplot
for i, ax in enumerate(axes):
    if i < N_feats:
        scatter = ax.scatter(pca_result[:,0], pca_result[:,1], c=paw_withdrawal_features[0:pca_result.shape[0],i], cmap='rainbow', s=1)
        ax.set_title(FEATURE_NAMES[i])
        ax.axis('off')
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label(FEAT_RANGES[i], rotation=270, va='bottom')
    else:
        ax.axis('off')  # Remove axis for any unused subplots

plt.suptitle("PCA XY Embedding of Paw Withdrawal Data")
plt.tight_layout()

plt.savefig(fname = images_dir + "PCA_1_2_Projection_" + datetime.datetime.now().strftime("%Y-%m-%d")  + ".png")
plt.show()

In [None]:
# 'sagittal slice'

import datetime

N_feats = paw_withdrawal_features.shape[1]

 # Define the dimensions of the grid
num_rows = 4
num_cols = 6

# Create a grid of subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 9))

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Loop through each feature and plot it in a subplot
for i, ax in enumerate(axes):
    if i < N_feats:
        scatter = ax.scatter(pca_result[:,0], pca_result[:,2], c=paw_withdrawal_features[0:pca_result.shape[0],i], cmap='rainbow', s=1)
        ax.set_title(FEATURE_NAMES[i])
        ax.axis('off')
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label(FEAT_RANGES[i], rotation=270, va='bottom')
    else:
        ax.axis('off')  # Remove axis for any unused subplots

plt.suptitle("PCA XZ Embedding of Paw Withdrawal Data")
plt.tight_layout()

plt.savefig(fname = images_dir + "PCA_1_3_Projection_" + datetime.datetime.now().strftime("%Y-%m-%d")  + ".png")
plt.show()

In [None]:
'coronal slice'

import datetime

N_feats = paw_withdrawal_features.shape[1]

 # Define the dimensions of the grid
num_rows = 4
num_cols = 6

# Create a grid of subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 9))

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Loop through each feature and plot it in a subplot
for i, ax in enumerate(axes):
    if i < N_feats:
        scatter = ax.scatter(pca_result[:,1], pca_result[:,2], c=paw_withdrawal_features[0:pca_result.shape[0],i], cmap='rainbow', s=1)
        ax.set_title(FEATURE_NAMES[i])
        ax.axis('off')
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label(FEAT_RANGES[i], rotation=270, va='bottom')
    else:
        ax.axis('off')  # Remove axis for any unused subplots

plt.suptitle("PCA YZ Embedding of Paw Withdrawal Data")
plt.tight_layout()

plt.savefig(fname = images_dir + "PCA_2_3_Projection_" + datetime.datetime.now().strftime("%Y-%m-%d")  + ".png")
plt.show()

In [None]:
plot_average_heatmap_subplots(data=pca_result[:,[0,2]], features=paw_withdrawal_features, feature_names = FEATURE_NAMES, feature_ranges = FEAT_RANGES, percentage=100, grid_rows=15, grid_cols=15)

# Trajectory Modeling (TM) and Energy Landscape Analysis (ELA)

In [None]:
# 'pseudohorizontal'
plot_displacement_vector_field_direction_global(embedding_data=pca_result[:,0:2], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=1)
# 'pseudocoronal'
plot_displacement_vector_field_direction_global(embedding_data=pca_result[:,1:3], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=5)
# 'pseudosagittal'
plot_displacement_vector_field_direction_global(embedding_data=pca_result[:, [0, 2]], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=1)

In [None]:
plot_displacement_vector_field_direction_unique_stims(embedding_data=pca_result[:, [1,2]], session_ends=session_ends, stims=stims, grid_rows=32, grid_cols=32, scale=80)

In [None]:
plot_displacement_vector_field_direction_unique_groups_stims(embedding_data=pca_result[:, [1,2]], session_ends=session_ends, stims=stims, groups=groups, grid_rows=32, grid_cols=32, scale=80)

In [None]:
plot_displacement_vector_field_velocity_global(embedding_data=pca_result[:, 0:2], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=10)
#plot_displacement_vector_field_velocity_global(embedding_data=pca_result[:, 1:3], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=3)
#plot_displacement_vector_field_velocity_global(embedding_data=pca_result[:, [0,2]], session_ends=session_ends, grid_rows=32, grid_cols=32, scale=3)

In [None]:
plot_displacement_vector_field_velocity_unique_stims(embedding_data=pca_result[:, 0:2], session_ends=session_ends, stims=stims, grid_rows=32, grid_cols=32, scale=80)
#plot_displacement_vector_field_velocity_unique_stims(embedding_data=pca_result[:, 1:3], session_ends=session_ends, stims=stims, grid_rows=32, grid_cols=32, scale=30)
#plot_displacement_vector_field_velocity_unique_stims(embedding_data=pca_result[:, [0,2]], session_ends=session_ends, stims=stims, grid_rows=32, grid_cols=32, scale=30)

In [None]:
plot_displacement_vector_field_velocity_unique_groups_stims(embedding_data=pca_result[:, 1:3], session_ends=session_ends, stims=stims, groups=groups, grid_rows=32, grid_cols=32, scale=80)

In [None]:
plot_unique_session_against_global_embedding(embedding_data=pca_result[:, 0:2], session_ends=session_ends, session=1, grid_rows=32, grid_cols=32, embedding_scale=10, session_scale=20, overlay_color = "purple")
#plot_unique_session_against_global_embedding(embedding_data=pca_result[:, 1:3], session_ends=session_ends, session=1, grid_rows=32, grid_cols=32, embedding_scale=17, session_scale=20, overlay_color = "purple")
#plot_unique_session_against_global_embedding(embedding_data=pca_result[:, [0,2]], session_ends=session_ends, session=1, grid_rows=32, grid_cols=32, embedding_scale=17, session_scale=20, overlay_color = "purple")

In [None]:
plot_curl_streamlines_global(embedding_data=pca_result[:, 0:2], session_ends=session_ends, grid_rows=24, grid_cols=24, scale=1)
#plot_curl_streamlines_global(embedding_data=pca_result[:, 1:3], session_ends=session_ends, grid_rows=12, grid_cols=12, scale=1)
#plot_curl_streamlines_global(embedding_data=pca_result[:, [0,2]], session_ends=session_ends, grid_rows=12, grid_cols=12, scale=1)