# Extract Respiration Signal with PhysFormer

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

In [None]:
import torch
import respiration.utils as utils
from respiration.extractor.rhythm_former import *

device = utils.get_torch_device()

model_mmpd = utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth')
model_pure = utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth')
model_ubfc = utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth')

models = {
    'MMPD_intra_RhythmFormer': model_mmpd,
    'PURE_cross_RhythmFormer': model_pure,
    'UBFC_cross_RhythmFormer': model_ubfc,
}

In [None]:
import torchvision.transforms as transforms


def preprocess_frames(frames):
    # Preprocess the frames to be in 128x128 with torch
    transform = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

    # Transform each frame
    transformed_frames = torch.stack([
        transform(frame) for frame in frames
    ])

    return transformed_frames.unsqueeze(0).to(device)

In [None]:
from datetime import datetime
from tqdm.auto import tqdm

scenarios = dataset.get_scenarios(['101_natural_lighting'])

predictions = []

for (subject, setting) in tqdm(scenarios):
    frames, params = dataset.get_video_rgb(subject, setting)
    frames = preprocess_frames(frames)

    for model_name, model_path in models.items():
        model = RhythmFormer()
        # Fix model loading: Some key have an extra 'module.' prefix
        model = torch.nn.DataParallel(model)
        model.to(device)

        _ = model.load_state_dict(torch.load(model_path, map_location=device))

        start_time = datetime.now()
        with torch.no_grad():
            model.eval()
            output = model(frames)

        predictions.append({
            'subject': subject,
            'setting': setting,
            'model': model_name,
            'time': datetime.now() - start_time,
            'signal': output.cpu().numpy().unsqueeze(0).tolist(),
        })

    del frames, params

In [None]:
import pandas as pd
import respiration.utils as utils

signal_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signal_file = utils.file_path(signal_dir, 'rhythm_former.csv')

df = pd.DataFrame(predictions)
df.to_csv(signal_file, index=False)
df.head()