In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

: 

In [None]:

# Load the drift rates per subject and condition
df = pd.read_csv("ddm_absolute_drift_subjects_grandmean.csv")
print(df.head(5))
print("\nUnique conditions:", df['Condition'].unique())
print("\nDescriptive stats:\n", df[['Drift','BDI','LSAS']].describe())


: 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8,5))
sns.boxplot(data=df, x='Condition', y='Drift', palette="Set3")
sns.stripplot(data=df, x='Condition', y='Drift', color='black', alpha=0.6, jitter=0.2)
plt.title("Drift Rate Distribution by Condition")
plt.axhline(0, color='gray', linestyle='--')
plt.xticks(rotation=15)
plt.show()

In [None]:
condition_means = df.groupby('Condition')['Drift'].mean()
print(condition_means)

In [None]:
# Prepare data for interaction plot: mean drift per condition
mean_df = df.groupby(['Condition']).agg({'Drift':'mean'}).reset_index()
# Add factors for plotting
mean_df['Valence'] = mean_df['Condition'].apply(lambda x: 'Positive' if 'positive' in x else 'Negative')
mean_df['Axis'] = mean_df['Condition'].apply(lambda x: 'Affiliation' if 'affiliation' in x else 'Dominance')

# Pivot to matrix of means for plotting
mean_matrix = mean_df.pivot(index='Axis', columns='Valence', values='Drift')
print(mean_matrix)  # to see values

# Plot interaction: Valence on x, Drift on y, separate lines for Axis
fig, ax = plt.subplots(figsize=(5,4))
axes = ['Affiliation','Dominance']
valences = ['Negative','Positive']
colors = {'Affiliation':'teal','Dominance':'orange'}
for axis in axes:
    ax.plot(valences, [mean_matrix.loc[axis, v] for v in valences], marker='o', label=axis, color=colors[axis])
    ax.fill_between(valences,  # dummy approach to add error bands if needed
                    [mean_matrix.loc[axis, v] for v in valences],
                    [mean_matrix.loc[axis, v] for v in valences], color=colors[axis], alpha=0.1)
ax.axhline(0, color='gray', linestyle='--')
ax.set_ylabel("Mean Drift Rate")
ax.set_title("Interaction of Valence and Circumplex Axis on Drift")
ax.legend(title="Axis")
plt.show()


In [None]:
import numpy as np
# Compute per-subject mean drifts for positive and negative conditions
subj_means = df.pivot(index='Subject', columns='Condition', values='Drift')
subj_means['pos_mean'] = subj_means[['positive_highaffiliation','positive_highdominance']].mean(axis=1)
subj_means['neg_mean'] = subj_means[['negative_lowaffiliation','negative_lowdominance']].mean(axis=1)
subj_means['valence_diff'] = subj_means['pos_mean'] - subj_means['neg_mean']
subj_means = subj_means.merge(df[['Subject','BDI','LSAS']].drop_duplicates(), on='Subject')

# Correlation between BDI and drifts
from scipy.stats import pearsonr
print("BDI vs pos_mean:", pearsonr(subj_means['pos_mean'], subj_means['BDI']))
print("BDI vs neg_mean:", pearsonr(subj_means['neg_mean'], subj_means['BDI']))
print("BDI vs valence_diff (pos-minus-neg):", pearsonr(subj_means['valence_diff'], subj_means['BDI']))


In [None]:
# Double-check subject-wise average drifts
valence_df = df.copy()
valence_df['Valence'] = valence_df['Condition'].apply(lambda x: 'Positive' if 'positive' in x else 'Negative')

# Compute per-subject average drift for positive/negative
valence_means = valence_df.groupby(['Subject', 'Valence'])['Drift'].mean().unstack()
valence_means['valence_diff'] = valence_means['Positive'] - valence_means['Negative']

# Merge with BDI
valence_means = valence_means.merge(df[['Subject', 'BDI']].drop_duplicates(), on='Subject')

# Plot
sns.regplot(data=valence_means, x='BDI', y='valence_diff', lowess=True, line_kws={'color': 'red'})
plt.title("BDI vs Valence Drift Difference (Positive - Negative)")
plt.xlabel("BDI (Depression Score)")
plt.ylabel("Valence Drift Difference")
plt.axhline(0, linestyle='--', color='gray')
plt.show()

In [None]:
import statsmodels.formula.api as smf
model = smf.ols("valence_diff ~ BDI + LSAS", data=subj_means).fit()
print(model.summary().tables[1])

In [None]:
print("LSAS vs pos_mean:", pearsonr(subj_means['pos_mean'], subj_means['LSAS']))
print("LSAS vs neg_mean:", pearsonr(subj_means['neg_mean'], subj_means['LSAS']))
print("LSAS vs valence_diff:", pearsonr(subj_means['valence_diff'], subj_means['LSAS']))


In [None]:
# Scatter of LSAS vs drift in dominance conditions
plt.figure(figsize=(7, 5))
sns.regplot(data=subj_means, x='LSAS', y='positive_highdominance', lowess=True,
            scatter_kws={'alpha': 0.6}, line_kws={'color': 'blue'}, label='Positive HighDominance')

sns.regplot(data=subj_means, x='LSAS', y='negative_lowdominance', lowess=True,
            scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'}, label='Negative LowDominance')

plt.xlabel("LSAS (Social Anxiety Score)")
plt.ylabel("Drift Rate")
plt.title("LSAS vs Drift in Dominance-related Conditions")
plt.legend()
plt.show()

In [None]:
# Multiple regression: overall mean drift ~ BDI + LSAS
subj_means['drift_mean'] = subj_means[['pos_mean','neg_mean']].mean(axis=1)
model2 = smf.ols("drift_mean ~ BDI + LSAS", data=subj_means).fit()
print(model2.summary().tables[1])


In [None]:
# Label Circumplex axis
df['Axis'] = df['Condition'].apply(lambda x: 'Affiliation' if 'affiliation' in x else 'Dominance')

# Compute per-subject average drift for each Axis
axis_means = df.groupby(['Subject', 'Axis'])['Drift'].mean().unstack()
axis_means['axis_diff'] = axis_means['Affiliation'] - axis_means['Dominance']

# Merge with clinical data
axis_means = axis_means.merge(df[['Subject', 'BDI', 'LSAS']].drop_duplicates(), on='Subject')

# Plot
sns.regplot(data=axis_means, x='BDI', y='axis_diff', lowess=True, line_kws={'color': 'green'})
plt.title("BDI vs Drift Difference (Affiliation - Dominance)")
plt.xlabel("BDI (Depression Score)")
plt.ylabel("Axis Drift Difference")
plt.axhline(0, linestyle='--', color='gray')
plt.show()


In [None]:
import statsmodels.formula.api as smf

# Regression: valence difference
model_val = smf.ols("valence_diff ~ BDI + LSAS", data=valence_means).fit()
print(model_val.summary())

# Regression: axis differencep
model_axis = smf.ols("axis_diff ~ BDI + LSAS", data=axis_means).fit()
print(model_axis.summary())
