# Extract respiratory signals with FlowNet2 optical flow

Based on "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks" is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband16'
setting = '101_natural_lighting'

video_path = dataset.get_video_path(subject, setting)

## Load the FlowNet model

In [None]:
import torch
import respiration.utils as utils

from respiration.extractor.flownet import (
    FlowNet2,
    resize_and_center_frames,
)

device = utils.get_torch_device()
path = utils.file_path('data', 'flownet', 'FlowNet2_checkpoint.pth')
loaded = torch.load(path)

model = FlowNet2(batch_norm=False)
model.load_state_dict(loaded['state_dict'])
model = model.to(device)
model.eval()

## Extract optical flow from the video

In [None]:
import torch
import numpy as np
import respiration.utils as utils

from tqdm.auto import tqdm
from torchvision import transforms

param = utils.get_video_params(video_path)

# Only get the first 12 seconds of the video
param.num_frames = param.fps * 30

# Number of frames that are processed at once
batch_size = 10

new_dim = 640

# Store the optical flows vectors (N, 2, H, W)
optical_flows = np.zeros((param.num_frames - 1, 2, new_dim, new_dim), dtype=np.float32)

# Extract the optical flow from the video in batches
for start in tqdm(range(0, param.num_frames, batch_size - 1)):
    num_frames = min(batch_size, param.num_frames - start)
    frames, _ = utils.read_video_rgb(video_path, num_frames, start)

    preprocess = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        # Center Crop the frames
        transforms.CenterCrop((new_dim, new_dim)),
        transforms.ToTensor()
    ])

    frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
    frames = frames.to(device)

    # Fold the frames into (T, C, 2, H, W)
    unfolded_frames = frames.unfold(0, 2, 1).permute(0, 1, 4, 2, 3)

    with torch.no_grad():
        flows = model(unfolded_frames)

    # Garbage collect...
    del frames, unfolded_frames

    for idx in range(flows.shape[0]):
        # Add the optical flow to the numpy array
        optical_flows[start + idx] = flows[idx].cpu().numpy()

## Visualize the optical flow

In [None]:
figure_dir = utils.dir_path('outputs', 'figures', 'flownet', mkdir=True)

In [None]:
import matplotlib.pyplot as plt
import respiration.extractor.optical_flow_raft as raft

frames, _ = utils.read_video_rgb(video_path, 1, 1)
frames = resize_and_center_frames(frames, new_dim)

frame = frames[0].cpu().numpy().transpose(1, 2, 0)

# Make frames ints 0..255
frame = (frame * 255).astype(np.uint8)

arrow_frame = raft.draw_flow(frame, optical_flows[0])
flow_frame = raft.image_from_flow(optical_flows[0])

fig, ax = plt.subplots(1, 2, figsize=(20, 8))
ax[0].imshow(arrow_frame)
ax[0].set_title('Optical flow arrows')
ax[1].imshow(flow_frame)
ax[1].set_title('Optical flow magnitude')

utils.savefig(fig, figure_dir, 'optical_flow')

## Extract the respiratory signal

1. Find the region of interest (ROI) on the chest
2. Calculate the motion magnitude in the ROI
3. Plot the motion magnitude over time

In [None]:
import respiration.roi as roi

# Find the chest region
x, y, w, h = roi.detect_chest(frame)

# Get only the optical flows in the chest region
roi_flows = optical_flows[:, :, y:y + h, x:x + w]

# Calculate motion magnitude by squaring the x and y components and taking the square root
magnitudes = np.sqrt(roi_flows[:, 1] ** 2)

In [None]:
mean_curve = np.mean(magnitudes, axis=(1, 2))
std_curve = np.std(magnitudes, axis=(1, 2))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ax.plot(mean_curve, label='Mean')
ax.fill_between(
    np.arange(len(mean_curve)),
    mean_curve - std_curve,
    mean_curve + std_curve,
    alpha=0.3,
    label='Standard deviation')

ax.set_title('Motion magnitude in the ROI')
ax.set_xlabel('Frame')
ax.set_ylabel('Motion magnitude')

utils.savefig(fig, figure_dir, 'roi_magnitudes')

In [None]:
import respiration.analysis as analysis

respiratory_gt = dataset.get_breathing_signal(subject, setting)[1:param.num_frames]

comparator = analysis.SignalComparator(respiratory_gt, mean_curve, sample_rate=param.fps)
utils.pretty_print(comparator.errors())

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ax.plot(comparator.prediction, label='Prediction')
ax.plot(comparator.ground_truth, label='Ground truth')
ax.set_title('Respiratory signal')
ax.set_xlabel('Frame')
ax.set_ylabel('Motion magnitude')
ax.legend()