In [None]:
# ===============================================================
# PARTE 1: INSTALAÇÃO DAS DEPENDÊNCIAS
# ===============================================================

!rm -rf monocular-demos
!pip install git+https://github.com/peabody124/GaitTransformer
!git clone https://github.com/IntelligentSensingAndRehabilitation/monocular-demos.git
!pip install -q fastapi uvicorn pyngrok python-multipart nest-asyncio

%cd monocular-demos
!pip install .
%cd ..

%env MUJOCO_GL=egl

# limit jax and TF from consuming all GPU memory
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

#from tqdm.notebook import tqdm

import os
import jax
import jax.numpy as jnp
import tensorflow_hub as hub
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cv2

from tqdm import tqdm
from tqdm import trange
from mpl_toolkits.mplot3d import Axes3D
from jaxtyping import Integer, Float, Array, PRNGKeyArray
from typing import Tuple, Dict

import equinox as eqx
import optax

import monocular_demos
from monocular_demos.biomechanics_mjx.forward_kinematics import ForwardKinematics
from monocular_demos.biomechanics_mjx.monocular_trajectory import KineticsWrapper, get_default_wrapper
from monocular_demos.biomechanics_mjx.visualize import render_trajectory, jupyter_embed_video


from gait_transformer.gait_phase_transformer import load_default_model, gait_phase_stride_inference
from gait_transformer.visualization import make_overlay, jupyter_embed_video
from gait_transformer.gait_phase_kalman import gait_kalman_smoother, compute_phases, get_event_times

# Check if GPU is available
if tf.config.list_physical_devices('GPU'):
    print("TensorFlow is using the GPU")
else:
    print("TensorFlow is not using the GPU")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)


# Check for available GPU devices
num_devices = jax.local_device_count()
print(f"Found {num_devices} JAX devices:")

In [None]:
# ===============================================================
# PARTE 2: CÓDIGO DE PROCESSAMENTO ADAPTADO
# ===============================================================
def processador_de_video(video_filepath: str, output_dir: str, joint_selection: str = "Joelho", height_mm: int = 1778):
  """
  Executa o pipeline completo de análise biomecânica do notebook.
  Retorna um dicionário com os caminhos para todos os arquivos de saída gerados.
  """
  fk = ForwardKinematics()

  ### ALTERAÇÃO: Configuração do diretório de saída e nomes de arquivo ###
  base_filename = os.path.splitext(os.path.basename(video_filepath))[0]
  os.makedirs(output_dir, exist_ok=True)
  results = {}
  print(f"Iniciando análise para o vídeo: {video_filepath}")
  print(f"Os resultados serão salvos em: {output_dir}")

  # ----------------------------------------------------------------------------
  cap = cv2.VideoCapture(video_filepath)
  ret, frame = cap.read()
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  ### ALTERAÇÃO: Salva o Gráfico 1 (Frame Inicial) ###
  plt.figure(figsize=(8, 6))
  plt.imshow(frame)
  plt.title(f'Frame Inicial de {os.path.basename(video_filepath)}')
  plt.axis('off')
  path_grafico_1 = os.path.join(output_dir, f"{base_filename}_01_frame_inicial.png")
  plt.savefig(path_grafico_1)
  plt.close()
  results['grafico_frame_inicial'] = path_grafico_1

  # Se o frame aparecer transposto, defina isto como True ---------------------
  # (Nota: a sobreposição do gait transformer não ficará correta)
  rotated = False

  # Estimação de Keypoints (MeTAbs com skeleton 'bml_movi_87') ---------------------------
  model = hub.load('https://bit.ly/metrabs_l')  # Takes about 3 minutes
  skeleton = 'bml_movi_87'
  joint_names = model.per_skeleton_joint_names[skeleton].numpy().astype(str)
  joint_edges = model.per_skeleton_joint_edges[skeleton].numpy()

  from monocular_demos.utils import joint_names, video_reader

  # Lê o vídeo em lotes ------------------------------------------------
  vid, n_frames = video_reader(video_filepath)

  print(f'About to processs {video_filepath} which has {n_frames} frames')
  accumulated = None
  for i, frame_batch in tqdm(enumerate(vid), total=n_frames//8):
      if rotated:
          frame_batch = frame_batch.transpose(0, 2, 1, 3)

      pred = model.detect_poses_batched(frame_batch, skeleton=skeleton)

      if accumulated is None:
          accumulated = pred

      else:
          for key in accumulated.keys():
              accumulated[key] = tf.concat([accumulated[key], pred[key]], axis=0)

      # if i > 10:
      #     break

  # Verifica o numero de pessoas detectadas por frame ------------------------------------------------
  num_people = [p.shape[0] for p in accumulated['poses2d']]

  if 0 in set(num_people):
      print('**WARNING** some frames with no people, make fail')

  # assert this is 1 for all the frames
  #assert len(set(num_people)) == 1

  boxes = np.array([p[0] for p in accumulated['boxes'] if len(p) > 0])
  pose3d = np.array([p[0] for p in accumulated['poses3d'] if len(p) > 0])
  pose2d = np.array([p[0] for p in accumulated['poses2d'] if len(p) > 0])

  # Para conveniência, salve os keypoints caso o notebook trave ou precise reiniciar
  with open('keypoints3d.npz', 'wb') as f:
      np.savez(f, pose3d)


  frame_idx = 0

  # Seleciona apenas os nomes das articulações que possuem letras maiúsculas
  capitalized_joint_names = [name for name in joint_names if any(c.isupper() for c in name)]

  # Encontra os índices dessas articulações na lista original
  capitalized_joint_indices = [joint_names.index(name) for name in capitalized_joint_names]

  # Extrai os keypoints do frame selecionado apenas para as articulações com letras maiúsculas
  keypoints_t0_capitalized = pose3d[frame_idx, capitalized_joint_indices, :]

  # Defina as conexões (arestas) para as articulações com letras maiúsculas.
  # Você deve definir manualmente essas conexões com base nas ligações anatômicas
  # que deseja visualizar entre as articulações selecionadas.
  # Esta é uma etapa crucial, pois o array original joint_edges pode não mapear
  # diretamente para o subconjunto de articulações com letras maiúsculas.
  # Como exemplo, vamos definir algumas conexões potenciais com base na anatomia humana comum.
  # Você precisará ajustar isso com base nas conexões reais que deseja exibir.
  # Os índices aqui se referem aos índices *dentro* da lista `capitalized_joint_names`.
  capitalized_joint_edges = [
      (capitalized_joint_names.index('CHip'), capitalized_joint_names.index('LHip')),
      (capitalized_joint_names.index('CHip'), capitalized_joint_names.index('RHip')),
      (capitalized_joint_names.index('LHip'), capitalized_joint_names.index('LKnee')),
      (capitalized_joint_names.index('RHip'), capitalized_joint_names.index('RKnee')),
      (capitalized_joint_names.index('LKnee'), capitalized_joint_names.index('LAnkle')),
      (capitalized_joint_names.index('RKnee'), capitalized_joint_names.index('RAnkle')),
      (capitalized_joint_names.index('LAnkle'), capitalized_joint_names.index('LFoot')),
      (capitalized_joint_names.index('RAnkle'), capitalized_joint_names.index('RFoot')),
      (capitalized_joint_names.index('CHip'), capitalized_joint_names.index('Neck')),
      (capitalized_joint_names.index('Neck'), capitalized_joint_names.index('Head')),
      (capitalized_joint_names.index('Neck'), capitalized_joint_names.index('LShoulder')),
      (capitalized_joint_names.index('Neck'), capitalized_joint_names.index('RShoulder')),
      (capitalized_joint_names.index('LShoulder'), capitalized_joint_names.index('LElbow')),
      (capitalized_joint_names.index('RShoulder'), capitalized_joint_names.index('RElbow')),
      (capitalized_joint_names.index('LElbow'), capitalized_joint_names.index('LWrist')),
      (capitalized_joint_names.index('RElbow'), capitalized_joint_names.index('RWrist')),
      (capitalized_joint_names.index('LWrist'), capitalized_joint_names.index('LHand')),
      (capitalized_joint_names.index('RWrist'), capitalized_joint_names.index('RHand')),
  ]

  # Create a 3D scatter plot------------------------------------------------------------
  fig = plt.figure(figsize=(10, 8))
  fig.suptitle('Visualização do Esqueleto no Frame 0')
  ax = fig.add_subplot(121, projection='3d')

  # Scatter plot of the capitalized keypoints
  ax.scatter(keypoints_t0_capitalized[:, 0], keypoints_t0_capitalized[:, 2], keypoints_t0_capitalized[:, 1])

  # Add labels for each capitalized joint
  for i, (x, y, z) in enumerate(keypoints_t0_capitalized):
      ax.text(x, y, z, capitalized_joint_names[i], fontsize=9)

  # Draw lines connecting the capitalized joints based on the defined edges
  for i, j in capitalized_joint_edges:
      ax.plot(
          [keypoints_t0_capitalized[i, 0], keypoints_t0_capitalized[j, 0]],
          [keypoints_t0_capitalized[i, 2], keypoints_t0_capitalized[j, 2]],
          [keypoints_t0_capitalized[i, 1], keypoints_t0_capitalized[j, 1]],
          'k-',  # 'k-' means black solid line
          linewidth=1
      )


  # Set labels for the axes
  ax.set_xlabel('X')
  ax.set_ylabel('Y (element 2)')
  ax.set_zlabel('Z (element 1)')

  # Set a title for the plot
  ax.set_title('3D Scatter Plot and Skeleton')

  # Set equal aspect ratio
  max_range = np.array([keypoints_t0_capitalized[:,0].max() - keypoints_t0_capitalized[:,0].min(),
                        keypoints_t0_capitalized[:,2].max() - keypoints_t0_capitalized[:,2].min(),
                        keypoints_t0_capitalized[:,1].max() - keypoints_t0_capitalized[:,1].min()]).max()

  mid_x = (keypoints_t0_capitalized[:,0].max() + keypoints_t0_capitalized[:,0].min()) * 0.5
  mid_y = (keypoints_t0_capitalized[:,2].max() + keypoints_t0_capitalized[:,2].min()) * 0.5
  mid_z = (keypoints_t0_capitalized[:,1].max() + keypoints_t0_capitalized[:,1].min()) * 0.5

  ax.set_xlim(mid_x - max_range * 0.5, mid_x + max_range * 0.5)
  ax.set_ylim(mid_y - max_range * 0.5, mid_y + max_range * 0.5)
  ax.set_zlim(mid_z + max_range * 0.5, mid_z - max_range * 0.5)



  ax = fig.add_subplot(122)

  pose = pose3d[frame_idx]
  # pose /= 1000.0
  pose = pose - np.mean(pose, axis=0)

  pose = pose[:, [0, 2, 1]]
  pose[:, 2] *= -1

  ax.plot(pose[:, 0], pose[:, 2], '.')
  for e in joint_edges:
      ax.plot(pose[e, 0], pose[e, 2], 'k')
  for i, p in enumerate(pose):
      ax.text(p[0]+0.05, p[2], f'{i}: {joint_names[i]}', fontsize=8)

  ax.axis('equal')

  # Show the plot
  path_grafico_2 = os.path.join(output_dir, f"{base_filename}_02_visualizacao_esqueleto.png")
  plt.savefig(path_grafico_2)
  #plt.show()
  plt.close(fig)
  results['grafico_esqueleto_3d_2d'] = path_grafico_2

  #---------------------------------------------------------------------

  with open('keypoints3d.npz', 'rb') as f:
    pose3d = np.load(f, allow_pickle=True)['arr_0']

  # exclude any frames where people were missed to make the code work reliably
  pose3d = np.array([p[0] for p in accumulated['poses3d']])

  # build a dataset that includes the timestamps and 3D pose estimates

  #---------------------------------------------------------------------

  # convert pose to m
  pose = pose3d
  pose = pose[:, :, [0, 2, 1]]
  pose[:, :, 2] *= -1
  pose /= 1000.0

  pose = pose - np.min(pose, axis=1, keepdims=True)

  timestamps = jnp.arange(len(pose)) / 30.0

  dataset = (timestamps, pose)

  #-----------------------------------------------------------------

  # construct a loss function between the forward pass through the forward kinematic
  # implicit representation and the resulting keypoint and the detected keypoitns

  def loss(
      model: KineticsWrapper,
      x: Float[Array, "times"],
      y: Float[Array, "times keypoints 3"],
      site_offset_regularization = 1e-1
  ) -> Tuple[Float, Dict]:

      timestamps = x
      keypoints3d = y
      metrics = {}

      # NOTE: steps is an make sure this retraces for different dimensions
      (state, constraints, next_states), (ang, vel, action), _ = model(
          timestamps,
          skip_vel=True,
          skip_action=True,
      )

      pred_kp3d = state.site_xpos

      l = jnp.mean((pred_kp3d - keypoints3d) ** 2) * 100 # so in cm
      metrics["kp_err"] = l

      # regularize marker offset
      l_site_offset = jnp.sum(jnp.square(model.site_offsets))
      l += l_site_offset * site_offset_regularization

      # make loss the first key in the dictionary by popping and building a new dictionary with the rest
      metrics = {"loss": l, **metrics}

      return l, metrics


  @eqx.filter_jit
  def step(model, opt_state, data, loss_grad, optimizer, **kwargs):
      x, targets = data

      (val, metrics), grads = loss_grad(model, x=x, y=targets, **kwargs)
      updates, opt_state = optimizer.update(grads, opt_state, model)
      model = eqx.apply_updates(model, updates)
      return val, model, opt_state, metrics


  def fit_model(
      model: KineticsWrapper,
      dataset: Tuple,
      lr_end_value: float = 1e-8,
      lr_init_value: float = 1e-4,
      max_iters: int = 5000,
      clip_by_global_norm: float = 0.1,
  ):

      # work out the transition steps to make the desired schedule
      transition_steps = 10
      lr_decay_rate = (lr_end_value / lr_init_value) ** (1.0 / (max_iters // transition_steps))
      learning_rate = optax.warmup_exponential_decay_schedule(
          init_value=0,
          warmup_steps=0,
          peak_value=lr_init_value,
          end_value=lr_end_value,
          decay_rate=lr_decay_rate,
          transition_steps=transition_steps,
      )

      optimizer = optax.chain(
          optax.adamw(learning_rate=learning_rate, b1=0.8, weight_decay=1e-5), optax.zero_nans(), optax.clip_by_global_norm(clip_by_global_norm)
      )
      opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

      loss_grad = eqx.filter_value_and_grad(loss, has_aux=True)

      counter = trange(max_iters)
      for i in counter:

          val, model, opt_state, metrics = step(model, opt_state, dataset, loss_grad, optimizer)

          if i > 0 and i % int(max_iters // 10) == 0:
              print(f"\niter: {i} loss: {val}.")  # metrics: {metrics}")

          if i % 50 == 0:
              metrics = {k: v.item() for k,v in metrics.items()}
              counter.set_postfix(metrics)

      return model, metrics


  fkw = get_default_wrapper()
  updated_model, metrics = fit_model(fkw, dataset)

  #----------------------------------------------------------------

  # ----------------------------------------------------------------
  # ALTERAÇÃO AQUI: Lógica de Seleção Dinâmica das Articulações
  # ----------------------------------------------------------------

  # Dicionário mapeando a escolha da interface para os nomes internos do modelo FK
  # NOTA: Verifique se os nomes ('hip_flexion_r', etc) batem com seu fk.joint_names
  # Dicionário atualizado com os nomes EXATOS do seu modelo
  mapa_articulacoes = {
      # Membros Inferiores
      "Joelho":    {'nomes': ['knee_angle_r', 'knee_angle_l'],   'titulo': 'Flexão do Joelho'},
      "Quadril":   {'nomes': ['hip_flexion_r', 'hip_flexion_l'], 'titulo': 'Flexão do Quadril'},
      "Tornozelo": {'nomes': ['ankle_angle_r', 'ankle_angle_l'], 'titulo': 'Angulação do Tornozelo'},

      # Membros Superiores
      "Ombro":     {'nomes': ['arm_flex_r', 'arm_flex_l'],       'titulo': 'Flexão do Ombro'},
      "Cotovelo":  {'nomes': ['elbow_flex_r', 'elbow_flex_l'],   'titulo': 'Flexão do Cotovelo'},
      "Punho":     {'nomes': ['wrist_flex_r', 'wrist_flex_l'],   'titulo': 'Flexão do Punho'},
  }

  # Fallback para Joelho se não encontrar
  selecao = mapa_articulacoes.get(joint_selection, mapa_articulacoes["Joelho"])
  target_joints = selecao['nomes']
  plot_title = selecao['titulo']

  # pass the timestamps through the fitted model to get the kinematics
  (state, constraints, next_states), (ang, vel, action), _ = updated_model(dataset[0], skip_vel=True, skip_action=True)

  # Busca dinâmica dos índices
  try:
      joint_idxs = jnp.array([fk.joint_names.index(n) for n in target_joints])

      fig, ax = plt.subplots(figsize=(10, 5))
      # Plota os dados (assumindo que são pares Direito/Esquerdo)
      ax.plot(dataset[0], -np.degrees(ang[:, joint_idxs]))
      ax.set_xlabel('Tempo (s)')
      ax.set_ylabel('Ângulo (graus)')
      ax.set_title(f'{plot_title} ao Longo do Tempo')
      ax.legend(['Direito', 'Esquerdo']) # Ajuste conforme a ordem da lista

      path_grafico_3 = os.path.join(output_dir, f"{base_filename}_03_angulo_{joint_selection.lower()}.png")
      plt.savefig(path_grafico_3)
      plt.close(fig)
      results['grafico_angulos'] = path_grafico_3

  except ValueError as e:
      print(f"Erro ao encontrar articulações: {e}")
      # Lógica de erro ou fallback aqui

  '''  # pass the timestamps through the fitted model to get the kinematics
  (state, constraints, next_states), (ang, vel, action), _ = updated_model(dataset[0], skip_vel=True, skip_action=True)

  knee_idx = jnp.array([fk.joint_names.index(n) for n in ['knee_angle_r', 'knee_angle_l']])
  # plot the knees
  #plt.figure()
  fig, ax = plt.subplots(figsize=(10, 5))
  ax.plot(dataset[0], -np.degrees(ang[:, knee_idx]))
  ax.set_xlabel('Tempo (s)')
  ax.set_ylabel('Ângulo do Joelho (graus)')
  ax.set_title('Flexão/Extensão do Joelho ao Longo do Tempo')
  ax.legend(['Joelho Direito', 'Joelho Esquerdo'])
  path_grafico_3 = os.path.join(output_dir, f"{base_filename}_03_angulo_joelhos.png")
  plt.savefig(path_grafico_3)
  plt.close(fig)
  results['grafico_angulo_joelhos'] = path_grafico_3'''

  #---------------------------------------------------------------------

  # And create a MuJoCo visualization

  fn = os.path.join(output_dir, f"{base_filename}_reconstrucao.mp4")
  render_trajectory(ang, fn, xml_path=None)
  results['video_reconstrucao'] = fn
  HTML = jupyter_embed_video(fn)
  HTML


  #----------------------------------------------------------

  # there are many skeleton formats support by this model. we are selecting one
  # compatible with the gait transformer we will use below
  skeleton = 'mpi_inf_3dhp_17'

  # get the joint names and the edges between them for visualization below
  joint_names = model.per_skeleton_joint_names[skeleton].numpy().astype(str)
  joint_edges = model.per_skeleton_joint_edges[skeleton].numpy()


  vid, n_frames = video_reader(video_filepath)

  print(f'About to processs {video_filepath} which has {n_frames} frames')
  accumulated = None
  for i, frame_batch in tqdm(enumerate(vid), total=n_frames//8):
      # use this for portrait videos on cell phone that are not detected
      if rotated:
          frame_batch = frame_batch.transpose(0, 2, 1, 3)

      pred = model.detect_poses_batched(frame_batch, skeleton=skeleton)

      if accumulated is None:
          accumulated = pred

      else:
          # concatenate the ragged tensor along the batch for each element in the dictionary
          for key in accumulated.keys():
              accumulated[key] = tf.concat([accumulated[key], pred[key]], axis=0)

  #----------------------------------------------------------------------------------------

  num_people = [p.shape[0] for p in accumulated['poses2d']]

  if 0 in set(num_people):
      print('**WARNING** some frames with no people, make fail')

  # assert this is 1 for all the frames
  #assert len(set(num_people)) == 1

  # then extract the information for that person
  boxes = np.array([p[0] for p in accumulated['boxes'] if len(p) > 0])
  pose3d = np.array([p[0] for p in accumulated['poses3d'] if len(p) > 0])
  pose2d = np.array([p[0] for p in accumulated['poses2d'] if len(p) > 0])

  #----------------------------------------------------------------------------------------

  # this is the order of joints from the Gast-NET algorithm that the gait transformer was originally trained on
  expected_order = ['pelv', 'rhip', 'rkne', 'rank', 'lhip', 'lkne', 'lank', 'spin', 'neck', 'head', 'htop', 'lsho', 'lelb', 'lwri', 'rsho', 'relb', 'rwri']
  expected_order_idx = np.array([joint_names.tolist().index(j) for j in expected_order])

  # GastNet also produces centered data
  #keypoints = pose3d - np.mean(pose3d, axis=1, keepdims=True)
  keypoints = pose3d - pose3d[:, joint_names.tolist().index('pelv'), None]

  # we should also convert the mm output from MeTRAbs to the expected meters
  keypoints = keypoints / 1000.0      # convert mm to m

  # finally convert the axis order and signs to be compatible
  keypoints = keypoints[:, :, [0, 2, 1]]
  keypoints = keypoints[:, expected_order_idx]
  keypoints[:, :, 2] *= -1

  #----------------------------------------------------------------------------------------

  transformer_model = load_default_model()

  #------------------------------------------------------------------------------------------

  # include the height of the participant
  # nominally this should be correct but won't affect timing

  height_mm = 1778

  # set the window length the transformer processes. this should be enough
  # to get at least a gait cycle or two
  L = 90

  phase, stride = gait_phase_stride_inference(keypoints, height_mm, transformer_model, L)

  #---------------------------------------------------------------------------------------------

  ### ALTERAÇÃO: Salva o Gráfico 4 (Fase da Marcha) ###
  fig, ax = plt.subplots(figsize=(10, 5))
  ax.plot(timestamps, phase[:, :4])
  ax.set_title('Componentes da Fase da Marcha')
  ax.set_xlabel('Tempo (s)')
  ax.set_ylabel('Valor do Componente')
  ax.legend(['cos(1)', 'cos(2)', 'cos(3)', 'cos(4)'])
  path_grafico_4 = os.path.join(output_dir, f"{base_filename}_04_fase_marcha.png")
  plt.savefig(path_grafico_4)
  plt.close(fig)
  results['grafico_fase_marcha'] = path_grafico_4

  #----------------------------------------------------------------------------------------------

  ### ALTERAÇÃO: Salva o Vídeo 2 (Overlay) ###
  video_overlay_path = os.path.join(output_dir, f"{base_filename}_overlay.mp4")
  phase_ordered = np.take(phase, [0, 4, 1, 5, 2, 6, 3, 7], axis=-1)
  make_overlay(video_filepath, phase_ordered, stride, pose2d, video_overlay_path)
  results['video_overlay'] = video_overlay_path

  #-------------------------------------------------------------------------------

  # kalman filter expects cos, sin alternating instead of the output from the gait transformer
  # which is the four cos and then the four sin

  phase_ordered = np.take(phase, [0, 4, 1, 5, 2, 6, 3, 7], axis=-1)
  state, predictions, errors = gait_kalman_smoother(phase_ordered)

  ### ALTERAÇÃO: Salva o Gráfico 5 (Erro do Filtro de Kalman) ###
  fig, ax = plt.subplots(figsize=(10, 5))
  ax.plot(timestamps, errors)
  ax.set_title('Erro de Reconstrução do Filtro de Kalman')
  ax.set_xlabel('Tempo (s)')
  ax.set_ylabel('Erro')
  path_grafico_5 = os.path.join(output_dir, f"{base_filename}_05_erro_kalman.png")
  plt.savefig(path_grafico_5)
  plt.close(fig)
  results['grafico_erro_kalman'] = path_grafico_5

  #-------------------------------------------------------------------------------------

  ### ALTERAÇÃO: Salva o Gráfico 6 (Estado do Filtro de Kalman) ###
  fig, ax = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
  fig.suptitle('Estado Estimado pelo Filtro de Kalman')
  ax[0].plot(timestamps, state[:, 0])
  ax[0].set_ylabel('Ciclos (rad)')
  ax[1].plot(timestamps, state[:, 1], label='Cadência (rad/s)')
  ax[1].plot(timestamps, state[:, 2], label='$\phi_0 (rad)')
  ax[1].plot(timestamps, state[:, 3], label='$\phi_1 (rad)')
  ax[1].plot(timestamps, state[:, 4], label='$\phi_2 (rad)')
  ax[1].set_ylabel('Fase')
  ax[1].set_xlabel('Tempo (s)')
  ax[1].legend()
  path_grafico_6 = os.path.join(output_dir, f"{base_filename}_06_estado_kalman.png")
  plt.savefig(path_grafico_6)
  plt.close(fig)
  results['grafico_estado_kalman'] = path_grafico_6

  #---------------------------------------------------------------------------

  # Get the timestamps for gait events

  timestamps = np.arange(state.shape[0]) / 30.0
  get_event_times(state, timestamps)

  print("\nAnálise concluída com sucesso!")
  return results

# ===============================================================
# PARTE 3: LÓGICA DA API E GERENCIAMENTO DE TAREFAS
# ===============================================================
import uuid
import threading
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import FileResponse
import asyncio

# Armazenamento em memória para o status dos trabalhos
# Em um sistema de produção, isso seria um banco de dados como Redis ou similar.
# Armazenamento em memória
jobs = {}

def run_processing(job_id: str, video_path: str, output_path: str, joint_selection: str):
    """Função que executa o processamento pesado em uma thread separada."""
    try:
        print(f"--- [Job {job_id}] Iniciando processamento ({joint_selection}) ---")

        # 1. Executa o processamento (UMA VEZ APENAS)
        results_paths = processador_de_video(video_path, output_path, joint_selection=joint_selection)

        # Verificação de segurança
        if not results_paths:
            raise Exception("O processador retornou vazio (nenhum arquivo gerado).")

        print(f"--- [Job {job_id}] Processamento finalizado. Preparando lista de arquivos... ---")

        # 2. Extrai apenas os nomes dos arquivos para a resposta da API
        # Adicionei uma verificação para garantir que o arquivo existe mesmo
        result_files = []
        for key, path in results_paths.items():
            if path and isinstance(path, str) and os.path.exists(path):
                result_files.append(os.path.basename(path))
            else:
                print(f"AVISO: Arquivo esperado não encontrado ou nulo: {path}")

        # 3. Atualiza o status
        jobs[job_id]['resultados'] = result_files
        jobs[job_id]['status'] = 'concluido'

        print(f"--- [Job {job_id}] STATUS DEFINIDO COMO: CONCLUIDO ---")

    except Exception as e:
        import traceback
        traceback.print_exc() # Imprime o erro completo no log do Colab
        print(f"❌ ERRO no job {job_id}: {e}")
        jobs[job_id]['status'] = 'erro'
        jobs[job_id]['error_message'] = str(e)

# ===============================================================
# PARTE 4: DEFINIÇÃO DOS ENDPOINTS DA API COM FASTAPI
# ===============================================================
app = FastAPI()

@app.post("/processar")
async def processar_video(file: UploadFile = File(...), joint_selection: str = Form("Joelho")): # <--- Novo parâmetro vindo do Form Data):
    """Recebe um vídeo, inicia o processamento e retorna um job_id."""
    job_id = str(uuid.uuid4())

    # Cria diretórios para uploads e resultados
    upload_dir = "uploads"
    results_dir = os.path.join("resultados", job_id)
    os.makedirs(upload_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)

    video_path = os.path.join(upload_dir, f"{job_id}_{file.filename}")

    with open(video_path, "wb") as buffer:
        buffer.write(await file.read())

    # Armazena o status inicial do job
    jobs[job_id] = {'status': 'processando', 'resultados': None}

    # Inicia o processamento em uma thread separada para não bloquear a API
    thread = threading.Thread(target=run_processing, args=(job_id, video_path, results_dir, joint_selection))
    thread.start()

    return {"message": "Processamento iniciado", "job_id": job_id}

@app.get("/status/{job_id}")
async def get_status(job_id: str):
    """Verifica o status de um trabalho de processamento."""
    job = jobs.get(job_id)
    if not job:
        raise HTTPException(status_code=404, detail="Job não encontrado")
    return job

@app.get("/resultados/{job_id}/{nome_arquivo}")
async def get_resultado(job_id: str, nome_arquivo: str):
    """Permite o download de um arquivo de resultado."""
    job = jobs.get(job_id)
    if not job or job['status'] != 'concluido':
        raise HTTPException(status_code=404, detail="Job não concluído ou não encontrado")

    file_path = os.path.join("resultados", job_id, nome_arquivo)

    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="Arquivo não encontrado")

    return FileResponse(path=file_path, media_type='application/octet-stream', filename=nome_arquivo)



In [None]:
# ===============================================================
# CÉLULA 2 (VERSÃO FINAL E ROBUSTA): Execução da API com URL Estática
# ===============================================================
import uvicorn
from pyngrok import ngrok, conf
from google.colab import userdata

# 1. Configuração de Autenticação
try:
    NGROK_TOKEN = "33skyLswrrFrO1SFfmXox0W9aWT_7C2hC78WLg9zn1nMx13vj"
    # Opcional: Pegar o domínio dos secrets para não deixar exposto no código
    # Se não tiver no secret, você pode escrever a string direto abaixo
    NGROK_DOMAIN = "toucan-glorious-fowl.ngrok-free.app" # Ex: "rufino-api.ngrok-free.app"

    if NGROK_TOKEN:
        ngrok.set_auth_token(NGROK_TOKEN)
    else:
        print("ERRO CRÍTICO: Secret 'NGROK_AUTHTOKEN' não encontrado.")
except Exception as e:
    print(f"Erro ao carregar secrets: {e}")

# 2. Abre o túnel do ngrok com Domínio Estático
port = 8000

# Fecha túneis anteriores para evitar conflitos ao reexecutar a célula
ngrok.kill()

try:
    if NGROK_DOMAIN:
        # Tenta conectar usando o domínio fixo
        ssh_tunnel = ngrok.connect(port, domain=NGROK_DOMAIN)
        print(f"✅ API rodando em URL ESTÁTICA: {ssh_tunnel.public_url}")
    else:
        # Fallback: Se não tiver domínio configurado, abre um aleatório
        print("⚠️ NGROK_DOMAIN não definido. Usando URL aleatória.")
        ssh_tunnel = ngrok.connect(port)
        print(f"✅ API rodando em URL ALEATÓRIA: {ssh_tunnel.public_url}")
except Exception as e:
    print(f"❌ Erro ao conectar ngrok: {e}")
    print("Dica: Verifique se o domínio está correto no Dashboard do ngrok.")

# 3. Inicia o servidor Uvicorn
config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="info")
server = uvicorn.Server(config)

await server.serve()