# DataJoint Element Facemap: Pose Estimation

Open-source data pipeline to automate analyses and organize data

In this tutorial, we will walk through processing facial behavior videos using the [Facemap Pose Estimation framework]("https://github.com/MouseLand/facemap"). Distinct facial keypoints are determined using a pose estimation model and are stored in the DataJoint pipeline. 

We will explain the following concepts as they relate to this pipeline:
- What is an Element versus a pipeline?
- Plot the pipeline with `dj.Diagram`
- Insert data into tables
- Query table contents
- Fetch table contents
- Run the pipeline for your experiments

For detailed documentation and tutorials on general DataJoint principles that support collaboration, automation, reproducibility, and visualizations:

- [DataJoint for Python - Interactive Tutorials](https://github.com/datajoint/datajoint-tutorials) - Fundamentals including table tiers, query operations, fetch operations, automated computations with the `make` function, etc.

- [DataJoint for Python - Documentation](https://datajoint.com/docs/core/datajoint-python/)

- [DataJoint Element for Facemap - Documentation](https://datajoint.com/docs/elements/element-facemap/)

Let's start by importing the packages necessary to run this pipeline 


# DataJoint Element for Facemap Deep Learning

#### Open-source data pipeline for processing and analyzing facial behavior videos.

Welcome to the tutorial for the DataJoint Element for facial pose estimation and ROI detection. This
tutorial aims to provide a comprehensive understanding of the open-source data pipeline
created using `element-facemap`.

This package is designed to seamlessly process, ingest, and track extracellular electrophysiology
data, along with its associated probe and recording metadata. By the end of this
tutorial you will have a clear grasp on setting up and integrating `element-facemap`
into your specific research projects and lab. 

![flowchart](../images/diagram_flowchart.svg)

### Prerequisites

Please see the [datajoint tutorials GitHub
repository](https://github.com/datajoint/datajoint-tutorials/tree/main) before
proceeding.

A basic understanding of the following DataJoint concepts will be beneficial to your
understanding of this tutorial: 
1. The `Imported` and `Computed` tables types in `datajoint-python`.
2. The functionality of the `.populate()` method. 

#### **Tutorial Overview**

+ Setup
+ *Activate* the DataJoint pipeline.
+ *Insert* subject and session metadata.
+ *Populate* Video Recordings and Model information
+ Generate and run the pose estimation task.
+ Visualize the results.

### **Setup**

This tutorial examines video data of .mp4 file format. The goal is to store, track
and manage sessions of facial pose data, including determining coordinates of facial body parts and
trajectory visualizations.

The results of this Element can be combined with **other modalities** to create
a complete, customizable data pipeline for your specific lab or study. For instance, you
can combine `element-facemap` with `element-calcium-imaging` and
`element-array-ephys` to relate orofacial behavior to the neural activity.

Let's start this tutorial by importing the packages necessary to run the notebook.

In [None]:
import os

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
assert os.path.basename(os.getcwd()) == "element-facemap", (
    "Please move to the " + "element directory"
)

In [None]:
import datajoint as dj
import datetime
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

If the tutorial is run in Codespaces, a private, local database server is created and
made available for you. This is where we will insert and store our processed results.
Let's connect to the database server.

In [None]:
dj.conn()

### **Activate the DataJoint Pipeline**

This tutorial activates the `facemap_inference.py` module from `element-facemap`, along
with upstream dependencies from `element-animal` and `element-session`. Please refer to the
[`tutorial_pipeline.py`](./tutorial_pipeline.py) for the source code.

## Combine multiple Elements into a pipeline

Each DataJoint Element is a modular set of tables that can be combined into a complete pipeline.

Each Element contains 1 or more modules, and each module declares its own schema in the database.

This tutorial pipeline is assembled from four DataJoint Elements.

The results of this Element example can be combined with other modalities to create a complete customizable data pipeline for your specific lab or study. For instance, you can combine element-facemap with element-calcium-imaging and element-array-ephys to characterize the neural activity.

| Element | Source Code | Documentation | Description |
| -- | -- | -- | -- |
| Element Lab | [Link](https://github.com/datajoint/element-lab) | [Link](https://datajoint.com/docs/elements/element-lab) | Lab management related information, such as Lab, User, Project, Protocol, Source. |
| Element Animal | [Link](https://github.com/datajoint/element-animal) | [Link](https://datajoint.com/docs/elements/element-animal) | General animal metadata and surgery information. |
| Element Session | [Link](https://github.com/datajoint/element-session) | [Link](https://datajoint.com/docs/elements/element-session) | General information of experimental sessions. |
| Element Facemap | [Link](https://github.com/datajoint/element-facemap) | [Link](https://datajoint.com/docs/elements/element-facemap) |  Facemap orofacial bodypart prediction with singular value decomposition or deep learning pose estimation.|

By importing the modules for the first time, the schemas and tables will be created in the database.  Once created, importing modules will not create schemas and tables again, but the existing schemas/tables can be accessed.

The Elements are imported and activated within the `tutorial_pipeline` script.

In [None]:
from tutorial_pipeline import (
    lab,
    subject,
    session,
    fbe,
    facemap_inference,
    Device
)

Each Python module (e.g. subject) contains a schema object that enables interaction with the schema in the database.

In [None]:
subject.schema

The Python classes in the module correspond to a table in the database server.

In [None]:
subject.Subject()

## Facemap Inference Schemas Diagram

We can represent the `fbe` and `facemap_inference` schemas and their upstream dependencies, `subject` and `session`, using `dj.Diagram()`.

In [None]:
(
    dj.Diagram(subject.Subject)
    + dj.Diagram(session.Session)
    + dj.Diagram(fbe.VideoRecording)
    + dj.Diagram(fbe.VideoRecording.File)
    + dj.Diagram(facemap_inference)
)

As evident from the diagram, this data pipeline encompasses tables associated with
model and video file data, task generation and results of model inference. A few tables, such as `subject.Subject` or `session.Session`,
while important for a complete pipeline, fall outside the scope of the `element-facemap`
tutorial, and will therefore, not be explored extensively here. The primary focus of
this tutorial will be on the `facemap_inference` schema.

## Insert entries into `subject` and `session` manual tables

Let's start with the first table in the schema diagram (i.e. `subject.Subject` table).

To know what data to insert into the table, we can view its dependencies and attributes using the `.describe()` and `.heading` functions.

In [None]:
print(subject.Subject.describe())

In [None]:
subject.Subject.heading

The cells above show all attributes of the subject table.
We will insert data into the
`subject.Subject` table. 

In [None]:
subject.Subject.insert1(
    dict(
        subject="subject1",
        subject_nickname="subject1_nickname",
        sex="U",
        subject_birth_date="2020-01-01",
        subject_description="Demo subject for Facemap Pose estimation processing.",
    )
)
subject.Subject()

Let's repeat the steps above for the `Session` table.

In [None]:
print(session.Session.describe())

In [None]:
session.Session.heading

Notice that `describe`, displays the table's structure and highlights its dependencies, such as its reliance on the `Subject` table. These dependencies represent foreign key references, linking data across tables.

On the other hand, `heading` provides an exhaustive list of the table's attributes. This
list includes both the attributes declared in this table and any inherited from upstream
tables.

With this understanding, let's move on to insert a session associated with our subject.

We will insert into the `session.Session` table by passing a dictionary to the `insert1` method.

In [None]:
session_key = dict(subject="subject1", session_datetime="2021-04-30 12:22:15")

In [None]:
session.Session.insert1(session_key)
session.Session()

Every experimental session produces a set of data files. The purpose of the `SessionDirectory` table is to locate these files. It references a directory path relative to a root directory, defined in `dj.config[\"custom\"]`. More information about `dj.config` is provided in the [documentation](https://datajoint.com/docs/elements/user-guide/).

In [None]:
session.SessionDirectory.insert1(
    dict(**session_key, session_dir="subject1/session1")
)
session.SessionDirectory()

Let's start by inserting a local pytorch model file into the `facemap_inference.FacemapModel` table
- Specify a unique `model_id`, a `model_description`, and the `full_local_model_filepath`
- The default facemap model is located in the hidden .facemap folder installed to your computer's home directory: `i.e. ~/.facemap/models/facemap_model_state.pt` 

In [None]:
model_name = 'facemap_model_state.pt'
model_id = 0
model_description = "test facemap model"
full_local_model_filepath = "~/.facemap/models/facemap_model_state.pt"
facemap_inference.FacemapModel.insert_new_model(
                                                model_name=model_name, 
                                                model_id=model_id, 
                                                model_description=model_description, 
                                                full_model_path=full_local_model_filepath
                                                )

Let's display the `facemap_inference.FacemapModel` table queried with the model id of interest to verify insertion

In [None]:
facemap_inference.FacemapModel() & f'model_id={model_id}'

Next let's display the `facemap_inference.FacemapModel.File` and `facemap_inference.FacemapModel.BodyPart` part tables

In [None]:
facemap_inference.FacemapModel.File() & f'model_id={model_id}'

In [None]:
facemap_inference.FacemapModel.BodyPart() & f'model_id={model_id}'

As the Diagram indicates, `fbe.VideoRecording` table needs to
contain data before the `facemap_inference.FacemapPoseEstimationTask` can be generated.

Next we will insert behavioral video recording data into `fbe.VideoRecording` and its part table, `fbe.VideoRecording.File`. 

In [None]:
from pathlib import Path
video_recording_key = {**session_key, 
                          "recording_id": 0}
facemap_root_dir_path = fbe.get_facemap_root_data_dir()
vid_path = "./example_data/inbox/subject0/session0/*.avi"
video_recording_file_insert = {
    **video_recording_key,
    "file_id": 0,
    "file_path": Path(vid_path).relative_to(facemap_root_dir_path),
}

fbe.VideoRecording.insert1(video_recording_key)
fbe.VideoRecording.File.insert1(video_recording_file_insert)

With an entries present in the `facemap_inference.FacemapModel` and the `fbe.VideoRecording` tables, the criteria is met for insertion into the `facemap_inference.FacemapPoseEstimationTask` table.
- `facemap_inference.FacemapPoseEstimationTask` is a staging table that pairs a specific `FacemapModel` with a `VideoRecording`.
- For this example we will choose to `load` existing results due to the speed of processing. 
- If choosing to run processing, the `task_mode` should be set to `trigger`. This step may take some time and can result in a lost connection database issue, to solve this simply rerun the cell.

 Pose estimation is then be evaluated in the next table `facemap_inference.FacemapPoseEstimation` according the the specifications of the `FacemapPoseEstimationTask`

In [None]:
model_key = (facemap_inference.FacemapModel & f'model_id={model_id}').fetch1("KEY")
key = {**video_recording_key, **model_key}
task_description = "Demo Facemap Inference Task, loads processed results"
facemap_inference.FacemapPoseEstimationTask.insert_pose_estimation_task(key, task_description=task_description, task_mode="load")

We can display the key in the `FacemapPoseEstimationTask` table to confirm it was inserted.

In [None]:
(facemap_inference.FacemapPoseEstimationTask() & key)

Next we will ingest the results into `FacemapPoseEstimation` and its part table `FacemapPoseEstimation.BodyPartPosition` for the key that we just inserted into the `FacemapPoseEstimationTask`

In [None]:
facemap_inference.FacemapPoseEstimation.populate(key, display_progress=True)

Once the cell above has completed run the next cells to display the `FacemapPoseEstimation` tables

In [None]:
facemap_inference.FacemapPoseEstimation()

In [None]:
facemap_inference.FacemapPoseEstimation.BodyPartPosition()

### Visualize Pose Estimation Output

In [None]:
pe_query = {**session_key, 'recording_id': 0, 'model_id': model_id}
pose_estimation_key = (facemap_inference.FacemapPoseEstimation & pe_query).fetch1("KEY")

Get Trajectory of X and Y coordinates

In [None]:
# Specify all body parts, or set body_parts to a custom list
body_parts = "all"
model_name = (facemap_inference.FacemapModel & f'model_id={key["model_id"]}').fetch1("model_name")

if body_parts == "all":
    body_parts = (facemap_inference.FacemapPoseEstimation.BodyPartPosition & key).fetch("body_part")
elif not isinstance(body_parts, list):
    body_parts = list(body_parts)


In [None]:
# Construct Pandas MultiIndex DataFrame
df = None
for body_part in body_parts:
    result_dict = (
        facemap_inference.FacemapPoseEstimation.BodyPartPosition
        & {"body_part": body_part}
        & {"recording_id": key["recording_id"]}
        & {"session_id": key["session_id"]}
    ).fetch("x_pos", "y_pos", "likelihood", as_dict=True)[0]
    x_pos = result_dict["x_pos"].tolist()
    y_pos = result_dict["y_pos"].tolist()
    likelihood = result_dict["likelihood"].tolist()
    a = np.vstack((x_pos, y_pos, likelihood))
    a = a.T
    pdindex = pd.MultiIndex.from_product(
        [[model_name], [body_part], ["x", "y", "likelihood"]],
        names=["model", "bodyparts", "coords"],
    )
    frame = pd.DataFrame(a, columns=pdindex, index=range(0, a.shape[0]))
    df = pd.concat([df, frame], axis=1)
df

In [None]:
df_xy = df.iloc[:,df.columns.get_level_values(2).isin(["x","y"])]['facemap_model_state.pt']
df_xy.mean()

Plot coordinates across time for each body part

In [None]:
df_xy.plot().legend(loc='best', prop={'size': 5})

In [None]:
df_flat = df_xy.copy()
df_flat.columns = df_flat.columns.map('_'.join)

Plot Trace Overlays of each body part across time

In [None]:
fig,ax=plt.subplots(2,2)
fig.set_figwidth(20)
fig.set_figheight(15)

df_flat.plot(x='eye(front)_x',y='eye(front)_y',ax=ax[0, 0])
df_flat.plot(x='eye(back)_x',y='eye(back)_y',ax=ax[0, 0])
df_flat.plot(x='eye(bottom)_x',y='eye(bottom)_y',ax=ax[0, 0])

df_flat.plot(x='nose(tip)_x',y='nose(tip)_y', ax=ax[1, 0])
df_flat.plot(x='nose(bottom)_x',y='nose(bottom)_y', ax=ax[1, 0])
df_flat.plot(x='nose(r)_x',y='nose(r)_y', ax=ax[1, 0])
df_flat.plot(x='nosebridge_x',y='nosebridge_y', ax=ax[1, 0])

df_flat.plot(x='mouth_x',y='mouth_y', ax=ax[0, 1])
df_flat.plot(x='lowerlip_x',y='lowerlip_y', ax=ax[0, 1])
df_flat.plot(x='paw_x',y='paw_y', ax=ax[0, 1])

df_flat.plot(x='whisker(I)_x',y='whisker(I)_y', ax=ax[1, 1])
df_flat.plot(x='whisker(II)_x',y='whisker(II)_y', ax=ax[1, 1])
df_flat.plot(x='whisker(II)_x',y='whisker(II)_y', ax=ax[1, 1])


### Visualize Keypoints Data

In [None]:
from matplotlib import cm
colors = cm.get_cmap("jet")(np.linspace(0, 1.0, len(body_parts)))

(facemap_inference.FacemapPoseEstimation.BodyPartPosition & pose_estimation_key)

Fetch the keypoints_data from the database as a dictionary in order to index and reshape it.

In [None]:
keypoints_data = (facemap_inference.FacemapPoseEstimation.BodyPartPosition & pose_estimation_key).fetch(as_dict=True)

In [None]:
pose_x_coord = []
pose_y_coord = []
pose_likelihood = []
for body_part_data in keypoints_data:
    pose_x_coord.append(body_part_data["x_pos"][:])
    pose_y_coord.append(body_part_data["y_pos"][:])
    pose_likelihood.append(body_part_data["likelihood"][:])

pose_x_coord = np.array([pose_x_coord]) # size: key points x frames
pose_y_coord = np.array([pose_y_coord]) # size: key points x frames
pose_likelihood = np.array([pose_likelihood])  # size: key points x frames
pose_data = np.concatenate(
    (pose_x_coord, pose_y_coord, pose_likelihood), axis=0
)  # size: 3 x key points x frames
pose_x_coord = pose_data[0,:,:]
pose_y_coord = pose_data[1,:,:]
pose_liklihood = pose_data[2,:,:]

Plot keypoints for a subset of frames

In [None]:
start_frame = 100
end_frame = 500

plt.figure(figsize=(15, 5), dpi=100)
for i, bodypart in enumerate(body_parts):
    plt.plot(np.arange(start_frame, end_frame), pose_x_coord[i, start_frame:end_frame], '-', c=colors[i], label=bodypart)
    plt.plot(np.arange(start_frame, end_frame), pose_y_coord[i, start_frame:end_frame], '--', c=colors[i])
plt.xlabel('Frame')
plt.ylabel('Keypoint coordinates')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.show()

Plot a subset of bodypart keypoints for a subset of frames

In [None]:
subset_bodyparts = ["whisker(I)",  "whisker(II)", "whisker(III)"]
start_frame = 100
end_frame = 500

plt.figure(figsize=(15, 5), dpi=100)
for i, bodypart in enumerate(body_parts):
    if bodypart in subset_bodyparts:
        plt.plot(np.arange(start_frame, end_frame), pose_x_coord[i, start_frame:end_frame], '-', c=colors[i], 
                 label=bodypart)
        plt.plot(np.arange(start_frame, end_frame), pose_y_coord[i, start_frame:end_frame], '--', c=colors[i])
plt.xlabel('Frame')
plt.ylabel('Keypoint coordinates')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.show()

#### Filter keypoints data by confidence

Use Facemap's `filter_outliers` function to remove outliers by applying a median filter to the keypoints data.

In [None]:
from facemap.utils import filter_outliers
# Use the following function to filter outliers in the keypoints data (see docstring for details)
"""
utils.filter_outliers(x, y, filter_window=15, baseline_window=50, max_spike=25, max_diff=25)
x: x coordinates of keypoints
y: y coordinates of keypoints
filter_window: window size for median filter (default: 15)
baseline_window: window size for baseline estimation (default: 50)
max_spike: maximum spike size (default: 25)
max_diff: maximum difference between baseline and filtered signal (default: 25)
"""

plt.figure(figsize=(15, 5), dpi=100)
for i, bodypart in enumerate(body_parts):
    if bodypart in subset_bodyparts:
        x, y = filter_outliers(pose_x_coord[i], pose_y_coord[i])
        plt.plot(np.arange(start_frame, end_frame), x[start_frame:end_frame], '-', c=colors[i], label=bodypart)
        plt.plot(np.arange(start_frame, end_frame), y[start_frame:end_frame], '--', c=colors[i])
plt.xlabel('Frame')
plt.ylabel('Keypoint coordinates')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
plt.title('Filtered keypoints')
plt.show()