In [None]:
import os
import pandas as pd
import pingouin as pg
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tools.draw import *


In [None]:
def remove_outliers_3sd(data):
    mean = np.mean(data)
    std = np.std(data)
    lower_bound = mean - 3 * std
    upper_bound = mean + 3 * std
    outliers = [i for i, val in enumerate(data) if val < lower_bound or val > upper_bound]
    return outliers, lower_bound, upper_bound

def filter_any_outliers(pre_data, post_data):
    mean_pre = np.mean(pre_data)
    std_pre = np.std(pre_data)
    mean_post = np.mean(post_data)
    std_post = np.std(post_data)

    # 计算上下限
    lower_pre = mean_pre - 3 * std_pre
    upper_pre = mean_pre + 3 * std_pre
    lower_post = mean_post - 3 * std_post
    upper_post = mean_post + 3 * std_post

    # 找出哪些被试有极端值 (在 pre 或 post 中超过 3 个标准差)
    valid_indices = [i for i in range(len(pre_data)) 
                     if (lower_pre <= pre_data[i] <= upper_pre) and (lower_post <= post_data[i] <= upper_post)]

    # 剔除极端值
    filtered_pre = [pre_data[i] for i in valid_indices]
    filtered_post = [post_data[i] for i in valid_indices]
    
    return filtered_pre, filtered_post

##### best choice/emotion/MB-MF scores

In [None]:
## load datum
emotion = pd.read_csv('datum/emotion.csv')
best_choice = pd.read_csv('datum/bestchoice_bili.csv', encoding='ISO-8859-1')
w_pun = pd.read_csv('datum/params.csv', encoding='cp1252')
MB_rew = pd.read_excel('datum/scores.xlsx')

In [None]:
df_filtered = emotion[emotion['type'] == 2]
df_filtered.describe()

In [None]:
ask_name = 'MF_pun'
ask_type = 'ND'
ask_bg = 'ad'

typing = 1 if ask_type == 'HC' else 2
bg = 1 if ask_bg == 'neutral' else 2
scores = [MB_rew[ask_name][i] for i in range(len(MB_rew))if MB_rew['type'][i]==typing and  MB_rew['background'][i]==bg]
pg.ttest(scores, 0)

In [None]:
## draw figures of emotion/best choice
var = 'MB_rew'
if var == 'emotion':
       y_name = 'adversion score (a.u.)'
elif var == 'best_choice': 
       y_name = 'best choice rate (a.u.)'
elif var == 'w_pun' or var =='w_rew': 
       y_name = f' MB weight (w)\n  (a.u.)'
elif var == 'MB_pun' or var == 'MB_rew': 
       y_name = f' MB score (a.u.)'
elif var == 'MF_pun' or var == 'MF_rew': 
       y_name = f' MF score (a.u.)'

eval(var)['type'] = eval(var)['type'].replace({1: 'HC', 2: 'ND'})
eval(var)['background'] = eval(var)['background'].replace({1: 'neutral', 2: 'adversive'})
value = eval(var)[var]
name = eval(var)['background']
group = eval(var)['type']
df_format = pd.DataFrame({'value':value, 'name':name, 'group':group})

ask_group = 'HC'
ask_type = 'adversive'
group_len = len([df_format['value'][i] for i in range(len(df_format)) if df_format['group'][i]==ask_group and  df_format['name'][i]==ask_type])
group_len
# pg.anova(data=w_pun, dv='w_rew', between=['type','background'])

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'bold'}
plt.rc('font', **font)
color = {'neutral':'rosybrown','adversive':'slategray'}

violin(ax=ax, data=df_format, x='group', y='value', hue='name',
       order=['HC', 'ND'], hue_order=['neutral', 'adversive'],
       palette=color)

ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
save_path = f'config/{var}.svg'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, format='svg', transparent=True)
plt.show()

##### manipulation check --clue exposure

In [None]:
craving_induce = pd.read_csv('datum/craving_indu.csv', encoding='cp1252')

craving_induce['type'] = craving_induce['type'].replace({1: 'HC', 2: 'ND'})
craving_induce['time'] = craving_induce['time'].replace({1: 'pre', 2: 'post'})
value = craving_induce['crave_score']
name = craving_induce['time']
group = craving_induce['type']

df_format = pd.DataFrame({'value': value, 'name': name, 'group': group})

pre_HC = [craving_induce['crave_score'][i] for i in range(len(craving_induce)) if craving_induce['type'][i] == 'HC' and craving_induce['time'][i] == 'pre']
post_HC = [craving_induce['crave_score'][i] for i in range(len(craving_induce)) if craving_induce['type'][i] == 'HC' and craving_induce['time'][i] == 'post']
pre_ND = [craving_induce['crave_score'][i] for i in range(len(craving_induce)) if craving_induce['type'][i] == 'ND' and craving_induce['time'][i] == 'pre']
post_ND = [craving_induce['crave_score'][i] for i in range(len(craving_induce)) if craving_induce['type'][i] == 'ND' and craving_induce['time'][i] == 'post']

# 剔除异常值
pre_HC_filtered, post_HC_filtered = filter_any_outliers(pre_HC, post_HC)
pre_ND_filtered, post_ND_filtered = filter_any_outliers(pre_ND, post_ND)

ur_name = 'ND'
timing ='pre'
print(len(eval(f'{timing}_{ur_name}_filtered')), len(eval(f'{timing}_{ur_name}')))
pg.ttest(eval(f'pre_{ur_name}_filtered'), eval(f'post_{ur_name}_filtered'), paired=True)


In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'bold'}
plt.rc('font', **font)
color = {'pre': '#c1bdb1', 'post': '#3e5751'}

violin(ax=ax, data=df_format, x='group', y='value', hue='name',
       order=['HC', 'ND'], hue_order=['pre', 'post'],
       palette=color)

ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
save_path = f'config/craving_induce.svg'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, format='svg', transparent=True)
plt.show()

##### manipulation check --markov task

In [None]:
HC = pd.read_csv('HC_redu.csv', encoding='cp1252')
ND = pd.read_csv('ND_redu.csv', encoding='cp1252')
name = 'ND'
comp_type = 'neu' 

indiv = eval(name)
indiv['background'] = indiv['background'].replace({1: 'neutral', 2: 'adversive'})
value = indiv['score']
name = indiv['background']
group = indiv['time']

df_format = pd.DataFrame({'value': value, 'name': name, 'group': group})

pre_neu = [indiv['score'][i] for i in range(len(indiv)) if indiv['background'][i] == 'neutral' and indiv['time'][i] == 'pre']
post_neu = [indiv['score'][i] for i in range(len(indiv)) if indiv['background'][i] == 'neutral' and indiv['time'][i] == 'post']
pre_neg = [indiv['score'][i] for i in range(len(indiv)) if indiv['background'][i] == 'adversive' and indiv['time'][i] == 'pre']
post_neg = [indiv['score'][i] for i in range(len(indiv)) if indiv['background'][i] == 'adversive' and indiv['time'][i] == 'post']

pre_filtered, post_filtered = filter_any_outliers(eval(f'pre_{comp_type}'), eval(f'post_{comp_type}'))

print(len(pre_filtered),len(eval(f'pre_{comp_type}')))
pg.ttest(pre_filtered, post_filtered, paired=True)

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'bold'}
plt.rc('font', **font)
color = {'pre':'#c1bdb1','post':'#3e5751'}
violin(ax=ax, data=df_format, x='name', y='value', hue='group',
       order=['neutral', 'adversive'], hue_order=['pre', 'post'],
       palette=color, scatter_size=5)

ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
save_path = f'config/{var}.svg'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, format='svg', transparent=True)
plt.show()

In [None]:
HC = pd.read_csv('HC_redu.csv', encoding='cp1252')
ND = pd.read_csv('ND_redu.csv', encoding='cp1252')
name = 'HC'
comp_type = 'neg'
pre_neu = [eval(name)['score'][i] for i in range(len(eval(name)))if eval(name)['background'][i]==1 and eval(name)['time'][i]=='pre']
post_neu = [eval(name)['score'][i] for i in range(len(eval(name)))if eval(name)['background'][i]==1 and eval(name)['time'][i]=='post']
pre_neg = [eval(name)['score'][i] for i in range(len(eval(name)))if eval(name)['background'][i]==2 and eval(name)['time'][i]=='pre']
post_neg = [eval(name)['score'][i] for i in range(len(eval(name)))if eval(name)['background'][i]==2 and eval(name)['time'][i]=='post']

pg.ttest(eval(f'pre_{comp_type}'), eval(f'post_{comp_type}'),paired=True)

##### check if there exists linear relationship betwen emotion and learning behavior

In [None]:
df = pd.DataFrame({'emotion':emotion['emotion'], 'w':best_choice['best_choice'],
         'type':emotion['type'],'background':emotion['background'],
         })

df_filtered = df[df['type'] == 2]

plt.figure(figsize=(8, 6))
sns.regplot(x='emotion', y='w', data=df_filtered, ci=None) 
plt.title('Scatter Plot with Regression Line (Emotion vs w)', fontsize=14)
plt.xlabel('Emotion', fontsize=12)
plt.ylabel('w', fontsize=12)
plt.show()
pg.linear_regression(df_filtered[['emotion']], df_filtered['w'])

##### draw stay probablity 

In [None]:
stay_prob = pd.read_csv('stay_prob.csv')

stay_prob['background'] = stay_prob['background'].replace({1:'neutral',2:'adverse'})
stay_prob['valence'] = stay_prob['valence'].replace({1:'reward',2:'punish'})
stay_prob['probability'] = stay_prob['probability'].replace({1:'common',2:'rare'})
stay_prob['type'] = stay_prob['type'].replace({1:'HC',2:'ND'})
stay_prob['result'] = np.where((stay_prob['valence'] == 'reward') & (stay_prob['result'] == 1), 'unrewarded',
    np.where((stay_prob['valence'] == 'reward') & (stay_prob['result'] == 2), 'rewarded',
    np.where((stay_prob['valence'] == 'punish') & (stay_prob['result'] == 1), 'unpunished',
    np.where((stay_prob['valence'] == 'punish') & (stay_prob['result'] == 2), 'punished', stay_prob['result'])))
)

ask_group = 'HC'
ask_valence = 'reward'
ask_bg = 'neutral'
order = ['rewarded','unrewarded'] if ask_valence == 'reward'else ['unpunished','punished']

res_sp = [stay_prob['sp'][i] for i in range(len(stay_prob)) if stay_prob['background'][i]==ask_bg and stay_prob['valence'][i]==ask_valence and stay_prob['type'][i]==ask_group]
res_prob = [stay_prob['probability'][i] for i in range(len(stay_prob)) if stay_prob['background'][i]==ask_bg and stay_prob['valence'][i]==ask_valence and stay_prob['type'][i]==ask_group]
res_lastgain = [stay_prob['result'][i] for i in range(len(stay_prob)) if stay_prob['background'][i]==ask_bg and stay_prob['valence'][i]==ask_valence and stay_prob['type'][i]==ask_group]

# datum's format: dataframe--cols: name(x), value(y), group
## in need: probability(common/rare);  result(rewarded/unrewarded)
df_format = pd.DataFrame({'value':res_sp, 'name':res_lastgain, 'group':res_prob})
fig, ax = plt.subplots(figsize=(2, 1.5), dpi=300)
font = {'family': 'Arial', 'weight': 'bold'}
plt.rc('font', **font)
color = {'common':'#9f9f9f','rare':'#cfcece'}
bar_scat(ax=ax, data=df_format,dot_size=5, color=color,order=order)
# ax.get_legend().remove()
basic_format(ax, ' ', 'stay probability')
    # 调整图例显示，避免重复
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(set(df_format['group']))], labels[:len(set(df_format['group']))], loc='upper right')

path = f'config/{ask_group}_{ask_valence}_{ask_bg}.svg'
plt.savefig(path, format='svg', transparent=True)
plt.show()
