# Train Rhythm Former

This notebook trains and fine-tunes the Rhythm Former model on the VitalCam dataset.

In [None]:
import respiration.utils as utils
from respiration.dataset import VitalCamSet

meta_manifest = {}

split_ratio = 0.8
meta_manifest['split_ratio'] = split_ratio

dataset = VitalCamSet()

device = utils.get_torch_device()
meta_manifest['device'] = device

epochs = 22
meta_manifest['epochs'] = epochs

min_freq = 0.1
meta_manifest['min_freq'] = min_freq

max_freq = 0.5
meta_manifest['max_freq'] = max_freq

learning_rate = 0.009
meta_manifest['learning_rate'] = learning_rate

meta_manifest['model'] = 'RhythmFormer'

In [None]:
configurations = [
    {
        'image_size': (128, 128),
        'split': 5,
        'pearson_weight': 1.0,
        'frequency_weight': 1.0,
        'norm_weight': 0.0,
        'mse_weight': 1.0,
        'spectral_convergence_weight': 1.0,
        'spectral_magnitude_weight': 1.0,
        'spectral_magnitude_norm': 'L1',
        'setting': '101_natural_lighting',
    },
    {
        'image_size': (128, 128),
        'split': 5,
        'pearson_weight': 1.0,
        'frequency_weight': 1.0,
        'norm_weight': 0.0,
        'mse_weight': 1.0,
        'spectral_convergence_weight': 1.0,
        'spectral_magnitude_weight': 1.0,
        'spectral_magnitude_norm': 'L1',
        'setting': '303_normalized_face',
    },
    {
        'image_size': (128, 128),
        'split': 5,
        'pearson_weight': 0.0,
        'frequency_weight': 0.0,
        'norm_weight': 1.0,
        'mse_weight': 0.0,
        'spectral_convergence_weight': 1.0,
        'spectral_magnitude_weight': 1.0,
        'spectral_magnitude_norm': 'L1',
        'setting': '101_natural_lighting',
    },
]

In [None]:
from pytz import timezone
from datetime import datetime

from respiration.training import HybridLoss
import torch

from tqdm.auto import tqdm
from respiration.extractor.rhythm_former import RhythmFormer
from respiration.training import VitalCamLoader

for config in configurations:

    manifest = meta_manifest.copy()
    image_size = config['image_size']

    # The timestamp is the unique identifier for this training run
    zone = timezone('Europe/Berlin')
    timestamp = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
    manifest['timestamp'] = timestamp
    manifest['config'] = config
    manifest['start_time'] = datetime.now().isoformat()
    manifest['image_size'] = image_size

    print('timestamp:', timestamp)

    loss_fn = HybridLoss(
        min_freq=min_freq,
        max_freq=max_freq,
        pearson_weight=config['pearson_weight'],
        frequency_weight=config['frequency_weight'],
        norm_weight=config['norm_weight'],
        mse_weight=config['mse_weight'],
        spectral_convergence_weight=config['spectral_convergence_weight'],
        spectral_magnitude_weight=config['spectral_magnitude_weight'],
    )
    manifest['loss_fn'] = 'HybridLoss'

    manifest['loss_fn_config'] = loss_fn.get_config()

    model = RhythmFormer(
        image_size=image_size,
    )
    # Fix model loading: Some key have an extra 'module.' prefix
    model = torch.nn.DataParallel(model)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    manifest['optimizer'] = 'Adam'

    model_dir = utils.dir_path(
        'models',
        'rhythm_former_v2',
        timestamp,
        mkdir=True)
    best_loss = float('inf')

    manifest['models'] = []
    manifest['losses'] = []

    # The split size heavenly influences the frequency resolution of the PSD in the loss function.
    # The smaller the split size, the higher the frequency resolution. However, the GPU memory for training
    # is limited. There is a trade-off between frequency resolution, memory consumption and image size.
    # Split 5:  Step size 0.04 Hz (1.25 BPM)
    # Split 10: Step size 0.08 Hz (2.50 BPM)
    split = config['split']
    manifest['split'] = split

    scenarios_all = dataset.get_scenarios([config['setting']])

    training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
    manifest['training_scenarios'] = training

    testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
    manifest['testing_scenarios'] = testing

    for epoch in range(epochs):
        model.train()
        training_loader = VitalCamLoader(training, split=split, device=device)
        testing_loader = VitalCamLoader(testing, split=split, device=device)

        train_loss = 0
        for (frames, target) in tqdm(training_loader, desc=f'Training'):
            frames = utils.preprocess_frames(frames, image_size, device)

            # Forward pass
            output = model(frames).squeeze(0)

            # Compute the loss
            loss = loss_fn(output, target)
            train_loss += loss.item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del frames, target, output, loss

        testing_loss = 0
        model.eval()
        with torch.no_grad():
            for (frames, target) in tqdm(testing_loader, desc=f'Testing'):
                frames = utils.preprocess_frames(frames, image_size, device)

                output = model(frames).squeeze(0)

                loss = loss_fn(output, target)
                testing_loss += loss.item()

                del frames, target, output, loss

        # Compute the average loss
        train_loss /= len(training_loader)
        testing_loss /= len(testing_loader)

        if testing_loss < best_loss:
            best_loss = testing_loss
            model_file = utils.file_path(model_dir, f'RF_{epoch}.pth')

            manifest['best_loss'] = best_loss
            manifest['models'].append({
                'epoch': epoch,
                'model_file': model_file,
                'train_loss': train_loss,
                'test_loss': testing_loss,
            })

            torch.save(model.state_dict(), model_file)

        print(f'[{epoch + 1}/{epochs}] '
              f'train-loss={train_loss:.3f} '
              f'test-loss={testing_loss:.3f}')
        manifest['losses'].append({
            'epoch': epoch,
            'train_loss': train_loss,
            'test_loss': testing_loss,
        })

        # Save the manifest
        manifest['epoch'] = epoch
        manifest['end_time'] = datetime.now().isoformat()
        manifest_file = utils.file_path(model_dir, 'manifest.json')
        utils.write_json(manifest_file, manifest)