# Extract Respiration Signal with PhysFormer

In [None]:
from respiration.dataset import VitalCamSet

dataset = VitalCamSet()

In [None]:
import torch
import respiration.utils as utils
from respiration.extractor.rhythm_former import *

device = utils.get_torch_device()

models = [{
    'name': 'MMPD_intra_RhythmFormer',
    'path': utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth'),
}, {
    'name': 'PURE_cross_RhythmFormer',
    'path': utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth'),
}, {
    'name': 'UBFC_cross_RhythmFormer',
    'path': utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth'),
}, {
    'id': '20240726_104536',
}, {
    'id': '20240726_155024',
}, {
    'id': '20240726_171101',
}, {
    'id': '20240726_212436',
}]

In [None]:
import torchvision.transforms as transforms


def preprocess_frames(frames, size=(128, 128)):
    # Preprocess the frames to be in 128x128 with torch
    transform = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize(size),
        transforms.ToTensor(),
    ])

    # Transform each frame
    transformed_frames = torch.stack([
        transform(frame) for frame in frames
    ])

    return transformed_frames.unsqueeze(0).to(device)

In [None]:
from datetime import datetime
from tqdm.auto import tqdm

predictions = []

for model in tqdm(models):
    if 'id' in model:
        model_dir = utils.dir_path('models', 'rhythm_former', model['id'], 'RhythmFormer')
        manifest_path = utils.file_path(model_dir, 'manifest.json')
        manifest = utils.read_json(manifest_path)

        model_name = f'RF_{model["id"]}'
        model_path = manifest['models'][-1]['model_file']
        testing = manifest['testing_scenarios']
        size = manifest['size'] if 'size' in manifest else (128, 128)
    else:
        model_name = model['name']
        model_path = model['path']
        testing = dataset.get_scenarios(['101_natural_lighting'])
        size = (128, 128)

    print(f'Extracting signal with {model_name} model')

    model = RhythmFormer()
    # Fix model loading: Some key have an extra 'module.' prefix
    model = torch.nn.DataParallel(model)
    model.to(device)

    # Load the model
    _ = model.load_state_dict(torch.load(model_path, map_location=device))

    for (subject, setting) in testing:
        frames, params = dataset.get_video_rgb(subject, setting)
        frames = preprocess_frames(frames, size)

        outputs = []

        start_time = datetime.now()
        for start in range(0, frames.size(1), 100):
            end = min(frames.size(1), start + 100)
            chunk = frames[:, start:end]
            with torch.no_grad():
                model.eval()
                output = model(chunk)
                outputs.extend(output.squeeze().cpu().numpy().tolist())

        predictions.append({
            'subject': subject,
            'setting': setting,
            'model': model_name,
            'time': datetime.now() - start_time,
            'signal': outputs,
        })

        del frames, params

In [None]:
import pandas as pd
import respiration.utils as utils

signal_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signal_file = utils.file_path(signal_dir, 'rhythm_former.csv')

df = pd.DataFrame(predictions)
df.to_csv(signal_file, index=False)
df.head()