# Correlation & Bland-Altman Plots

This notebook is used to create correlation and Bland-Altman plots for the different models. The models are grouped into different categories, such as optical flow, rPPG, Respiration Rhythm Former, and SimpleViT. The correlation plots show the correlation between the ground truth and the prediction, while the Bland-Altman plots show the difference between the ground truth and the prediction.

## Grouping the models

In [None]:
import pandas as pd
import respiration.utils as utils

analysis_dir = utils.dir_path('outputs', 'analysis')

metrics_file = utils.join_paths(analysis_dir, 'metrics.csv')
experiment_analysis = pd.read_csv(metrics_file)

frequencies_file = utils.join_paths(analysis_dir, 'frequencies.csv')
frequencies = pd.read_csv(frequencies_file)

In [None]:
figure_dir = utils.dir_path('outputs', 'figures', mkdir=True)

In [None]:
experiment_analysis['model'].unique()

In [None]:
optical_flow_models = [
    'raft_large',
    'raft_small',
    'FlowNet2',
    'FlowNet2C',
    'FlowNet2CS',
    'FlowNet2CSS',
    'FlowNet2S',
    'FlowNet2SD',
    'lucas_kanade',
    'pixel_intensity_rgb',
    'pixel_intensity_grey',
]

ppg_models = [
    # RhythmFormer
    'MMPD_intra_RhythmFormer',
    'PURE_cross_RhythmFormer',
    'UBFC_cross_RhythmFormer',

    # EfficientPhys
    'UBFC-rPPG_EfficientPhys',
    'BP4D_PseudoLabel_EfficientPhys',
    'SCAMPS_EfficientPhys',
    'PURE_EfficientPhys',
    'MA-UBFC_efficientphys',

    # TSCAN
    'BP4D_PseudoLabel_TSCAN',
    'MA-UBFC_tscan',
    'PURE_TSCAN',
    'SCAMPS_TSCAN',
    'UBFC-rPPG_TSCAN',

    # DeepPhys
    'BP4D_PseudoLabel_DeepPhys',
    'MA-UBFC_deepphys',
    'PURE_DeepPhys',
    'SCAMPS_DeepPhys',
    'UBFC-rPPG_DeepPhys',
    'MA-UBFC_efficientphys',
]

pretrained = [
    'mtts_can',
    'big_small',
]

# Get all models that start with RF_
rrf = [model for model in experiment_analysis['model'].unique() if model.startswith('RF_')]

# Get all models that start with RF_
simpleViT = [model for model in experiment_analysis['model'].unique() if model.startswith('SimpleViT_')]

print(f'OF: {len(optical_flow_models)}, PPG: {len(ppg_models)}, RF: {len(rrf)}, SimpleViT: {len(simpleViT)}')

In [None]:
groups = [
    {
        'label': 'of',
        'dimensions': (3, 4),
        'models': optical_flow_models,
    },
    {
        'label': 'rppg',
        'dimensions': (4, 5),
        'models': ppg_models,
    },
    {
        'label': 'rf',
        'dimensions': (6, 6),
        'models': rrf,
    },
    {
        'label': 'simple_vit',
        'dimensions': (2, 5),
        'models': simpleViT,
    },
    {
        'label': 'pretrained',
        'dimensions': (1, 2),
        'models': pretrained,
    },
    {
        'label': 'random',
        'dimensions': (1, 1),
        'models': ['random'],
    },
]

## Correlation Plots

In [None]:
# Record for the MAE, PCC and P values
records_corr = []

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from respiration.analysis import pearson_correlation

corr_dir = utils.dir_path(figure_dir, 'correlation', mkdir=True)

for group in groups:
    dim = group['dimensions']
    _, axs = plt.subplots(dim[0], dim[1], figsize=(3 * dim[1], 3 * dim[0]))

    if dim[0] == 1 and dim[1] == 1:
        axs = np.array([[axs]])
    elif dim[0] == 1:
        axs = np.array([axs])
    elif dim[1] == 1:
        axs = np.array([[ax] for ax in axs])

    for idx, model in enumerate(group['models']):
        data = frequencies[(frequencies['model'] == model)]

        # Transform the values from Hz to beats per minute
        preds = np.array(data['prediction'].values) * 60
        gts = np.array(data['ground_truth'].values) * 60

        mae = np.mean(np.abs(preds - gts))

        xy = (idx // dim[1], idx % dim[1])

        if preds.std() > 0:
            pcc, p = pearson_correlation(preds, gts)
            axs[xy].text(0.1, 0.9, f'PCC: {round(pcc, 3)}', transform=axs[xy].transAxes, color='red')
            axs[xy].text(0.1, 0.8, f'P: {round(p, 3)}', transform=axs[xy].transAxes, color='red')
            axs[xy].text(0.1, 0.7, f'MAE: {round(mae, 1)}', transform=axs[xy].transAxes, color='red')

            # Add a trend line
            axs[xy].plot(np.unique(gts), np.poly1d(np.polyfit(gts, preds, 1))(np.unique(gts)), color='red')

            records_corr.append({
                'model': model,
                'mae': mae,
                'pcc': pcc,
                'p': p,
            })
        else:
            axs[xy].text(0.1, 0.9, f'PCC: 0.0', transform=axs[xy].transAxes, color='red')
            axs[xy].plot(np.unique(gts), np.poly1d(np.polyfit(gts, preds, 1))(np.unique(gts)), color='red')

        # Scatter plot
        # axs[x, y].scatter(gts, preds, label=model, s=20, alpha=0.01)
        # Create a heatmap
        axs[xy].hexbin(gts, preds, gridsize=30, cmap='viridis', extent=[0, 30, 0, 30])
        axs[xy].set_title(f'{model}')

        # Show the range 5 to 30 for the x- and y-axis
        axs[xy].set_xlim(0, 30)
        axs[xy].set_ylim(0, 30)

        # Name the x- and y-axis
        axs[xy].set_xlabel('Ground truth (bpm)')
        axs[xy].set_ylabel('Prediction (bpm)')

    plt.tight_layout()

    # Save the figure
    utils.savefig(plt.gcf(), corr_dir, group['label'])

    # Don't show the plot
    plt.close()

## Bland-Altman Plots

In [None]:
# Record for the difference mean and standard deviation
records_altman = []

In [None]:
altman_dir = utils.dir_path(figure_dir, 'bland_altman', mkdir=True)

for group in groups:
    dim = group['dimensions']
    _, axs = plt.subplots(dim[0], dim[1], figsize=(3 * dim[1], 3 * dim[0]))

    if dim[0] == 1 and dim[1] == 1:
        axs = np.array([[axs]])
    elif dim[0] == 1:
        axs = np.array([axs])
    elif dim[1] == 1:
        axs = np.array([[ax] for ax in axs])

    for idx, model in enumerate(group['models']):
        data = frequencies[(frequencies['model'] == model)]

        # Transform the values from Hz to beats per minute
        preds = np.array(data['prediction'].values) * 60
        gts = np.array(data['ground_truth'].values) * 60

        # Calculate the difference between the two values
        diff = preds - gts
        diff_mean = diff.mean()
        diff_std = diff.std()

        records_altman.append({
            'model': model,
            'diff_mean': diff_mean,
            'diff_std': diff_std,
        })

        # Calculate the mean of the two values
        mean = (preds + gts) / 2

        xy = (idx // dim[1], idx % dim[1])

        # Create a heatmap
        axs[xy].hexbin(mean, diff, gridsize=30, cmap='viridis', extent=[0, 30, -20, 20])
        axs[xy].set_title(f'{model}')

        # Name the x- and y-axis
        axs[xy].set_xlabel('Mean (bpm)')
        axs[xy].set_ylabel('Difference (bpm)')

        # Add a horizontal line at diff.mean()
        axs[xy].axhline(diff_mean, color='red', linestyle='--')

        # Add the 95% confidence interval
        axs[xy].axhline(diff_mean + 1.96 * diff_std, color='green', linestyle='--')
        axs[xy].axhline(diff_mean - 1.96 * diff_std, color='green', linestyle='--')

        # Add the 95% confidence interval text
        axs[xy].text(0.1, 0.9, f'Mean: {round(diff_mean, 1)}', transform=axs[xy].transAxes, color='red')
        axs[xy].text(0.1, 0.8, f'CI: {round(diff_std, 1)}', transform=axs[xy].transAxes, color='green')

        # Set the y-axis to -10 to 10
        axs[xy].set_ylim(-20, 20)
        axs[xy].set_xlim(6, 30)

    plt.tight_layout()

    # Save the figure
    utils.savefig(plt.gcf(), altman_dir, group['label'])

    # Don't show the plot
    plt.close()

In [None]:
# Merge the records
metrics = pd.merge(pd.DataFrame(records_corr), pd.DataFrame(records_altman), on='model')
metrics.to_csv(utils.join_paths(analysis_dir, 'metrics_average.csv'), index=False)