# Extract respiratory signals with RAFT

Recurrent All-Pairs Field Transforms (RAFT) is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband11'
setting = '101_natural_lighting'

video_path = dataset.get_video_path(subject, setting)

## Load the RAFT model

In [None]:
import respiration.utils as utils
import respiration.extractor.raft as raft

device = utils.get_torch_device()
model = raft.load_model('raft_small', device)

## Extract optical flow from the video

In [None]:
import math
import torch
import numpy as np
from tqdm.auto import tqdm

param = utils.get_video_params(video_path)
# Only get the first 30 seconds of the video
param.num_frames = param.fps * 20

# Number of frames that are processed at once
batch_size = 20
batches = math.ceil(param.num_frames / (batch_size // 2))

# Store the optical flows vectors (N, 2, H, W)
optical_flows = np.zeros((param.num_frames - 1, 2, param.height, param.width), dtype=np.float32)

# Extract the optical flow from the video in batches
for batch in tqdm(range(0, batches)):
    start = (batch_size // 2) * batch
    num_frames = min(start + batch_size, param.num_frames) - start
    chunk, _ = utils.read_video_rgb(video_path, num_frames, start)
    chunk = raft.preprocess(chunk, device)

    with torch.no_grad():
        flows = model(chunk[:-1], chunk[1:])

    # Garbage collect...
    del chunk

    # Only keep the last flow iteration
    flows = flows[-1]

    for idx in range(flows.shape[0]):
        # Add the optical flow to the numpy array
        optical_flows[start + idx] = flows[idx].cpu().numpy()

## Visualize the optical flow

In [None]:
figure_dir = utils.dir_path('outputs', 'figures', 'raft', mkdir=True)

In [None]:
import matplotlib.pyplot as plt

frames, _ = utils.read_video_rgb(video_path, 1, 1)
arrow_frame = raft.draw_flow(frames[0], optical_flows[0])
flow_frame = raft.image_from_flow(optical_flows[0])

fig, ax = plt.subplots(1, 2, figsize=(20, 8))
ax[0].imshow(arrow_frame)
ax[0].set_title('Optical flow arrows')
ax[1].imshow(flow_frame)
ax[1].set_title('Optical flow magnitude')

utils.savefig(fig, figure_dir, 'optical_flow')

## Extract the respiratory signal

1. Find the region of interest (ROI) on the chest
2. Calculate the motion magnitude in the ROI
3. Plot the motion magnitude over time

In [None]:
import respiration.roi as roi

# Find the chest region
x, y, w, h = roi.detect_chest(frames[0])

# Get only the optical flows in the chest region
roi_flows = optical_flows[:, :, y:y + h, x:x + w]

# Calculate motion magnitude by squaring the x and y components and taking the square root
# magnitudes = np.sqrt(roi_flows[:, 0] ** 2 + roi_flows[:, 1] ** 2)
magnitudes = np.sqrt(roi_flows[:, 1] ** 2)

In [None]:
mean_curve = np.mean(magnitudes, axis=(1, 2))
std_curve = np.std(magnitudes, axis=(1, 2))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ax.plot(mean_curve, label='Mean')
ax.fill_between(
    np.arange(len(mean_curve)),
    mean_curve - std_curve,
    mean_curve + std_curve,
    alpha=0.3,
    label='Standard deviation')

ax.set_title('Motion magnitude in the ROI')
ax.set_xlabel('Frame')
ax.set_ylabel('Motion magnitude')

utils.savefig(fig, figure_dir, 'roi_magnitudes')

## Compare the ground truth and prediction

In [None]:
lowpass = 0.1
highpass = 0.5

In [None]:
import respiration.analysis as analysis

# Filter and normalize the ground truth signal
ground_truth = dataset.get_breathing_signal(subject, setting)
ground_truth = ground_truth[:param.num_frames]
ground_truth = analysis.normalize_signal(ground_truth)
ground_truth = analysis.butterworth_filter(ground_truth, param.fps, lowpass, highpass)

# Filter and normalize the predicted signal
predicted = analysis.normalize_signal(mean_curve)
predicted = analysis.butterworth_filter(predicted, param.fps, lowpass, highpass)

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(20, 8))
ax[0].plot(ground_truth, label='Ground truth')
ax[0].set_title('Ground truth')
ax[0].set_xlabel('Frame')
ax[0].set_ylabel('Normalized signal')

ax[1].plot(predicted, label='Predicted')
ax[1].set_title('Predicted')
ax[1].set_xlabel('Frame')
ax[1].set_ylabel('Normalized signal')

fig.tight_layout()

utils.savefig(fig, figure_dir, 'comparison')