# Extract Respiration Signal with PhysFormer

In [None]:
import torch
import respiration.utils as utils
from respiration.dataset import VitalCamSet
from respiration.extractor.rhythm_former import *

dataset = VitalCamSet()
natural_lighting = dataset.get_scenarios(['101_natural_lighting'])

# The pre-trained PPG models
models = {
    'MMPD_intra_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth'),
        'testing_scenarios': natural_lighting,
        'input_dimension': (128, 128),
    }, 'PURE_cross_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth'),
        'testing_scenarios': natural_lighting,
        'input_dimension': (128, 128),
    }, 'UBFC_cross_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth'),
        'testing_scenarios': natural_lighting,
        'input_dimension': (128, 128),
    }
}

In [None]:
# The fine-tuned Respiratory models
fine_tuned_ids = [
    '20240726_104536',
    '20240726_155024',
    '20240726_171101',
    '20240726_212436',
    '20240727_170156',
]

for model_id in fine_tuned_ids:
    model_dir = utils.dir_path('models', 'rhythm_former', model_id, 'RhythmFormer')
    manifest_path = utils.file_path(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)

    model_name = f'RF_{model_id}'
    model_path = manifest['models'][-1]['model_file']
    testing = manifest['testing_scenarios']
    input_dimension = manifest['size'] if 'size' in manifest else (128, 128)

    models[model_name] = {
        'model_path': model_path,
        'testing_scenarios': testing,
        'input_dimension': input_dimension,
    }

In [None]:
from datetime import datetime
from tqdm.auto import tqdm

device = utils.get_torch_device()
predictions = []

for model_name, spec in tqdm(models.items(), total=len(models)):
    print(f'Extracting signal with {model_name} model')

    model_path = spec['model_path']

    model = RhythmFormer()
    # Fix model loading: Some key have an extra 'module.' prefix
    model = torch.nn.DataParallel(model)
    model.to(device)

    # Load the model
    _ = model.load_state_dict(torch.load(model_path, map_location=device))

    for (subject, setting) in spec['testing_scenarios']:
        frames, params = dataset.get_video_rgb(subject, setting)
        frames = preprocess_frames(frames, spec['input_dimension'])

        outputs = []

        start_time = datetime.now()
        for start in range(0, frames.size(1), 100):
            end = min(frames.size(1), start + 100)
            chunk = frames[:, start:end]
            with torch.no_grad():
                model.eval()
                output = model(chunk)
                outputs.extend(output.squeeze().cpu().numpy().tolist())

        predictions.append({
            'subject': subject,
            'setting': setting,
            'model': model_name,
            'time': datetime.now() - start_time,
            'signal': outputs,
        })

        del frames, params

In [None]:
import pandas as pd
import respiration.utils as utils

signal_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signal_file = utils.file_path(signal_dir, 'rhythm_former.csv')

df = pd.DataFrame(predictions)
df.to_csv(signal_file, index=False)
df.head()