In [None]:
import scipy
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


# Formulas for confidence bounds for correlation coefficient
def std_err(n,r):
    sr = np.sqrt((1-r**2)/(n-2))
    return [sr,sr]

In [None]:
# Setting the positions and width for the bars
results_df_test = pd.read_csv("fig_data/SuppFig3.csv")
pos = list(range(4))
width = 0.25 
    
# Plotting the bars
fig, ax = plt.subplots(figsize=(5.3,5.3))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

plt.bar(pos, results_df_test[results_df_test['model'] == 'cnn']['correlation'], 
        width, 
        alpha=0.5, label='CNN');

plt.bar([p + width for p in pos], results_df_test[results_df_test['model'] == 'rr']['correlation'], 
        width, 
        alpha=0.5, label='Ridge Regression');

plt.bar([p + 2*width for p in pos], results_df_test[results_df_test['model'] == 'rf']['correlation'], 
        width, 
        alpha=0.5, label='Random Forest');

n = 348
corr_ci_cnn = np.array([std_err(n,r) for r in results_df_test[results_df_test['model'] == 'cnn']['correlation']]).transpose()
corr_ci_rr = np.array([std_err(n,r) for r in results_df_test[results_df_test['model'] == 'rr']['correlation']]).transpose()
corr_ci_rf = np.array([std_err(n,r) for r in results_df_test[results_df_test['model'] == 'rf']['correlation']]).transpose()

#print(results_df_test[results_df_test['model'] == 'cnn']['correlation'])
plt.errorbar(pos, results_df_test[results_df_test['model'] == 'cnn']['correlation'], 
             corr_ci_cnn, fmt=' ', capsize=4)
plt.errorbar([p + width for p in pos], results_df_test[results_df_test['model'] == 'rr']['correlation'], 
             corr_ci_rr, fmt=' ', capsize=4)
plt.errorbar([p + 2*width for p in pos], results_df_test[results_df_test['model'] == 'rf']['correlation'], 
             corr_ci_rf, fmt=' ', capsize=4)

ax.set_xticks([p + width for p in pos])
labels = results_df_test['target'].tolist()
labels[3] = "knee flexion\nat max extension"
ax.set_xticklabels(labels);

plt.xlabel("Parameter")
plt.ylabel("Correlation")
leg = plt.legend()
leg.get_frame().set_linewidth(0)
plt.ylim([0.5,1]);
plt.title("Model Performance",fontsize=14)
plt.savefig("output/Supplement-Figure-3.pdf",bbox_inches="tight",pad_inches=0.1)