# Extracting Respiration Signals with Face Transformer

This notebook creates predictions of respiratory signals for all models trained with the face transformer architecture.

In [None]:
import respiration.utils as utils

model_ids = [
    '20240710_194632',
]

In [None]:
image_size = 256
device = utils.get_torch_device()

In [None]:
import torch
from vit_pytorch import SimpleViT


def load_model(model_id: str) -> (SimpleViT, dict):
    model_dir = utils.dir_path('models', 'transformer', model_id)
    manifest_path = utils.join_paths(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)

    model = SimpleViT(
        image_size=image_size,
        patch_size=manifest['image_patch_size'],
        num_classes=1,
        dim=manifest['embedding_dim'],
        heads=manifest['heads'],
        mlp_dim=manifest['mlp_dim'],
        depth=manifest['depth'],
    ).to(device)

    # Load the best model from the training process
    model_path = utils.join_paths(model_dir, manifest['trained_models'][-1]['model'])
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.eval()

    return model, manifest

In [None]:
from tqdm.auto import tqdm
from respiration.dataset import ScenarioLoader

predictions = []

for model_id in tqdm(model_ids):
    model, manifest = load_model(model_id)
    scenarios = manifest['testing_scenarios']

    for inx, (subject, setting) in enumerate(scenarios):
        print(f'Processing {subject} - {setting}')
        loader = ScenarioLoader(subject, setting, manifest['num_frames'], device)

        prediction = []

        for (frames, gt_classes) in loader:
            frames = utils.normalize_frames(frames)
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                outputs = model(frames).squeeze()
            prediction.extend(outputs.tolist())

        predictions.append({
            'subject': subject,
            'setting': setting,
            'model': model_id,
            'signal': prediction,
        })

In [None]:
import pandas as pd

df = pd.DataFrame(predictions)

output_dir = utils.dir_path('outputs', 'signals', mkdir=True)

# Save the evaluation dataframe
csv_path = utils.join_paths(output_dir, 'transformer_predictions.csv')
df.to_csv(csv_path, index=False)

df.head()

## Evaluate the Predictions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import respiration.analysis as analysis
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

prediction = predictions[3]
subject = prediction['subject']
setting = prediction['setting']

gt_signal = dataset.get_breathing_signal(subject, setting)
prediction_signal = np.array(prediction['signal'])

compare = analysis.SignalComparator(
    prediction_signal,
    gt_signal[:len(prediction_signal)],
    30,
    detrend_tarvainen=False,
    filter_signal=True,
)

plt.figure(figsize=(20, 5))
plt.plot(compare.ground_truth, label='Ground Truth')
plt.plot(compare.prediction, label='Prediction')
plt.legend()
plt.show()

In [None]:
compare.errors()