In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from adjustText import adjust_text
import numpy as np
from scipy import stats

### Prediction Pipeline Steps
- quadratic fitting: `polyfit`
- transform the data to z-score
- set up the higher boundary and lower boundary. higher is the highest value; lower is the highest value in SNC group
- identify all the overshooting participants

In [2]:
data = pd.read_excel('/Users/yilewang/workspaces/data4project/prediction_data.xlsx', sheet_name='main', skiprows=1)

In [34]:
## Step 0. Define function
def get_r_squared(x, y, degree=2):
    # Calculate the polynomial
    coefficients = np.polyfit(x, y, degree)
    y_pred = np.polyval(coefficients, x)

    # Calculate the R-squared value
    ss_res = np.sum((y - y_pred) ** 2)
    ss_tot = np.sum((y - np.mean(y)) ** 2)
    r_squared = 1 - (ss_res / ss_tot)

    return r_squared

def get_z_scores(data):
    return stats.zscore(data)

def get_overshooter(data):
    df_overshooter = pd.DataFrame({'group': data['group'], 'caseid': data['caseid']})
    for i in data.columns[2:]:
        # in pandas, get the highest value's index
        max_index = data[i].idxmax()
        # get the highest value's index in SNC group
        max_index_SNC = data[data['group'] == 'SNC'][i].idxmax()
        high_bound, low_bound = data[i][max_index], data[i][max_index_SNC]
        # get the all the index between the high_bound and low_bound
        overshooter = data[(data[i] <= high_bound) & (data[i] > low_bound)]['caseid']
        df_overshooter[i] = df_overshooter['caseid'].apply(lambda x: 1 if x in overshooter.values else 0)
    return df_overshooter

def stripplot_get_coordinates(data, x_var, y_var):
    # Demo dataset and apply some culling for visibility

    # Create the stripplot
    ax = sns.stripplot(x=x_var, y=y_var, data=data, jitter=True, dodge=True)

    # Get the coordinates of the points
    coordinates = []
    for collection in ax.collections:
        coordinates.extend(collection.get_offsets())
    return np.array(coordinates)


### Step1. get quadratic list

In [15]:
x = [1,2,3,4]
quadratic_list = []
for i in data.columns[12:]:
    # get groupmean 
    groupmean = data.groupby('group')[i].mean()
    linear_fit = get_r_squared(x, groupmean.values, degree=1)
    quadratic_fit = get_r_squared(x, groupmean.values, degree=2)
    if linear_fit <0.9 and quadratic_fit > 0.9:
        quadratic_list.append(i)
print(quadratic_list)

['ignition_mCNG-R', 'ignition_mCNG-L', 'ignition_pCNG-L', 'ignition_PHG-L', 'ignition_AMY-R', 'ignition_AMY-L', 'ignition_sTEMp-R', 'ignition_sTEMp-L', 'ignition_mTEMp-R', 'ignition_mTEMp-L', 'Go_Gc', 'K21', 'MC_aCNG', 'MC_mCNG', 'MC_pCNG', 'MC_PHG', 'MC_AMY', 'MC_sTEMp', 'cluster_4', 'pcgl_freq_gamma', 'pcgr_freq_gamma', 'pcgl_freq_theta', 'freq_gamma_over_theta', 'pcgl_amp_gamma', 'pcgr_amp_gamma', 'pcgl_amp_theta', 'LI_freq_gamma', 'LI_freq_theta', 'LI_amp_gamma', 'wdc_HIP_R', 'wdc_AMY_L', 'wdc_AMY_R', 'wdc_sTEMp_R', 'wdc_mTEMp_L', 'wdc_mTEMp_R', 'wdc_average']


### Step2. boundary

In [37]:
pd_quadratic = pd.concat([data['group'], data['caseid'], data[quadratic_list]], axis=1)
pd_overshoot = get_overshooter(pd_quadratic)
pd_overshoot['sum'] = pd_overshoot[quadratic_list].sum(axis=1)
pd_overshoot['sex'] = data['M_Sex']
pd_overshoot.to_excel('/Users/yilewang/workspaces/data4project/overshoot.xlsx', index=False)

In [None]:
### Visualization Module
for i in data.columns[5:]:
    figure = plt.figure(figsize=(15, 10))
    ax = figure.add_subplot(111)
    coordinates = stripplot_get_coordinates(data, 'group', i)
    ax.plot(coordinates[:,0], coordinates[:,1], 'o')
    texts = [plt.text(coordinates[:,0][i], coordinates[:,1][i], f'{group}_{caseid}', ha='center', va='center') for i, (group, caseid) in enumerate(zip(data['group'], data['caseid']))]

    # texts = [plt.text(coordinates[:,0][i], coordinates[:,1][i], f'{sex}_{group}_{caseid}', ha='center', va='center') for i, (group, caseid, sex) in enumerate(zip(data['group'], data['caseid'], data['M_Sex']))]
    adjust_text(texts, expand=(1.2, 2), # expand text bounding boxes by 1.2 fold in x direction and 2 fold in y direction
            arrowprops=dict(arrowstyle='->', color='red') # ensure the labeling is clear by adding arrows
            )
    plt.show()
    # Label the points with their y-offset (y_var)