# Extract respiratory signals with FlowNet2 optical flow

Based on "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks" is a deep learning model for optical flow estimation. The optical flow directions and magnitudes can be used to extract respiratory signals from videos. This notebook demonstrates how to use RAFT to extract respiratory signals from videos.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband16'
setting = '101_natural_lighting'

video_path = dataset.get_video_path(subject, setting)

## Define FlowNet 2.0 model

In [None]:
import numpy as np
import torch.nn as nn


def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
                      bias=False),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.1, inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
                      bias=True),
            nn.LeakyReLU(0.1, inplace=True)
        )


def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias=True):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
                      bias=bias),
            nn.BatchNorm2d(out_planes),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2,
                      bias=bias),
        )


def predict_flow(in_planes):
    return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=True)


def deconv(in_planes, out_planes):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
        nn.LeakyReLU(0.1, inplace=True)
    )


class tofp16(nn.Module):
    def __init__(self):
        super(tofp16, self).__init__()

    def forward(self, input):
        return input.half()


class tofp32(nn.Module):
    def __init__(self):
        super(tofp32, self).__init__()

    def forward(self, input):
        return input.float()


def init_deconv_bilinear(weight):
    f_shape = weight.size()
    heigh, width = f_shape[-2], f_shape[-1]
    f = np.ceil(width / 2.0)
    c = (2 * f - 1 - f % 2) / (2.0 * f)
    bilinear = np.zeros([heigh, width])
    for x in range(width):
        for y in range(heigh):
            value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
            bilinear[x, y] = value
    weight.data.fill_(0.)
    for i in range(f_shape[0]):
        for j in range(f_shape[1]):
            weight.data[i, j, :, :] = torch.from_numpy(bilinear)


def save_grad(grads, name):
    def hook(grad):
        grads[name] = grad

    return hook

In [None]:
import torch.nn as nn
from torch.nn import init


class FlowNetSD(nn.Module):
    def __init__(self, batchNorm=True):
        super(FlowNetSD, self).__init__()

        self.batchNorm = batchNorm
        self.conv0 = conv(self.batchNorm, 6, 64)
        self.conv1 = conv(self.batchNorm, 64, 64, stride=2)
        self.conv1_1 = conv(self.batchNorm, 64, 128)
        self.conv2 = conv(self.batchNorm, 128, 128, stride=2)
        self.conv2_1 = conv(self.batchNorm, 128, 128)
        self.conv3 = conv(self.batchNorm, 128, 256, stride=2)
        self.conv3_1 = conv(self.batchNorm, 256, 256)
        self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
        self.conv4_1 = conv(self.batchNorm, 512, 512)
        self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
        self.conv5_1 = conv(self.batchNorm, 512, 512)
        self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
        self.conv6_1 = conv(self.batchNorm, 1024, 1024)

        self.deconv5 = deconv(1024, 512)
        self.deconv4 = deconv(1026, 256)
        self.deconv3 = deconv(770, 128)
        self.deconv2 = deconv(386, 64)

        self.inter_conv5 = i_conv(self.batchNorm, 1026, 512)
        self.inter_conv4 = i_conv(self.batchNorm, 770, 256)
        self.inter_conv3 = i_conv(self.batchNorm, 386, 128)
        self.inter_conv2 = i_conv(self.batchNorm, 194, 64)

        self.predict_flow6 = predict_flow(1024)
        self.predict_flow5 = predict_flow(512)
        self.predict_flow4 = predict_flow(256)
        self.predict_flow3 = predict_flow(128)
        self.predict_flow2 = predict_flow(64)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)

            if isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)
                # init_deconv_bilinear(m.weight)
        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, x):
        out_conv0 = self.conv0(x)
        out_conv1 = self.conv1_1(self.conv1(out_conv0))
        out_conv2 = self.conv2_1(self.conv2(out_conv1))

        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6 = self.predict_flow6(out_conv6)
        flow6_up = self.upsampled_flow6_to_5(flow6)
        out_deconv5 = self.deconv5(out_conv6)

        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
        out_interconv5 = self.inter_conv5(concat5)
        flow5 = self.predict_flow5(out_interconv5)

        flow5_up = self.upsampled_flow5_to_4(flow5)
        out_deconv4 = self.deconv4(concat5)

        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
        out_interconv4 = self.inter_conv4(concat4)
        flow4 = self.predict_flow4(out_interconv4)
        flow4_up = self.upsampled_flow4_to_3(flow4)
        out_deconv3 = self.deconv3(concat4)

        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
        out_interconv3 = self.inter_conv3(concat3)
        flow3 = self.predict_flow3(out_interconv3)
        flow3_up = self.upsampled_flow3_to_2(flow3)
        out_deconv2 = self.deconv2(concat3)

        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
        out_interconv2 = self.inter_conv2(concat2)
        flow2 = self.predict_flow2(out_interconv2)

        if self.training:
            return flow2, flow3, flow4, flow5, flow6
        else:
            return flow2,

In [None]:
class FlowNet2SD(FlowNetSD):
    def __init__(self, args: dict, batchNorm=False, div_flow=20):
        super(FlowNet2SD, self).__init__(batchNorm=batchNorm)
        self.rgb_max = args['rgb_max']
        self.div_flow = div_flow

    def forward(self, inputs):
        rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
        x = (inputs - rgb_mean) / self.rgb_max
        x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1)

        out_conv0 = self.conv0(x)
        out_conv1 = self.conv1_1(self.conv1(out_conv0))
        out_conv2 = self.conv2_1(self.conv2(out_conv1))

        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6 = self.predict_flow6(out_conv6)
        flow6_up = self.upsampled_flow6_to_5(flow6)
        out_deconv5 = self.deconv5(out_conv6)

        concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
        out_interconv5 = self.inter_conv5(concat5)
        flow5 = self.predict_flow5(out_interconv5)

        flow5_up = self.upsampled_flow5_to_4(flow5)
        out_deconv4 = self.deconv4(concat5)

        concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
        out_interconv4 = self.inter_conv4(concat4)
        flow4 = self.predict_flow4(out_interconv4)
        flow4_up = self.upsampled_flow4_to_3(flow4)
        out_deconv3 = self.deconv3(concat4)

        concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
        out_interconv3 = self.inter_conv3(concat3)
        flow3 = self.predict_flow3(out_interconv3)
        flow3_up = self.upsampled_flow3_to_2(flow3)
        out_deconv2 = self.deconv2(concat3)

        concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
        out_interconv2 = self.inter_conv2(concat2)
        flow2 = self.predict_flow2(out_interconv2)

        if self.training:
            return flow2, flow3, flow4, flow5, flow6
        else:
            return self.upsample1(flow2 * self.div_flow)

## Load the FlowNet model

In [None]:
import torch
import respiration.utils as utils

device = utils.get_torch_device()
path = utils.file_path('data', 'flownet', 'FlowNet2-SD_checkpoint.pth')
loaded = torch.load(path)

args = {
    'rgb_max': 255,
    'div_flow': 20,
}
model = FlowNet2SD(args, batchNorm=False)
model.load_state_dict(loaded['state_dict'])
model = model.to(device)
model.eval()

## Extract optical flow from the video

In [None]:
import cv2


def resize_and_center_frames(frames: np.ndarray, target_size: (int, int)):
    resized_frames = []

    for frame in frames:
        # Resize frame to target size while maintaining aspect ratio
        height, width = frame.shape[:2]
        aspect_ratio = width / height

        if width > height:
            new_width = target_size[0]
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = target_size[1]
            new_width = int(new_height * aspect_ratio)

        resized_frame = cv2.resize(frame, (new_width, new_height))

        # Create a new blank frame with target size
        centered_frame = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)

        # Calculate top-left corner for centering the resized frame
        x_offset = (target_size[0] - new_width) // 2
        y_offset = (target_size[1] - new_height) // 2

        # Place the resized frame in the center of the new blank frame
        centered_frame[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized_frame

        resized_frames.append(centered_frame)

    return resized_frames

In [None]:
import math
import torch
import respiration.utils as utils

from tqdm.auto import tqdm
from torchvision import transforms

param = utils.get_video_params(video_path)

# Only get the first 12 seconds of the video
param.num_frames = param.fps * 30

# Number of frames that are processed at once
batch_size = 120
batches = math.ceil(param.num_frames / batch_size)

new_dim = 640

# Store the optical flows vectors (N, 2, H, W)
optical_flows = np.zeros((param.num_frames - 1, 2, new_dim, new_dim), dtype=np.float32)

# Extract the optical flow from the video in batches
for batch in tqdm(range(0, batches)):
    start = batch_size * batch - batch
    num_frames = min(batch_size, param.num_frames - start)
    frames, _ = utils.read_video_rgb(video_path, num_frames, start)

    preprocess = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        # Center Crop the frames
        transforms.CenterCrop((new_dim, new_dim)),
        transforms.ToTensor()
    ])

    frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
    frames = frames.to(device)

    # Fold the frames into (T, C, 2, H, W)
    unfolded_frames = frames.unfold(0, 2, 1).permute(0, 1, 4, 2, 3)

    with torch.no_grad():
        flows = model(unfolded_frames)

    # Garbage collect...
    del frames, unfolded_frames

    for idx in range(flows.shape[0]):
        # Add the optical flow to the numpy array
        optical_flows[start + idx] = flows[idx].cpu().numpy()

## Visualize the optical flow

In [None]:
figure_dir = utils.dir_path('outputs', 'figures', 'flownet', mkdir=True)

In [None]:
import matplotlib.pyplot as plt
import respiration.extractor.optical_flow_raft as raft

frames, _ = utils.read_video_rgb(video_path, 1, 1)
frames = resize_and_center_frames(frames, (new_dim, new_dim))
arrow_frame = raft.draw_flow(frames[0], optical_flows[0])
flow_frame = raft.image_from_flow(optical_flows[0])

fig, ax = plt.subplots(1, 2, figsize=(20, 8))
ax[0].imshow(arrow_frame)
ax[0].set_title('Optical flow arrows')
ax[1].imshow(flow_frame)
ax[1].set_title('Optical flow magnitude')

utils.savefig(fig, figure_dir, 'optical_flow')

## Extract the respiratory signal

1. Find the region of interest (ROI) on the chest
2. Calculate the motion magnitude in the ROI
3. Plot the motion magnitude over time

In [None]:
import respiration.roi as roi

# Find the chest region
x, y, w, h = roi.detect_chest(frames[0])

# Get only the optical flows in the chest region
roi_flows = optical_flows[:, :, y:y + h, x:x + w]

# Calculate motion magnitude by squaring the x and y components and taking the square root
magnitudes = np.sqrt(roi_flows[:, 1] ** 2)

In [None]:
mean_curve = np.mean(magnitudes, axis=(1, 2))
std_curve = np.std(magnitudes, axis=(1, 2))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ax.plot(mean_curve, label='Mean')
ax.fill_between(
    np.arange(len(mean_curve)),
    mean_curve - std_curve,
    mean_curve + std_curve,
    alpha=0.3,
    label='Standard deviation')

ax.set_title('Motion magnitude in the ROI')
ax.set_xlabel('Frame')
ax.set_ylabel('Motion magnitude')

utils.savefig(fig, figure_dir, 'roi_magnitudes')

In [None]:
import respiration.analysis as analysis

respiratory_gt = dataset.get_breathing_signal(subject, setting)[1:param.num_frames]

comparator = analysis.SignalComparator(respiratory_gt, mean_curve, sample_rate=param.fps)
utils.pretty_print(comparator.errors())

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ax.plot(comparator.prediction, label='Prediction')
ax.plot(comparator.ground_truth, label='Ground truth')
ax.set_title('Respiratory signal')
ax.set_xlabel('Frame')
ax.set_ylabel('Motion magnitude')
ax.legend()