# Extract respiratory signals with RAFT

Recurrent All-Pairs Field Transforms (RAFT) is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

# The scenarios (subject, setting) to process
scenarios = dataset.get_scenarios(['101_natural_lighting'])

# The RAFT models to use for optical flow estimation
raft_models = [
    'raft_large',
    'raft_small',
]

In [None]:
import respiration.utils as utils

output_dir = utils.dir_path('outputs', 'signals', mkdir=True)
flows_dir = utils.dir_path('outputs', 'raft_flows', mkdir=True)
manifest_file = utils.join_paths(output_dir, 'raft_manifest.json')

## Part 1: Extract optical flows

This part is heavy on computational resources and may take a long time to complete. The optical flows are extracted from the videos and saved to disk. The extracted optical flows are stored in the `outputs/raft_flows` directory.

In [None]:
from datetime import datetime
import respiration.utils as utils

device = utils.get_torch_device()

manifest = {
    'timestamp_start': datetime.now(),
    'scenarios': scenarios,
    'device': device,
    'raft_models': raft_models,
    'flows': [],
    'incomplete_rois': [],
}

In [None]:
import os
import torch
import math
import numpy as np
from tqdm.auto import tqdm
import respiration.extractor.raft as raft

# Number of frames that are processed at once
batch_size = 20

for (subject, setting) in tqdm(scenarios):
    print(f'Processing {subject} - {setting}')

    video_path = dataset.get_video_path(subject, setting)
    param = utils.get_video_params(video_path)

    batches = math.ceil(param.num_frames / (batch_size // 2))

    for raft_model in raft_models:
        model = raft.load_model(raft_model, device)

        # Store the optical flows vectors (N, 2, H, W). N is the number of frames 
        # in the video minus one, because we calculate the optical flow between consecutive frames.
        optical_flows = np.zeros((param.num_frames - 1, 2, param.height, param.width), dtype=np.float32)

        # Extract the optical flow from the video in batches
        for batch in range(0, batches):
            # Calculate the start frame for this batch
            start = (batch_size // 2) * batch

            # Calculate the number of frames to process in this batch
            num_frames = min(start + batch_size, param.num_frames) - start

            chunk, _ = utils.read_video_rgb(video_path, num_frames, start)
            chunk = raft.preprocess(chunk, device)

            with torch.no_grad():
                flows = model(chunk[:-1], chunk[1:])

            # Only keep the last flow iteration
            flows = flows[-1]

            for idx in range(flows.shape[0]):
                # Add the optical flow to the numpy array
                optical_flows[start + idx] = flows[idx].cpu().numpy()

            # Garbage collect...
            del chunk
            del flows

        # Store the extracted signals
        filename = f'{subject}_{setting}_{raft_model}.npy'
        flow_file = os.path.join(flows_dir, filename)
        np.save(flow_file, optical_flows)

        # Garbage collect the optical flows (8.2GB)
        del optical_flows

        manifest['flows'].append({
            'subject': subject,
            'setting': setting,
            'model': raft_model,
            'filename': filename,
        })

In [None]:
manifest['timestamp_finish'] = datetime.now()
utils.write_json(manifest_file, manifest)

## Part 2: Export respiratory signals

This part reads the extracted optical flows and calculates the respiratory signals. The respiratory signals are saved to a CSV file in the `outputs/signals` directory.

In [None]:
# Read the manifest file
manifest = utils.read_json(manifest_file)

In [None]:
import os
import numpy as np
import respiration.roi as roi
from tqdm.auto import tqdm

extracted_signals = []

for (subject, setting) in tqdm(scenarios):
    for raft_model in raft_models:
        filename = f'{subject}_{setting}_{raft_model}.npy'
        flow_file = os.path.join(flows_dir, filename)
        assert os.path.exists(flow_file), f'File not found: {flow_file}'

        optical_flows = np.load(flow_file)

        video_path = dataset.get_video_path(subject, setting)
        params = utils.get_video_params(video_path)
        first_frame = dataset.get_first_frame(subject, setting)
        roi_areas = roi.get_roi_areas(first_frame)
        if len(roi_areas) < 3:
            print(f'Warning: only {len(roi_areas)} ROIs found for {subject} - {setting}')
            manifest['incomplete_rois'].append({
                'subject': subject,
                'setting': setting,
                'rois': [name for (_, name) in roi_areas],
            })

        for ((x, y, w, h), name) in roi_areas:
            # Select the motion vectors in the region of interest (N, 2, H, W)
            flow_region = optical_flows[:, :, y:y + h, x:x + w]

            # Horizontal motion (N, H, W)
            u = flow_region[:, 0, :, :]

            # Vertical motion (N, H, W)
            v = flow_region[:, 1, :, :]

            # Calculate the magnitudes of the motion vectors (N, H, W)
            magnitudes = np.sqrt(u ** 2 + v ** 2)

            # Calculate the mean and standard deviation of the magnitudes
            uv_mean_curve = np.mean(magnitudes, axis=(1, 2))
            uv_std_curve = np.std(magnitudes, axis=(1, 2))

            # Calculate the mean and standard deviation of the vertical motion
            v_mean_curve = v.mean(axis=(1, 2))
            v_std_curve = v.std(axis=(1, 2))

            # Store the extracted signals
            extracted_signals.append({
                'subject': subject,
                'setting': setting,
                'model': raft_model,
                'roi': name,
                'sampling_rate': params.fps,
                'signal_uv': uv_mean_curve.tolist(),
                'signal_uv_std': uv_std_curve.tolist(),
                'signal_v': v_mean_curve.tolist(),
                'signal_v_std': v_std_curve.tolist(),
            })

        del optical_flows

In [None]:
import pandas as pd

signals_df = pd.DataFrame(extracted_signals)
predictions_file = os.path.join(output_dir, 'raft_predictions.csv')
signals_df.to_csv(predictions_file, index=False)

In [None]:
signals_df.head()