# Extract respiratory signals with RAFT

Recurrent All-Pairs Field Transforms (RAFT) is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()
scenarios = dataset.get_scenarios(['101_natural_lighting'])

In [None]:
import os

evaluation_dir = os.path.join('..', '..', 'evaluation', 'optical_flow_raft')
os.makedirs(evaluation_dir, exist_ok=True)

flows_dir = os.path.join(evaluation_dir, 'flows')
os.makedirs(flows_dir, exist_ok=True)

## Extract optical flows

In [None]:
from datetime import datetime
import respiration.utils as utils

device = utils.get_torch_device()

raft_models = [
    'raft_large',
    'raft_small',
]

manifest = {
    'timestamp_start': datetime.now(),
    'scenarios': scenarios,
    'device': device,
    'raft_models': raft_models,
    'flows': [],
}

In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm
import respiration.extractor.optical_flow_raft as raft

# Number of frames that are processed at once
batch_size = 10

for (subject, setting) in tqdm(scenarios):
    print(f'Processing {subject} - {setting}')

    video_path = dataset.get_video_path(subject, setting)
    param = utils.get_video_params(video_path)

    for raft_model in raft_models:
        model = raft.load_model(raft_model, device)

        # Store the optical flows vectors (N, 2, H, W)
        optical_flows = np.zeros((param.num_frames, 2, param.height, param.width), dtype=np.float32)

        # Extract the optical flow from the video in batches
        for start in range(0, param.num_frames, batch_size):
            # Calculate the number of frames to process in this batch
            num_frames = min(start + batch_size, param.num_frames) - start

            chunk, _ = utils.read_video_rgb(video_path, num_frames, start)
            chunk = raft.preprocess(chunk, device)

            with torch.no_grad():
                # Split the frames into odd and even frames to calculate optical flow on consecutive frames
                flows = model(chunk[::2], chunk[1::2])

            # Garbage collect...
            del chunk

            # Only keep the last flow iteration
            flows = flows[-1]

            for idx in range(flows.shape[0]):
                # Add the optical flow to the numpy array
                optical_flows[start + idx] = flows[idx].cpu().numpy()

        # Store the extracted signals
        filename = f'{subject}_{setting}_{raft_model}.npy'
        flow_file = os.path.join(flows_dir, filename)
        np.save(flow_file, optical_flows)

        # Garbage collect the optical flows (8.2GB)
        del optical_flows

        manifest['flows'].append({
            'subject': subject,
            'setting': setting,
            'model': raft_model,
            'filename': filename,
        })

In [None]:
manifest['timestamp_finish'] = datetime.now()
manifest_file = os.path.join(evaluation_dir, 'manifest.json')
utils.write_json(manifest_file, manifest)

## Extract respiratory signals

In [None]:
import numpy as np
from tqdm.auto import tqdm

In [None]:
import respiration.roi as roi

extracted_signals = []

for (subject, setting) in tqdm(scenarios):
    for raft_model in raft_models:
        filename = f'{subject}_{setting}_{raft_model}.npy'
        flow_file = os.path.join(flows_dir, filename)
        assert os.path.exists(flow_file)

        optical_flows = np.load(flow_file)

        first_frame = dataset.get_first_frame(subject, setting)
        roi_areas = roi.get_roi_areas(first_frame)

        for ((x, y, w, h), name) in roi_areas:
            # Select the motion vectors in the region of interest
            flow_region = optical_flows[:, :, y:y + h, x:x + w]

            # Calculate the magnitudes of the motion vectors
            magnitudes = np.sqrt(np.sum(flow_region ** 2, axis=(1, 2)))

            # Calculate the mean and standard deviation of the magnitudes
            mean_curve = np.mean(magnitudes, axis=1)
            std_curve = np.std(magnitudes, axis=1)

            extracted_signals.append({
                'subject': subject,
                'setting': setting,
                'model': raft_model,
                'roi': name,
                'signal': mean_curve.tolist(),
                'signal_std': std_curve.tolist(),
            })

        del optical_flows

In [None]:
import pandas as pd

signals_df = pd.DataFrame(extracted_signals)
predictions_file = os.path.join(evaluation_dir, 'predictions.csv')
signals_df.to_csv(predictions_file, index=False)

In [None]:
signals_df.head()