# Train Rhythm Former

This notebook trains and fine-tunes the Rhythm Former model on the VitalCam dataset.

In [None]:
manifest = {}

In [None]:
import respiration.utils as utils
from respiration.dataset import (
    VitalCamSet,
    VitalCamLoader,
)

dataset = VitalCamSet()
scenarios_all = dataset.get_scenarios(['303_normalized_face'])

split_ratio = 0.8
manifest['split_ratio'] = split_ratio

training = scenarios_all[:int(len(scenarios_all) * split_ratio)]
manifest['training_scenarios'] = training

testing = scenarios_all[int(len(scenarios_all) * split_ratio):]
manifest['testing_scenarios'] = testing

device = utils.get_torch_device()
manifest['device'] = device

In [None]:
import torchvision.transforms as transforms


def preprocess_frames(frames):
    # Preprocess the frames to be in 128x128 with torch
    transform = transforms.Compose([
        transforms.ToPILImage(mode='RGB'),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

    # Transform each frame
    transformed_frames = torch.stack([
        transform(frame) for frame in frames
    ])

    return transformed_frames.unsqueeze(0).to(device)

In [None]:
model_mmpd = utils.file_path('data', 'rhythm_former', 'MMPD_intra_RhythmFormer.pth')
model_pure = utils.file_path('data', 'rhythm_former', 'PURE_cross_RhythmFormer.pth')
model_ubfc = utils.file_path('data', 'rhythm_former', 'UBFC_cross_RhythmFormer.pth')

models = {
    'RhythmFormer': None,
    # 'RhythmFormer_MMPD': model_mmpd,
    # 'RhythmFormer_PURE': model_pure,
    # 'RhythmFormer_UBFC': model_ubfc,
}
manifest['models'] = models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def normal_sampling(mean: torch.Tensor, std: torch.Tensor, label_k: torch.Tensor):
    return torch.exp(-((label_k - mean) ** 2) / (2 * std ** 2)) / (torch.sqrt(torch.tensor(2 * torch.pi)) * std)


def filtered_periodogram(
        time_series: torch.Tensor,
        sampling_rate: int,
        min_freq: float = 0,
        max_freq: float = float('inf')) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the power spectral density (PSD) of a signal within a given frequency range.
    :param time_series: Respiratory signal
    :param sampling_rate: Sampling rate
    :param min_freq: minimum frequency
    :param max_freq: maximum frequency
    :return: Frequencies and FFT result
    """

    # Compute the power spectral density (PSD) using periodogram
    psd = (torch.fft.fft(time_series).abs() ** 2) / time_series.shape[0]

    psd = psd[:len(psd) // 2]
    freq = torch.fft.fftfreq(time_series.shape[0], 1 / sampling_rate)[:len(psd)]

    # Find the indices corresponding to the frequency range
    idx = (freq >= min_freq) & (freq <= max_freq)

    # Extract the frequencies and PSDs within the specified range
    freq_range = freq[idx]
    psd_range = psd[idx]

    # Make the psd sum to 1
    psd_range = psd_range / psd_range.sum()

    return freq_range, psd_range


def euclidean_distance(pred_psd: torch.Tensor, gt_psd: torch.Tensor):
    return torch.dist(pred_psd.softmax(dim=0), gt_psd.softmax(dim=0))


def cosine_distance(pred_psd: torch.Tensor, gt_psd: torch.Tensor):
    return 1 - F.cosine_similarity(pred_psd, gt_psd, dim=0)


def frequency_loss(pred_psd: torch.Tensor, gt_psd: torch.Tensor):
    return F.cross_entropy(pred_psd, torch.argmax(gt_psd))
    # return F.cross_entropy(pred_psd.softmax(dim=0), gt_psd)


def pearson_correlation(prediction: torch.Tensor, ground_truth: torch.Tensor):
    """Compute Pearson correlation coefficient"""
    x_mean = torch.mean(prediction)
    y_mean = torch.mean(ground_truth)

    num = torch.sum((prediction - x_mean) * (ground_truth - y_mean))
    den = torch.sqrt(torch.sum((prediction - x_mean) ** 2) * torch.sum((ground_truth - y_mean) ** 2))

    correlation = num / den

    # Bigger correlation means smaller loss, so we negate it
    return 1 - correlation


def norm_loss(pred_psd: torch.Tensor, gt_psd: torch.Tensor) -> torch.Tensor:
    std = torch.tensor(3.0)

    pred_mean = torch.argmax(pred_psd)
    pred_label = torch.arange(pred_psd.shape[0], device=pred_psd.device)
    pred_norm = normal_sampling(pred_mean, std, pred_label)

    gt_mean = torch.argmax(gt_psd)
    gt_label = torch.arange(gt_psd.shape[0], device=gt_psd.device)
    gt_norm = normal_sampling(gt_mean, std, gt_label)

    criterion = torch.nn.KLDivLoss(reduction='none')
    return criterion(pred_norm.log(), gt_norm).sum()


class HRHybridLoss(nn.Module):
    def __init__(self):
        super(HRHybridLoss, self).__init__()

    def forward(self, prediction, ground_truth):
        # Temporal Loss (negative Pearson correlation)
        pearson = pearson_correlation(prediction, ground_truth)

        freq_pred, pred_psd = filtered_periodogram(prediction, 30, 0.08, 0.6)
        freq_gt, gt_psd = filtered_periodogram(ground_truth, 30, 0.08, 0.6)

        freq_loss = frequency_loss(pred_psd, gt_psd)
        # cosine = cosine_distance(pred_psd, gt_psd)
        # euclid = euclidean_distance(pred_psd, gt_psd)

        pred_freq = freq_pred[torch.argmax(pred_psd)]
        gt_freq = freq_gt[torch.argmax(gt_psd)]

        freq_distance = (pred_freq - gt_freq).abs() / (0.6 - 0.08)
        # freq_distance = (pred_freq - gt_freq).abs()

        # KL Divergence
        # psd_prediction = torch.fft.fft(prediction).abs() ** 2
        # psd_ground_truth = torch.fft.fft(ground_truth).abs() ** 2
        # kl_div = torch.nn.functional.kl_div(
        #     pred_psd.softmax(dim=0),
        #     gt_psd,
        #     reduction='batchmean',
        # )
        norm_l = norm_loss(pred_psd, gt_psd)

        # Combine losses
        # total_loss = 0.2 * pearson + 0.4 * freq_loss + 0.4 * cosine
        # total_loss = 0.3 * pearson + 0.7 * freq_distance
        # total_loss = 0.2 * pearson + 0.1 * freq_distance + 0.4 * freq_loss + 0.3 * cosine
        total_loss = 0.2 * pearson + 1.0 * freq_loss + 1.0 * norm_l
        # total_loss = 0.2 * pearson + 1.0 * freq_distance + 1.0 * norm_l

        print(
            f'pearson={pearson:.3f} '
            # f'euclid={euclid:.3f} '
            f'norm_l={norm_l:.3f} '
            f'freq_loss={freq_loss:.3f} '
            # f'cosine={cosine:.3f} '
            # f'freq_distance={freq_distance:.2f} '
            f'total_loss={total_loss.item():2.3f} '
            f'pred_freq={pred_freq.item():.2f} '
            f'gt_freq={gt_freq.item():.2f}')

        return total_loss

In [None]:
from pytz import timezone
from datetime import datetime

# The timestamp is the unique identifier for this training run
zone = timezone('Europe/Berlin')
timestamp = datetime.now().astimezone(zone).strftime('%Y%m%d_%H%M%S')
manifest['timestamp'] = timestamp

epochs = 22
manifest['epochs'] = epochs

loss_fn = HRHybridLoss()
manifest['loss_fn'] = 'HRHybridLoss'

# learning_rate = 0.00001
learning_rate = 9e-3
manifest['learning_rate'] = learning_rate

In [None]:
from tqdm.auto import tqdm
from respiration.extractor.rhythm_former import RhythmFormer

for model_name, model_path in models.items():
    model = RhythmFormer()
    # Fix model loading: Some key have an extra 'module.' prefix
    model = torch.nn.DataParallel(model)
    model.to(device)

    manifest = manifest.copy()
    manifest['model'] = model_name
    manifest['models'] = []

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    manifest['optimizer'] = 'Adam'

    if model_path is not None:
        model.load_state_dict(torch.load(model_path, map_location=device))

    model_dir = utils.dir_path(
        'models',
        'rhythm_former',
        timestamp,
        model_name,
        mkdir=True)
    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        training_loader = VitalCamLoader(training, parts=6, device=device)
        testing_loader = VitalCamLoader(testing, parts=6, device=device)

        train_loss = 0
        for (frames, target) in tqdm(training_loader, desc=f'Training'):
            frames = preprocess_frames(frames)
            # target = target.unsqueeze(0)

            # Forward pass
            output = model(frames).squeeze(0)
            # print(f'output.shape={output.shape} target.shape={target.shape}')

            # Compute the loss
            loss = loss_fn(output, target)
            # print(f'loss={loss.item()}')
            train_loss += loss.item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del frames, target, output, loss

        testing_loss = 0
        model.eval()
        with torch.no_grad():
            for (frames, target) in tqdm(testing_loader, desc=f'Testing'):
                frames = preprocess_frames(frames)
                # target = target.unsqueeze(0)

                output = model(frames).squeeze(0)
                loss = loss_fn(output, target)
                # print(f'loss={loss.item()}')
                testing_loss += loss.item()

                del frames, target, output, loss

        # Compute the average loss
        train_loss /= len(training_loader)
        testing_loss /= len(testing_loader)

        if testing_loss < best_loss:
            best_loss = testing_loss
            model_file = utils.file_path(model_dir, f'{model_name}_{epoch}.pth')

            manifest['best_loss'] = best_loss
            manifest['models'].append({
                'epoch': epoch,
                'model_file': model_file,
                'train_loss': train_loss,
                'test_loss': testing_loss,
            })

            torch.save(model.state_dict(), model_file)

        print(f'{model_name}[{epoch + 1}/{epochs}] '
              f'train-loss={train_loss:.3f} '
              f'test-loss={testing_loss:.3f}')

        # Save the manifest
        manifest['epoch'] = epoch
        manifest_file = utils.file_path(model_dir, 'manifest.json')
        utils.write_json(manifest_file, manifest)