# Extract respiratory signals with RAFT

Recurrent All-Pairs Field Transforms (RAFT) is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
import respiration.dataset as repository

dataset = repository.from_default()

subject = 'Proband16'
setting = '201_shouldercheck'

video_path = dataset.get_video_path(subject, setting)

In [None]:
import respiration.utils as utils
import respiration.extractor.optical_flow_raft as raft

device = utils.get_torch_device()
model = raft.load_model('raft_large', device)

In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm

# Number of frames that are processed at once
batch_size = 10

param = utils.get_video_params(video_path)

# Set the number of frames to 30 * 6
# param.num_frames = 30 * 6

# Store the optical flows vectors (N, 2, H, W)
optical_flows = np.zeros((param.num_frames, 2, param.height, param.width), dtype=np.float32)

for start in tqdm(range(0, param.num_frames, batch_size)):
    # Calculate the number of frames to process in this batch
    num_frames = min(start + batch_size, param.num_frames) - start

    frames, _ = utils.read_video_rgb(video_path, num_frames, start)
    frames = raft.preprocess(frames, device)

    with torch.no_grad():
        # Split the frames into odd and even frames to calculate optical flow on consecutive frames
        flows = model(frames[::2], frames[1::2])

    # Garbage collect...
    del frames

    # Only keep the last flow iteration
    flows = flows[-1]

    for idx in range(flows.shape[0]):
        # Add the optical flow to the numpy array
        optical_flows[start + idx] = flows[idx].cpu().numpy()

In [None]:
import numpy as np


def draw_flow(img, flow, step=20):
    """
    Plots the optical flow vectors on the image.
    Args:
    - img: The original image.
    - flow: The optical flow vectors (HxWx2).
    - step: Space between vectors to be drawn.
    """

    h, w = img.shape[:2]
    y, x = np.mgrid[step // 2:h:step, step // 2:w:step].reshape(2, -1).astype(int)
    fx, fy = flow[y, x].T

    # Create an image to draw on
    vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    # Draw arrows
    for (x0, y0, dx, dy) in zip(x, y, fx, fy):
        # Length of the arrow is sqrt(dx^2 + dy^2)
        # length = np.sqrt(dx ** 2 + dy ** 2)
        # if length > 20:
        #     continue

        end_point = (int(x0 + dx), int(y0 + dy))
        cv2.arrowedLine(
            vis,
            (x0, y0),
            end_point,
            color=(255, 0, 0),
            thickness=1,
            tipLength=0.25,
        )

    return vis


def draw_flow_max(img, flow):
    """
    Plots the optical flow vectors on the image.
    Args:
    - img: The original image.
    - flow: The optical flow vectors (HxWx2).
    """

    h, w = img.shape[:2]
    fx, fy = flow[:, :, 0], flow[:, :, 1]

    # Create an image to draw on
    vis = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    # Draw arrows
    for y in range(h):
        for x in range(w):
            dx, dy = fx[y, x], fy[y, x]
            end_point = (int(x + dx), int(y + dy))

            # Length of the arrow is sqrt(dx^2 + dy^2)
            length = np.sqrt(dx ** 2 + dy ** 2)
            if length > 20:
                continue

            cv2.arrowedLine(
                vis,
                (x, y),
                end_point,
                color=(255, 0, 0),
                thickness=1,
                tipLength=0.25,
            )

    return vis


def numpy_flow_to_image(flow: np.ndarray) -> np.ndarray:
    """
    Converts the optical flow vectors to an image.
    Args:
    - flow: The optical flow vectors (2xHxW).
    """
    input = torch.from_numpy(flow).permute(2, 0, 1)
    flow_image = flow_to_image(input)

    return flow_image.numpy().transpose(1, 2, 0)

In [None]:
import cv2
from torchvision.utils import flow_to_image

arrow_video_path = '../../reports/videos/optical_flow_arrow.avi'
arrow_video = cv2.VideoWriter(
    arrow_video_path,
    cv2.VideoWriter_fourcc(*'XVID'),
    param.fps,
    (param.width, param.height))

motion_video_path = '../../reports/videos/optical_flow_motion.avi'
motion_video = cv2.VideoWriter(
    motion_video_path,
    cv2.VideoWriter_fourcc(*'XVID'),
    param.fps,
    (param.width, param.height))

cap = cv2.VideoCapture(video_path)

for idx in tqdm(range(param.num_frames)):
    ret, frame_ = cap.read()
    if not ret:
        break

    flow = optical_flows[idx].transpose(1, 2, 0)
    frame = cv2.cvtColor(frame_, cv2.COLOR_RGB2GRAY)

    arrow_frame = draw_flow(frame, flow)
    arrow_video.write(arrow_frame)

    flow_frame = numpy_flow_to_image(flow)
    motion_video.write(flow_frame)

arrow_video.release()
motion_video.release()
cap.release()