# Compare Models

## Load the data

In [None]:
import pandas as pd
import respiration.utils as utils

analysis_dir = utils.dir_path('outputs', 'analysis')

metrics_avg_file = utils.join_paths(analysis_dir, 'metrics_average.csv')
metrics_average = pd.read_csv(metrics_avg_file)

frequencies_file = utils.join_paths(analysis_dir, 'frequencies.csv')
frequencies = pd.read_csv(frequencies_file)

In [None]:
figure_dir = utils.dir_path('outputs', 'figures', 'compare', mkdir=True)

## Group Models

In [None]:
groups = [
    {
        'label': 'Optical Flow',
        'models': [
            'raft_small',
            'FlowNet2CS',
            'lucas_kanade',
            'pixel_intensity_rgb',
        ],
    },
    {
        'label': 'R-RhythmFormer',
        'models': [
            'RF_20240902_210159',
        ],
    },
    {
        'label': 'SimpleViT',
        'models': [
            'SimpleViT_20240906_131118',
        ],
    },
    {
        'label': 'Pretrained Respiration',
        'models': [
            'mtts_can',
            'big_small',
        ],
    },
    {
        'label': 'Pretrained rPPG',
        'models': [
            'MMPD_intra_RhythmFormer',
            'UBFC-rPPG_TSCAN',
            'UBFC-rPPG_EfficientPhys',
            'BP4D_PseudoLabel_DeepPhys',
        ],
    },
    {
        'label': 'Random',
        'models': [
            'random',
        ],
    },
]

## Compare the MAE and PCC

In [None]:
points = []

for group in groups:
    group_points = metrics_average[metrics_average['model'].isin(group['models'])]

    for _, row in group_points.iterrows():
        points.append({
            'model': row['model'],
            'mae': row['mae'],
            'pcc': row['pcc'],
            'group': group['label'],
        })

points = pd.DataFrame(points)

In [None]:
# Plot the MAE and PCC for the psd method
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))

# Scatter the MAE and PCC for the different models
sns.scatterplot(
    data=points,
    x='mae',
    y='pcc',
    s=250,
    style='model',
    hue='model',
)

plt.xlabel('MAE (BPM)')
plt.ylabel('Correlation')
plt.title('MAE and Pearson Correlation for the different models')
plt.tight_layout()

# Set the dimensions of the plot
plt.xlim(0, 8)
plt.ylim(0, 1)

utils.savefig(plt.gcf(), figure_dir, 'mae_pcc')

plt.show()

In [None]:
# Plot the MAE and PCC for the psd method
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))

# Scatter the MAE and PCC for the different models
sns.scatterplot(
    data=points,
    x='mae',
    y='pcc',
    s=250,
    style='model',
    hue='group',
)

plt.xlabel('MAE (BPM)')
plt.ylabel('Correlation')
plt.title('MAE and Pearson Correlation for the different models')
plt.tight_layout()

# Set the dimensions of the plot
plt.xlim(0, 8)
plt.ylim(0, 1)

utils.savefig(plt.gcf(), figure_dir, 'mae_pcc_grouped')

plt.show()

## T-Test between the groups

In [None]:
import numpy as np
from scipy import stats

group_records = []

for group1 in groups:
    for group2 in groups:
        # if group1 == group2:
        #     continue

        group1_points = frequencies[frequencies['model'].isin(group1['models'])]
        group2_points = frequencies[frequencies['model'].isin(group2['models'])]

        error_1 = np.abs(group1_points['prediction'] - group1_points['ground_truth'])
        error_2 = np.abs(group2_points['prediction'] - group2_points['ground_truth'])

        mae_1 = error_1.mean()
        mae_2 = error_2.mean()

        t_stat, p = stats.ttest_ind(error_1, error_2)

        group_records.append({
            'group1': group1['label'],
            'group2': group2['label'],
            'mae_1': round(mae_1 * 60, 1),
            'mae_2': round(mae_2 * 60, 1),
            't_stat': round(t_stat, 3),
            'p_value': round(p, 3),
        })

group_records = pd.DataFrame(group_records)
group_records

In [None]:
# Visualize the t_stat in a heatmap
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 8))

# Create a pivot table
heatmap = group_records.pivot_table(
    index='group1',
    columns='group2',
    values='t_stat',
)

# Getting the Upper Triangle of the co-relation matrix
matrix = np.triu(heatmap)

sns.heatmap(
    data=heatmap,
    annot=True,
    fmt='.0f',
    # cmap='viridis',
    mask=matrix,
)

# Turn the label on the y-axis
plt.yticks(rotation=0)

plt.title('Group T-Statistics')
plt.tight_layout()

utils.savefig(plt.gcf(), figure_dir, 't_stat')

plt.show()

## Correlation Plots

In [None]:
models = []

for group in groups:
    models.extend(group['models'])

len(models)

In [None]:
dim = (3, 5)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from respiration.analysis import pearson_correlation

_, axs = plt.subplots(dim[0], dim[1], figsize=(3 * dim[1], 3 * dim[0]))

if dim[0] == 1 and dim[1] == 1:
    axs = np.array([[axs]])
elif dim[0] == 1:
    axs = np.array([axs])
elif dim[1] == 1:
    axs = np.array([[ax] for ax in axs])

for idx, model in enumerate(models):
    data = frequencies[(frequencies['model'] == model)]

    # Transform the values from Hz to beats per minute
    preds = np.array(data['prediction'].values) * 60
    gts = np.array(data['ground_truth'].values) * 60

    mae = np.mean(np.abs(preds - gts))

    xy = (idx // dim[1], idx % dim[1])

    if preds.std() > 0:
        pcc, p = pearson_correlation(preds, gts)
        axs[xy].text(0.1, 0.9, f'PCC: {round(pcc, 3)}', transform=axs[xy].transAxes, color='red')
        axs[xy].text(0.1, 0.8, f'P: {round(p, 3)}', transform=axs[xy].transAxes, color='red')
        axs[xy].text(0.1, 0.7, f'MAE: {round(mae, 1)}', transform=axs[xy].transAxes, color='red')

        # Add a trend line
        axs[xy].plot(np.unique(gts), np.poly1d(np.polyfit(gts, preds, 1))(np.unique(gts)), color='red')
    else:
        axs[xy].text(0.1, 0.9, f'PCC: 0.0', transform=axs[xy].transAxes, color='red')
        axs[xy].plot(np.unique(gts), np.poly1d(np.polyfit(gts, preds, 1))(np.unique(gts)), color='red')

    # Create a heatmap
    axs[xy].hexbin(gts, preds, gridsize=30, cmap='viridis', extent=[0, 30, 0, 30])
    axs[xy].set_title(f'{model}')

    # Show the range 5 to 30 for the x- and y-axis
    axs[xy].set_xlim(0, 30)
    axs[xy].set_ylim(0, 30)

    # Name the x- and y-axis
    axs[xy].set_xlabel('Ground truth (bpm)')
    axs[xy].set_ylabel('Prediction (bpm)')

plt.tight_layout()

# Save the figure
utils.savefig(plt.gcf(), figure_dir, 'correlation')

plt.show()

## Bland-Altman Plots

In [None]:
_, axs = plt.subplots(dim[0], dim[1], figsize=(3 * dim[1], 3 * dim[0]))

if dim[0] == 1 and dim[1] == 1:
    axs = np.array([[axs]])
elif dim[0] == 1:
    axs = np.array([axs])
elif dim[1] == 1:
    axs = np.array([[ax] for ax in axs])

for idx, model in enumerate(models):
    data = frequencies[(frequencies['model'] == model)]

    # Transform the values from Hz to beats per minute
    preds = np.array(data['prediction'].values) * 60
    gts = np.array(data['ground_truth'].values) * 60

    # Calculate the difference between the two values
    diff = preds - gts
    diff_mean = diff.mean()
    diff_std = diff.std()

    # Calculate the mean of the two values
    mean = (preds + gts) / 2

    xy = (idx // dim[1], idx % dim[1])

    # Create a heatmap
    axs[xy].hexbin(mean, diff, gridsize=30, cmap='viridis', extent=[0, 30, -20, 20])
    axs[xy].set_title(f'{model}')

    # Name the x- and y-axis
    axs[xy].set_xlabel('Mean (bpm)')
    axs[xy].set_ylabel('Difference (bpm)')

    # Add a horizontal line at diff.mean()
    axs[xy].axhline(diff_mean, color='red', linestyle='--')

    # Add the 95% confidence interval
    axs[xy].axhline(diff_mean + 1.96 * diff_std, color='green', linestyle='--')
    axs[xy].axhline(diff_mean - 1.96 * diff_std, color='green', linestyle='--')

    # Add the 95% confidence interval text
    axs[xy].text(0.1, 0.9, f'Mean: {round(diff_mean, 1)}', transform=axs[xy].transAxes, color='red')
    axs[xy].text(0.1, 0.8, f'CI: {round(diff_std, 1)}', transform=axs[xy].transAxes, color='green')

    # Set the y-axis to -10 to 10
    axs[xy].set_ylim(-20, 20)
    axs[xy].set_xlim(6, 30)

plt.tight_layout()

# Save the figure
utils.savefig(plt.gcf(), figure_dir, 'bland_altman')

plt.show()