In [None]:
import respiration.dataset as repository

dataset = repository.from_default()

subject = 'Proband16'
scenario = '101_natural_lighting'

In [None]:
frames, meta = dataset.get_video_bgr(subject, scenario)

In [None]:
import torch

if torch.backends.mps.is_available():
    # Use the MPS (Multi-Process Service) to run the model
    # This is only available on macOS
    device = torch.device('mps')
elif torch.cuda.is_available():
    # Use the GPU to run the model
    device = torch.device('cuda')
else:
    # Use the CPU to run the model
    device = torch.device('cpu')

device

In [None]:
import torch
from respiration.extractor.r_ppg.big_small import BigSmall

model_path = '../data/deep_phys/BP4D_BigSmall_Multitask_Fold3.pth'

# Wrap modul in nn.DataParallel
model = BigSmall()
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

# Load the model with the weights
key_matching = model.load_state_dict(torch.load(model_path, map_location=device))
key_matching

In [None]:
model.eval()

In [None]:
import cv2
import numpy as np


def preprocess_frames(frames: np.array, big_res=144, small_res=9):
    # Center crop frames to square shape
    h, w, _ = frames[0].shape
    crop_size = min(h, w)
    start_y = (h - crop_size) // 2
    start_x = (w - crop_size) // 2
    frames = [frame[start_y:start_y + crop_size, start_x:start_x + crop_size] for frame in frames]

    # Convert frames to floating point
    frames = np.array(frames, dtype=np.float32)

    # Generate Small branch inputs (normalized difference frames)
    diff_frames = frames[1:] - frames[:-1]
    sum_frames = frames[1:] + frames[:-1]
    small_inputs = diff_frames / (sum_frames + 1e-7)
    small_inputs = (small_inputs - np.mean(small_inputs)) / np.std(small_inputs)
    small_inputs = [cv2.resize(frame, (small_res, small_res)) for frame in small_inputs]

    # Fix missing first frame
    small_inputs = small_inputs + [np.zeros_like(small_inputs[0])]

    # Generate Big branch inputs (raw frames)
    big_inputs = (frames - np.mean(frames)) / np.std(frames)
    big_inputs = [cv2.resize(frame, (big_res, big_res)) for frame in big_inputs]

    return small_inputs, big_inputs

In [None]:
# Get the first 10 seconds of the video
from respiration.utils import video

rgb_frames = video.bgr_to_rgb(frames[:300])
small, big = preprocess_frames(rgb_frames, big_res=144, small_res=9)

In [None]:
# Show the first big and small frame
import matplotlib.pyplot as plt

_, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].imshow(small[0])
axs[0].set_title('Small Frame')

axs[1].imshow(big[0])
axs[1].set_title('Big Frame')

In [None]:
# Convert the frames to a tensor
small_tensor = torch.tensor(np.array(small), device=device)
big_tensor = torch.tensor(np.array(big), device=device)

In [None]:
small_tensor.shape, big_tensor.shape

In [None]:
# Transform the tensor to the shape expected by the model (frame_count, c, w, h)
small_tensor = small_tensor.permute(0, 3, 1, 2)
big_tensor = big_tensor.permute(0, 3, 1, 2)

small_tensor.shape, big_tensor.shape

In [None]:
# Extract the signals
with torch.no_grad():
    au_out, bvp_out, resp_out = model((big_tensor, small_tensor))

resp_out.shape

In [None]:
import respiration.preprocessing as preprocessing

waveform = resp_out.cpu().numpy().squeeze()
waveform = preprocessing.detrend_tarvainen(waveform)
waveform = preprocessing.butterworth_filter(waveform, meta.fps, 0.8, 3.0)

In [None]:
waveform.shape

In [None]:
# Plot the rPPG signal
plt.figure(figsize=(15, 5))
plt.plot(waveform)
plt.title('Respiration Signal')
plt.xlabel('Frame')
plt.ylabel('Amplitude')
plt.show()

In [None]:
from tqdm.auto import tqdm

chunk_size = 300

waveform = None

for inx in tqdm(range(0, len(frames), chunk_size)):
    end = min(inx + chunk_size, len(frames))
    rgb_frames = video.bgr_to_rgb(frames[inx:end])
    small, big = preprocess_frames(rgb_frames, big_res=144, small_res=9)

    # Convert the frames to a tensor
    small_tensor = torch.tensor(small, device=device)
    big_tensor = torch.tensor(big, device=device)

    # Transform the tensor to the shape expected by the model (frame_count, c, w, h)
    small_tensor = small_tensor.permute(0, 3, 1, 2)
    big_tensor = big_tensor.permute(0, 3, 1, 2)

    with torch.no_grad():
        _, _, resp_out = model((big_tensor, small_tensor))
        
        if waveform is None:
            waveform = resp_out.cpu().numpy()
        else:
            waveform = np.concatenate((waveform, resp_out.cpu().numpy()), axis=0)

waveform = waveform.squeeze()

In [None]:
waveform_processed = preprocessing.detrend_tarvainen(waveform)
waveform_processed = preprocessing.butterworth_filter(waveform_processed, meta.fps, 0.08, 0.6)

In [None]:
# Plot the signals
_, axs = plt.subplots(2, 1, figsize=(15, 10))

axs[0].plot(waveform)
axs[0].set_title('Signal')
axs[0].set_xlabel('Frame')
axs[0].set_ylabel('Amplitude')

axs[1].plot(waveform_processed)
axs[1].set_title('Processed')
axs[1].set_xlabel('Frame')
axs[1].set_ylabel('Amplitude')