# Extract Respiration Signal with PhysFormer

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

subject = 'Proband21'
scenario = '101_natural_lighting'

In [None]:
frames, meta = dataset.get_video_rgb(
    subject,
    scenario,
    num_frames=30 * 20,
    show_progress=True,
)

In [None]:
gt_respiration = dataset.get_breathing_signal(subject, scenario)

# Cut the signal to match the number of frames
gt_respiration = gt_respiration[:len(frames)]

In [None]:
import torch
import respiration.utils as utils
from respiration.extractor.rhythm_former import *

device = utils.get_torch_device()

# Pretrained PPG models
# model_checkpoint = utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth')
# model_checkpoint = utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth')
# model_checkpoint = utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth')

# Fine-tuned Respiration models
# model_checkpoint = utils.file_path(
#     'models', 'rhythm_former', '20240721_173436', 'RhythmFormer', 'RhythmFormer_4.pth')
# model_checkpoint = utils.file_path(
#     'models', 'rhythm_former', '20240721_181857', 'RhythmFormer', 'RhythmFormer_4.pth')
# model_checkpoint = utils.file_path(
#     'models', 'rhythm_former', '20240721_215042', 'RhythmFormer', 'RhythmFormer_6.pth')
# model_checkpoint = utils.file_path(
#     'models', 'rhythm_former', '20240721_185122', 'RhythmFormer', 'RhythmFormer_9.pth')
# model_checkpoint = utils.file_path(
#     'models', 'rhythm_former', '20240722_115720', 'RhythmFormer', 'RhythmFormer_1.pth')
model_checkpoint = utils.file_path(
    'models', 'rhythm_former', '20240722_185129', 'RhythmFormer', 'RhythmFormer_8.pth')

model = RhythmFormer()
# Fix model loading: Some key have an extra 'module.' prefix
model = torch.nn.DataParallel(model)
model.to(device)

key_matching = model.load_state_dict(torch.load(model_checkpoint, map_location=device))
key_matching

In [None]:
import torchvision.transforms as transforms

# Preprocess the frames to be in 128x128 with torch
transform = transforms.Compose([
    transforms.ToPILImage(mode='RGB'),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Assuming `frames` is a list of frame images
transformed_frames = [
    transform(frame) for frame in frames
]

# Optionally, stack the list of transformed frames into a single tensor
frames_torch = torch.stack(transformed_frames).to(device)

# Create batches of size 20
frames_torch = frames_torch.unsqueeze(0)
frames_torch.shape

In [None]:
with torch.no_grad():
    model.eval()
    output = model(frames_torch.to(device))
    print(output.shape)

In [None]:
# Plot the output
import matplotlib.pyplot as plt

# Plot the out and the ground truth on two separate plots
fig, axs = plt.subplots(2, 1, figsize=(20, 5))

axs[0].plot(output.cpu().numpy().flatten())
axs[0].set_title('Prediction')

axs[1].plot(gt_respiration)
axs[1].set_title('Ground Truth')

plt.tight_layout()
plt.show()

In [None]:
import respiration.analysis as analysis

output_processed = analysis.butterworth_filter(output.cpu().numpy().flatten(), 30, 0.08, 0.6)

# Plot the out and the ground truth on two separate plots
_, axs = plt.subplots(2, 1, figsize=(20, 5))

axs[0].plot(output_processed)
axs[0].set_title('Prediction')

axs[1].plot(gt_respiration)
axs[1].set_title('Ground Truth')

plt.tight_layout()
plt.show()