# MTTS-CAN

This notebook demonstrates how to extract respiration and pulse from a video using the MTTS-CAN model.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband05'
setting = '101_natural_lighting'

In [None]:
frames, meta = dataset.get_video_rgb(
    subject,
    setting,
    num_frames=30 * 10,
    show_progress=True,
)

## Preprocessing

The MTTS-CAN model expects the frames to be resized and normalized in the temporal domain.

In [None]:
from respiration.extractor.mtts_can import (
    preprocess_video_frames,
    preprocess_frames_original,
)

# resized, normalized = preprocess_frames_original(frames)
resized, normalized = preprocess_video_frames(frames)

In [None]:
import matplotlib.pyplot as plt
import respiration.utils as utils

# Plot resized and normalized frames
_, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].imshow(resized[1])
axs[0].set_title('Resized Frame')

axs[1].imshow(normalized[1])
axs[1].set_title('Normalized Frame')

figure_dir = utils.dir_path('outputs', 'figures', 'pre-process', mkdir=True)
utils.savefig(plt.gcf(), figure_dir, 'mtts_can')

## Prediction

In [None]:
import matplotlib.pyplot as plt

from respiration.extractor.mtts_can import load_model

frame_depth = 10

# The model expects a number of frames that is a multiple of frame_depth
cut_off = (normalized.shape[0] // frame_depth) * frame_depth
input_resized = resized[:cut_off]
input_normalized = normalized[:cut_off]

In [None]:
model = load_model()
predictions = model.predict(
    (input_resized, input_normalized),
    batch_size=100
)

In [None]:
import numpy as np

pulse_prediction = np.cumsum(predictions[0])
respiration_prediction = np.cumsum(predictions[1])

## Show the predicted pulse

In [None]:
import respiration.utils as utils
import respiration.analysis as analysis

pulse_gt = dataset.get_vital_sign(subject, setting, utils.VitalSigns.pulse)

# Not all frames are used for prediction --> cut the ground truth to the same length
pulse_gt = pulse_gt[:pulse_prediction.shape[0]]

pulse_compare = analysis.SignalCompare(
    pulse_prediction,
    pulse_gt,
    meta.fps,
    lowpass=0.75,
    highpass=2.5,
)

In [None]:
pulse_compare.bpm_errors()

In [None]:
pulse_compare.distances()

In [None]:
# Plot the pulse prediction
_ = plt.figure(figsize=(20, 10))
plt.plot(pulse_compare.prediction, label='Pulse Prediction')
plt.plot(pulse_compare.ground_truth, label='Pulse Ground Truth')
plt.title('Pulse Prediction')
plt.xlabel('Time')
plt.ylabel('Pulse')
plt.legend()
plt.show()

## Show the predicted respiration

In [None]:
respiration_gt = dataset.get_breathing_signal(subject, setting)

# Not all frames are used for prediction --> cut the ground truth to the same length
respiration_gt = respiration_gt[:respiration_prediction.shape[0]]

pulse_compare = analysis.SignalCompare(
    pulse_prediction,
    respiration_prediction,
    meta.fps,
    lowpass=0.75,
    highpass=2.5,
)

In [None]:
pulse_compare.bpm_errors()

In [None]:
pulse_compare.distances()

In [None]:
# Plot the respiration prediction
_ = plt.figure(figsize=(20, 10))
plt.plot(pulse_compare.prediction, label='Respiration Prediction')
plt.plot(pulse_compare.ground_truth, label='Respiration Ground Truth')
plt.title('Respiration Prediction')
plt.xlabel('Time')
plt.ylabel('Respiration')