# Extract Signal with EfficientPhys

This notebook demonstrates how to extract a signal from a video using the EfficientPhys model. The pretrained models extract rPPG signals from a video. The fine-tuned are trained to extract the respiratory signals.

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband16'
setting = '101_natural_lighting'

frames, meta = dataset.get_video_rgb(
    subject,
    setting,
    num_frames=30 * 12,
    show_progress=True,
)

In [None]:
import os
import torch
import respiration.utils as utils
from respiration.extractor.efficient_phys import EfficientPhys

dim = 72
frame_depth = 20

device = utils.get_torch_device()

model = EfficientPhys(img_size=dim, frame_depth=frame_depth)

# Wrap model in nn.DataParallel to fix model loading and key matching
model = torch.nn.DataParallel(model)
model.to(device)

model_checkpoint = os.path.join('..', '..', 'data', 'rPPG-Toolbox', 'BP4D_PseudoLabel_EfficientPhys.pth')
key_matching = model.load_state_dict(torch.load(model_checkpoint, map_location=device))

In [None]:
# The model expects the input to be a multiple of the frame depth
chunk_size = (frames.shape[0] // frame_depth) * frame_depth - (frame_depth - 1)
frames_chunk = frames[:chunk_size]

frames_chunk = utils.down_sample_video(frames_chunk, dim)

In [None]:
import matplotlib.pyplot as plt

figure_dir = utils.dir_path('outputs', 'figures', 'pre-process', mkdir=True)

plt.figure(figsize=(10, 5))
plt.imshow(frames_chunk[0])

utils.savefig(plt.gcf(), figure_dir, 'efficient_phys')

In [None]:
frames_chunk = torch.tensor(frames_chunk, dtype=torch.float32, device=device)

# Permute from (T, H, W, C) to (T, C, H, W)
frames_chunk = frames_chunk.permute(0, 3, 1, 2)

In [None]:
with torch.no_grad():
    out = model(frames_chunk)

prediction = out.cpu().detach().numpy().squeeze()

In [None]:
import respiration.analysis as analysis

respiration_gt = dataset.get_breathing_signal(subject, setting)

# Cut to the same length as the video
respiration_gt = respiration_gt[:prediction.shape[0]]

compare = analysis.SignalComparator(
    prediction,
    respiration_gt,
    meta.fps,
)

In [None]:
compare.errors()

In [None]:
compare.signal_distances()

In [None]:
import matplotlib.pyplot as plt

_ = plt.figure(figsize=(20, 5))
plt.plot(compare.prediction, label='Prediction')
plt.plot(compare.ground_truth, label='Ground Truth')
plt.legend()
plt.show()