In [1]:
import seaborn as sns
import matplotlib.pyplot as plt

import json

import sys
sys.path.append("../src")
from gait_gft import *


path_to_project = os.path.join('..')
misc_path = os.path.join(path_to_project, 'config')

with open(os.path.join(misc_path,'time_thresholds.json'), 'r') as f:
  time_thresholds = json.load(f)

with open(os.path.join(misc_path,'node_positions.json'), 'r') as f:
    pos_loaded = json.load(f)
    pos = {k: tuple(v) for k, v in pos_loaded.items()}

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns

class Visualizer:
    def __init__(self):
        pass  # You can add theme settings, font scaling, etc. here if needed

    def plot_graph_basis(self, graph: GraphModel, figsize=(6.4, 8), cmap='seismic', save_path=None):
        """
        Plot the Laplacian eigenvectors (U matrix) as a heatmap and the eigenvalues as a line plot above.

        Parameters
        ----------
        graph : GraphModel
            GraphModel object with .U, .Lambda, .nodes attributes.
        figsize : tuple
            Size of the figure.
        cmap : str
            Colormap for the heatmap.
        save_path : str or None
            If provided, save the figure to this path.
        """
        U = graph.U
        lambdas = graph.Lambda
        nodes = graph.node_list
        n = graph.n

        fig, (ax1, ax2) = plt.subplots(
            nrows=2, figsize=figsize, sharex=True,
            gridspec_kw={'height_ratios': [1, 3]}
        )

        # --- Upper: Eigenvalues
        width = 0.1
        ax1.scatter(
            np.arange(1, n+1) - 0.5, lambdas,
            color='#9B0014', edgecolor='black', s=50, marker='o', alpha=0.8
        )
        ax1.set_ylabel("Eigenvalue", fontsize=12)
        ax1.set_title(r"$L_N$ Eigenvalues", fontsize=14, pad=10)
        ax1.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
        ax1.tick_params(axis='both', which='major', labelsize=10)
        ax1.set_xlim(-1, len(lambdas) + 1)
        ax1.set_ylim(-width, max(2, lambdas.max()) + width)

        # --- Lower: Heatmap of U
        ax2.set_aspect('auto')
        sns.heatmap(U, cmap=cmap, center=0, cbar=False,
                    xticklabels=False, yticklabels=False, square=True, ax=ax2)

        ax2.set_xlabel("Eigenpair Index", fontsize=12)
        ax2.set_ylabel("Graph Node Index", fontsize=12)
        ax2.set_title(r"$L_N$ Eigenvectors", fontsize=14, pad=10)

        # Gridlines (vertical)
        for x in range(1, n):
            ax2.axvline(x, color='black', linewidth=0.7)

        ax2.set_xticks(np.arange(n) + 0.5)
        ax2.set_xticklabels(range(1, n + 1), fontsize=8, rotation=90)

        ax2.set_yticks(np.arange(n) + 0.5)
        ax2.set_yticklabels(nodes, fontsize=8)

        # --- External colorbar
        heatmap_box = ax2.get_position()
        cbar_height = 0.32
        cbar_bottom = heatmap_box.y0 + (heatmap_box.height - cbar_height) / 2
        cbar_ax = fig.add_axes([0.91, cbar_bottom, 0.02, cbar_height])

        norm = plt.Normalize(vmin=U.min(), vmax=U.max())
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        f = 0.85
        cbar = fig.colorbar(sm, cax=cbar_ax, ticks=[U.min()*f, 0, U.max()*f])
        cbar.ax.set_yticklabels(["-", "0", "+"], fontsize=12)

        # --- Layout
        plt.tight_layout(rect=[0, 0, 0.9, 1], h_pad=1.5)

        if save_path:
            fig.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()


    def plot_graph_signal(self, signal, ax, graph,
                          pos: dict = None,
                          norm_edges=None,
                          font_size=3, edge_width=1,
                          edge_cmap=plt.cm.gray_r,
                          node_size=100,
                          nodes_cmap=plt.cm.seismic,
                          draw_labels=False,
                          grayscale_edges=True,
                          signal_min_max=None,
                          constant_color=None,
                          draw_edges=True,
                          rotation=0,
                          font_color='black',
                          ):
        """
        Plot a graph signal on the given axis.

        Parameters
        ----------
        signal : np.ndarray
            1D array of signal values, length = number of nodes.
        ax : matplotlib axis
            Axis on which to draw.
        graph : GraphModel
            Graph structure and node order.
        pos : dict
            Node position layout (dict of node → (x, y)).
        labels : dict
            Node labels for display.
        edges_weights : np.ndarray
            Edge weights (optional).
        norm_edges : Normalize object
            Normalizer for edge colors.
        All other parameters: passed to `nx.draw_networkx_...`
        """
        if pos is None:
            raise ValueError("A position dictionary `pos` must be provided.")

        labels=dict(zip(graph.node_list,graph.node_list))
        edges_weights = np.array([
            graph.graph[u][v].get('weight', 1.0)
            for u, v in graph.graph.edges()
        ])
        if norm_edges is None:
            norm_edges = plt.Normalize(edges_weights.min(), edges_weights.max())


        # Normalize node values
        if signal_min_max is None:
            max_abs_val = np.max(np.abs(signal))
            signal_min_max = (-max_abs_val, max_abs_val)
        else:
            val = max(abs(signal_min_max[0]), abs(signal_min_max[1]))
            signal_min_max = (-val, val)

        norm_signal = plt.Normalize(vmin=signal_min_max[0], vmax=signal_min_max[1])

        # Node colors
        if constant_color is not None:
            node_colors = constant_color
        else:
            node_colors = nodes_cmap(norm_signal(signal))

        # Draw nodes
        nx.draw_networkx_nodes(
            graph.graph, pos,
            node_color=node_colors,
            node_size=node_size,
            # cmap=nodes_cmap,
            ax=ax
        )

        # Draw edges
        if draw_edges:
            if grayscale_edges:
                nx.draw_networkx_edges(
                    graph.graph, pos,
                    edge_color=edges_weights,
                    edge_cmap=edge_cmap,
                    width=edge_width,
                    edge_vmin=min(0, edges_weights.min()),
                    edge_vmax=edges_weights.max(),
                    ax=ax
                )
            else:
                nx.draw_networkx_edges(graph.graph, pos, width=edge_width, ax=ax)

        # Draw labels
        if draw_labels:
            label_texts = nx.draw_networkx_labels(
                graph.graph, pos,
                font_size=font_size,
                font_weight='bold',
                labels=labels,
                ax=ax,
                font_color=font_color
            )
            for _, t in label_texts.items():
                t.set_rotation(rotation)

        ax.axis("off")

        pos_X, pos_Y = zip(*pos.values())
        xlims = (min(pos_X)-10, max(pos_X)+10)
        ylims = (min(pos_Y)-10, max(pos_Y)+10)
    
        ax.set_xlim(xlims)
        ax.set_ylim(ylims)


# Graph Definition

In [3]:
with open(os.path.join(misc_path, 'links.txt'),'r') as f:
    text = f.read()

nodes = [line.split('\t')[1] for line in text.split('\n')]
edges = [tuple(line.split('\t')[1:]) for line in text.split('\n')[1:]]

skeleton = GraphModel(edges, node_list=nodes)

In [None]:
viz = Visualizer()
viz.plot_graph_basis(skeleton)

  plt.tight_layout(rect=[0, 0, 0.9, 1], h_pad=1.5)


# Data Preprocessing

In [None]:
path_to_zip = os.path.join(path_to_project, 'data', 'HDA_proj_A2.zip')
preproc = SkeletonPreprocessor(path_to_zip, nodes, time_thresholds)

r = preproc.load_and_process('subject1/normal/trial3/skeleton.csv')

# Signal and Gait definition

In [None]:
gait = GaitTrial(r, skeleton, label='normal')
# gait.vx.X_hat
# gait.vx.X

In [None]:
# Example setup
viz = Visualizer()

signal = gait.get_velocity_matrix('x')[:, 10]  # shape: (n_nodes,)

# Plot the signal
fig, ax = plt.subplots(figsize=(4, 4))

viz.plot_graph_signal(
    signal=signal,
    ax=ax,
    graph=skeleton,
    pos=pos
)

plt.tight_layout()
plt.show()

In [None]:
frame = 16

channels = ['x', 'y', 'z']
fig, axes = plt.subplots(1, 3, figsize=(17, 5), sharey=True)
plt.subplots_adjust(wspace=0.1)

v_max = max([gait.gft_dict[channel].X_hat.abs().max().max() for channel in ['x', 'y', 'z']])

for channel, ax in zip(channels, axes):
    df = gait.gft_dict[channel].X_hat

    sns.heatmap(df, cmap='seismic', center=0, cbar=False,
                vmin=-v_max, vmax=v_max,
                yticklabels=False, #square=True,
                ax=ax)

    ax.set_title(f"$GFT V_{channel}$", fontsize=16)
    # ax.set_yticks(np.arange(32) + 0.5)
    # ax.set_yticklabels(nodes, fontsize=8)
    ax.set_xlabel('Timestep', fontsize=10)

    ax.axvline(frame + 0.5, color='green', linestyle='dashed', linewidth=1.5)


gradient_ax = fig.add_axes([0.1, -0.05, 0.8, 0.025])  # Position for the colorbar
gradient = np.linspace(-1, 1, 256).reshape(1, -1)  # Generate gradient data from -1 to 1
gradient_ax.imshow(gradient, aspect='auto', cmap='seismic', extent=[-1, 1, 0, 1])
gradient_ax.set_xticks([-1, 0, 1])
gradient_ax.set_xticklabels(['-', '0', '+'], fontsize=4)
gradient_ax.set_yticks([])
gradient_ax.tick_params(axis='x', labelsize=20)
fig.savefig(f'normal_gait_nodes_vs_time_signals_grid.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# El objeto gait debe:
# -tener metodo de filtro, que guarde en filtered signal
# -usando el filtered signal, que sea capaz de recuperar las posiciones originales

# Después de definir objetos para graficar y animar,
# escribir notebook que muestre cómo usarlo

# Estructurar carpeta de archivos para subir a github

# Referenciar tesis, paper de origen de datos

In [None]:
# Debería entregar:
# - Notebook coherente, mostrando resultados (que el lector interprete el código)
# - Replicar gráficas más relevantes de la tesis

    # Matriz U e evals [skeleton]
    # grid esqueletos como evecs [skeleton]

    # signal vs. GFT [custom]
    # plot energia [signal]
    # columnas, una por cada tipo patologia [custom]

# - Animaciones más relevantes (esqueleto,esqueleto con vectores, esqueleto con/sin vectores y modos laterales (enlistar),
        # grid 6 esqueletos, [GFT(t) + energia, con esqueleto a un lado, barra vertical que recorra los modos normales])