# Extracting Respiration Signals with Face Transformer

This notebook show how to extract respiration signals from videos using a transformer model and normalized faces.

In [None]:
import respiration.utils as utils

model_ids = [
    '20240617_134349',
    '20240617_213641',
    '20240618_093129',
]

In [None]:
image_size = 256
device = utils.get_torch_device()

In [None]:
import torch


def preprocess_frames(frames: torch.Tensor) -> torch.Tensor:
    # Normalize the frames
    frames = (frames - frames.min()) / (frames.max() - frames.min())

    # Create a new tensor in the dimensions (batch, channels, 2, height, width)
    return torch.stack([frames[:-1], frames[1:]]).permute(1, 2, 0, 3, 4)


def signal_diff(time_series: torch.Tensor) -> torch.Tensor:
    """
    Calculate the difference between two consecutive values in the time series
    """

    # Shift the signal that no negative values are present
    min_value = torch.min(time_series)
    if min_value < 0:
        time_series = time_series - min_value

    # Calculate the difference between the time series
    diff = time_series[1:] - time_series[:-1]

    return diff

In [None]:
import math
import torch
import respiration.utils as utils

from torchvision import transforms


class ScenarioLoaderChunks:
    """
    A data loader for the VitalCamSet dataset. This class loads the video frames and the ground truth signal for a
    specific scenario. The video frames are loaded in chunks of a specific size. The ground truth signal is down-sampled
    to match the video frames' dimensions.
    """
    subject: str
    setting: str
    frames_per_segment: int

    def __init__(self,
                 subject: str,
                 setting: str,
                 frames_per_segment: int = 20):
        self.subject = subject
        self.setting = setting
        self.frames_per_segment = frames_per_segment

        self.video_path = dataset.get_video_path(subject, setting)
        self.total_frames = utils.get_frame_count(self.video_path)

    def __len__(self) -> int:
        return math.ceil(self.total_frames / self.frames_per_segment)

    def __iter__(self):
        self.current_index = 0
        return self

    def __next__(self):
        if self.current_index >= self.__len__():
            raise StopIteration
        else:
            item = self.__getitem__(self.current_index)
            self.current_index += 1
            return item

    def __getitem__(self, index) -> (torch.Tensor, torch.Tensor):
        """
        Return the frames and the ground truth signal for the given index
        :param index: The index of the chunk
        :return: The frames and the ground truth signal
        """

        if index >= self.__len__():
            raise IndexError("Index out of range")

        start = index * self.frames_per_segment
        end = start + self.frames_per_segment
        size = min(self.frames_per_segment, self.total_frames - start)

        # Load the video frames
        frames, meta = utils.read_video_rgb(self.video_path, size, start)
        preprocess = transforms.Compose([
            transforms.ToPILImage(mode='RGB'),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])
        frames = torch.stack([preprocess(frame) for frame in frames], dim=0)
        frames = frames.to(device)

        # Get the ground truth signal for the scenario
        gt_waveform = dataset.get_breathing_signal(self.subject, self.setting)
        gt_waveform = torch.tensor(gt_waveform.copy(), dtype=torch.float32, device=device)
        # Normalize the signal between 0 and 1
        gt_waveform = (gt_waveform - gt_waveform.min()) / (gt_waveform.max() - gt_waveform.min())
        gt_waveform = gt_waveform[start:end]

        return frames, gt_waveform

In [None]:
from vit_pytorch.simple_vit_3d import SimpleViT


def load_model(model_id: str) -> (SimpleViT, dict):
    model_dir = utils.dir_path('models', 'transformer', model_id)
    manifest_path = utils.join_paths(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)

    model = SimpleViT(
        image_size=image_size,
        frames=manifest['frame_patch_size'],
        image_patch_size=manifest['image_patch_size'],
        frame_patch_size=manifest['frame_patch_size'],
        num_classes=1,
        dim=manifest['embedding_dim'],
        heads=manifest['heads'],
        mlp_dim=manifest['mlp_dim'],
        depth=manifest['depth'],
    ).to(device)

    # Load the best model from the training process
    model_path = utils.join_paths(model_dir, manifest['trained_models'][-1]['model'])
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.eval()

    return model, manifest

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

In [None]:
from tqdm.auto import tqdm

predictions = []

for model_id in tqdm(model_ids):
    model, manifest = load_model(model_id)
    scenarios = manifest['testing_scenarios']

    for inx, (subject, setting) in enumerate(scenarios):
        loader = ScenarioLoaderChunks(subject, setting)

        prediction = []

        for (frames, gt_classes) in loader:
            frames = preprocess_frames(frames)
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                outputs = model(frames).squeeze()
            prediction.extend(outputs.tolist())

        predictions.append({
            'subject': subject,
            'setting': setting,
            'model': model_id,
            'signal': prediction,
        })

In [None]:
import pandas as pd

df = pd.DataFrame(predictions)

output_dir = utils.dir_path('outputs', 'signals', mkdir=True)

# Save the evaluation dataframe
csv_path = utils.join_paths(output_dir, 'transformer_predictions.csv')
df.to_csv(csv_path, index=False)

df.head()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import respiration.analysis as analysis

prediction = predictions[3]
subject = prediction['subject']
setting = prediction['setting']

gt_signal = dataset.get_breathing_signal(subject, setting)
prediction_signal = np.array(prediction['signal'])

gt_signal = torch.tensor(gt_signal, dtype=torch.float32)
gt_signal = signal_diff(gt_signal)
gt_signal = np.array(gt_signal)

compare = analysis.SignalComparator(
    prediction_signal,
    gt_signal[:len(prediction_signal)],
    30,
    detrend_tarvainen=False,
    filter_signal=True,
)

plt.figure(figsize=(20, 5))
plt.plot(compare.ground_truth, label='Ground Truth')
plt.plot(compare.prediction, label='Prediction')
plt.legend()
plt.show()

In [None]:
compare.errors()