In [None]:
import respiration.dataset as repository

dataset = repository.from_default()

subject = 'Proband05'
scenario = '101_natural_lighting'

In [None]:
frames, meta = dataset.get_video_rgb(subject, scenario)

In [None]:
import respiration.utils as utils

device = utils.get_torch_device()

In [None]:
import os
import torch
from respiration.extractor.efficient_phys import EfficientPhys

model_checkpoint = os.path.join('..', 'data', 'rPPG-Toolbox', 'PURE_EfficientPhys.pth')

dim = 72
frame_depth = 20

# Wrap modul in nn.DataParallel
model = EfficientPhys(img_size=dim, frame_depth=frame_depth)
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

key_matching = model.load_state_dict(torch.load(model_checkpoint, map_location=device))

In [None]:
model.eval()

In [None]:
total_frames = len(frames)
(total_frames // frame_depth) * frame_depth

In [None]:
# chunk_size = (frames.shape[0] // frame_depth) * frame_depth - (frame_depth - 1)
chunk_size = frame_depth * 100 + 1
frames_chunk = frames[:chunk_size]

frames_chunk = utils.down_sample_video(frames_chunk, dim)
frames_chunk = torch.tensor(frames_chunk, dtype=torch.float32, device=device).permute(0, 3, 1, 2)

frames_chunk.shape

In [None]:
out = model(frames_chunk)
out.shape

In [None]:
out = out.cpu().detach().numpy().squeeze()

In [None]:
import matplotlib.pyplot as plt

plt.plot(out)
plt.show()

In [None]:
import respiration.preprocessing as preprocessing

detrended = preprocessing.detrend_tarvainen(out)
filtered = preprocessing.butterworth_filter(detrended, meta.fps, 0.75, 2.5)

plt.plot(filtered)
plt.show()

In [None]:
pulse, _ = dataset.get_unisens_entry(subject, scenario, utils.VitalSigns.pulse)
pulse = preprocessing.detrend_tarvainen(pulse)
pulse = preprocessing.butterworth_filter(pulse, meta.fps, 0.75, 2.5)

pulse = preprocessing.resample_signal(pulse, len(frames))[:chunk_size - 1]

In [None]:
_, axs = plt.subplots(2, 1, figsize=(10, 10))

axs[0].plot(filtered)
axs[0].set_title('rPPG')

axs[1].plot(pulse)
axs[1].set_title('Pulse')

plt.show()

In [None]:
import respiration.analysis as analysis

print('Pearson:', analysis.pearson_correlation(filtered, pulse))
print('MSE:', analysis.distance_mse(filtered, pulse))