# Extract Respiration Signal with PhysFormer

In [None]:
import torch
import respiration.utils as utils
from respiration.extractor.rhythm_former import *

# The pre-trained PPG models
models = {
    'MMPD_intra_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth'),
        'input_dimension': (128, 128),
    }, 'PURE_cross_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth'),
        'input_dimension': (128, 128),
    }, 'UBFC_cross_RhythmFormer': {
        'model_path': utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth'),
        'input_dimension': (128, 128),
    }
}

In [None]:
# The fine-tuned Respiratory models
fine_tuned_ids = [
    '20240726_104536',
    '20240726_155024',
    '20240726_171101',
    '20240726_212436',
    '20240727_115407',
    '20240727_170156',
    '20240731_113403',
    '20240801_124757',
    '20240801_195728',
    '20240802_155121',
    '20240803_164403',
    '20240804_191911',
    '20240805_104628',
    '20240805_200748',
    '20240809_162808',
    '20240809_234509',
    '20240812_153436',
    '20240812_204742',
    '20240813_101414',
]

for model_id in fine_tuned_ids:
    model_dir = utils.dir_path('models', 'rhythm_former', model_id, 'RhythmFormer')
    manifest_path = utils.file_path(model_dir, 'manifest.json')
    manifest = utils.read_json(manifest_path)

    model_name = f'RF_{model_id}'
    model_path = manifest['models'][-1]['model_file']
    testing = manifest['testing_scenarios']
    input_dimension = manifest['size'] if 'size' in manifest else (128, 128)

    models[model_name] = {
        'model_path': model_path,
        'input_dimension': input_dimension,
    }

In [None]:
from respiration.dataset import V4VDataset

dataset = V4VDataset()
videos = dataset.get_metadata()

In [None]:
from tqdm.auto import tqdm
from datetime import datetime

device = utils.get_torch_device()
predictions = []

exclusions = [
    'F044_T4.mkv',  # 46
    'F020_T3.mkv',  # 54
    'F014_T10.mkv',  # 98
    'F011_T2.mkv',  # 42
    'F014_T7.mkv',  # 15
    'F014_T7.mkv',  # 14
    #### OK: 92, 52, 44, 41
]

for data in tqdm(videos[:200]):
    if data['vital'] != 'RR':
        continue

    video = data['video']
    if video in exclusions:
        continue

    frames_raw, _ = dataset.get_video_rgb(video)

    print(f'Processing {video}...')

    for model_name, spec in models.items():
        model_path = spec['model_path']
        model = RhythmFormer()
        # Fix model loading: Some key have an extra 'module.' prefix
        model = torch.nn.DataParallel(model)
        model.to(device)

        frames = utils.preprocess_frames(frames_raw, spec['input_dimension'])

        outputs = []
        start_time = datetime.now()

        for start in range(0, frames.size(1), 100):
            end = min(frames.size(1), start + 100)
            if end - start != 100:
                break

            # print(f'Processing {video} with {model_name} from {start} to {end}...')
            chunk = frames[:, start:end]
            with torch.no_grad():
                model.eval()
                output = model(chunk)
                outputs.extend(output.squeeze().cpu().numpy().tolist())

        predictions.append({
            'video': video,
            'model': model_name,
            'time': datetime.now() - start_time,
            'signal': outputs,
        })

        del model

    del frames

In [None]:
import pandas as pd
import respiration.utils as utils

signal_dir = utils.dir_path('outputs', 'signals', mkdir=True)
signal_file = utils.file_path(signal_dir, 'rhythm_former_v4v.csv')

df = pd.DataFrame(predictions)
df.to_csv(signal_file, index=False)
df.head()