In [None]:
import os
import pandas as pd
import pingouin as pg
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tools.fit_bms import *
from tools.draw import *
from tools.model_color import *
from sklearn.preprocessing import StandardScaler
from scipy.stats import shapiro, kruskal
from scipy.special import expit


    * 性别：1男，2女
    * 类型：1普通，2吸烟
    * 背景：1中性，2厌恶

In [None]:
data = pd.read_excel('datum\map_data.xlsx')

##### age/sex/number of each group

In [None]:
filtered_data  = data[(data['type'] == 2) & (data['background'] ==2)]
var = filtered_data['smoke_quantify']
(var != 'nan').sum()
print((np.mean(var)).round(2),(np.std(var)).round(2))

##### Figure 2A craving increase

In [None]:
## prepare dataframe
score = pd.concat((data['pre'], data['ing']))
raw_group = np.tile(data['type'], 2)
group = ['ND' if raw_group[t] == 2 else 'HC' for t in range(len(raw_group))]
time = ['before']*len(data) + ['after']*len(data)
subject = ([f'subject{i}' for i in range(1, len(data)+1)]) * 2
new_df = pd.DataFrame({'value':score, 'group':group, 'time':time, 'subject':subject})


In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'regular'}
plt.rc('font', **font)
color = {'before': '#c1bdb1', 'after': '#3e5751'}
y_name = 'craving score'
 
violin(ax, data = new_df, x = 'group', y = 'value',
       hue = 'time', order = ['HC', 'ND'], 
       hue_order = ['before', 'after'], palette = color)
ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
plt.show()

##### Figure 2B-C craving

In [None]:
## prepare dataframe
filtered_data = data[data['type'] == 2]

score = pd.concat((filtered_data['ing'], filtered_data['post']))

raw_group = np.tile(filtered_data['background'], 2)
group = ['neutral' if raw_group[t] == 1 else 'aversive' for t in range(len(raw_group))]
time = ['before']*len(filtered_data) + ['after']*len(filtered_data)
subject = ([f'subject{i}' for i in range(1, len(filtered_data)+1)]) * 2
new_df = pd.DataFrame({'value':score, 'group':group, 'time':time, 'subject':subject})
new_df

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'regular'}
plt.rc('font', **font)
color = {'before': '#c1bdb1', 'after': '#3e5751'}
y_name = 'craving score'
 
violin(ax, data = new_df, x = 'group', y = 'value',
       hue = 'time', order = ['neutral', 'aversive'], 
       hue_order = ['before', 'after'], palette = color)
ax.get_legend().remove()

basic_format(ax, '', f'{y_name}')
path = 'new_fig'
os.makedirs(path, exist_ok = True)
# plt.savefig(f'{path}/craving_HC.svg', transparent = True, format = 'svg')
plt.show()


##### Figure 3 emotion /best_choice(induce)

In [None]:
## prepare dataframe
param = 'best choice rate_rew'
score = (data[param])
raw_group = data['type']
group = ['ND' if raw_group[t] == 2 else 'HC' for t in range(len(raw_group))]
raw_time = data['background']
time =  ['aversive' if raw_time[t] == 2 else 'neutral' for t in range(len(raw_time))]
subject = ([f'subject{i}' for i in range(1, len(data)+1)])
new_df = pd.DataFrame({'value':score, 'group':group, 'time':time, 'subject':subject})

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'regular'}
plt.rc('font', **font)
color = {'neutral': [23/255, 66/255, 99/255], 'aversive':  [215/255, 87/255, 40/255]}
y_name = param
 
violin(ax, data = new_df, x = 'group', y = 'value',
       hue = 'time', order = ['HC', 'ND'], 
       hue_order = ['neutral', 'aversive'], palette = color)
ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
# plt.yticks([-3, -1.5, 0])
pth = 'new_fig'
os.makedirs(pth, exist_ok = True)
# plt.savefig(f'{pth}/best_choice_rew.svg', transparent = True, format = 'svg')
plt.show()

##### Figure 4 stay probability

In [None]:
filtered_data  = data[(data['type'] == 2) & (data['background'] == 1)]
ifreward = ['punished'] * len(filtered_data) * 2 + ['unpunished'] * len(filtered_data) * 2
ifcommon = (['common'] * len(filtered_data) + ['rare'] * len(filtered_data)) * 2
subject = [f'subject{i}' for i in range(1, len(filtered_data) + 1)] * 4
stay_prob = pd.concat([filtered_data['common punished'], filtered_data['rare punished'], filtered_data['common unpunished'], filtered_data['rare unpunished']], axis=0)

new_df = pd.DataFrame({'reward': ifreward, 'transition': ifcommon, 'stay_prob': stay_prob, 'subject': subject})

In [None]:
groups = ['ND','HC']
valences = ['pun','rew']
backgrounds = ['neg','neu']
for group in groups:
       for valence in valences:
              for background in backgrounds:
                     new_df = pd.read_csv(fr'stay_prob\{group}_{background}_{valence}.csv')

                     param = 'stay probability'
                     fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
                     font = {'family': 'Arial', 'weight': 'regular'}
                     plt.rc('font', **font)
                     color = {'common': [23/255, 66/255, 99/255], 'rare':  [215/255, 87/255, 40/255]}
                     y_name = param
                     order = ['reward', 'unrewarded'] if valence == 'rew' else ['unpunished','punished']

                     violin(ax, data = new_df, x = 'reward', y = 'stay_prob',
                            hue = 'transition', order = order, 
                            hue_order = ['common', 'rare'], palette = color)

                     sns.barplot(ax=ax, data=new_df, x='reward', y='stay_prob',
                            hue='transition', order=order,
                            hue_order=['common', 'rare'], palette=color,
                            alpha=0.3, edgecolor='black', ci=None)
                     plt.ylim((0,1))
                     ax.get_legend().remove()
                     basic_format(ax, '', f'{y_name}')
                     # plt.yticks([-4, -2,  0])
                     # plt.savefig(fr'stay_prob/figs/stay_{group}_{background}_{valence}.svg', transparent = True, format = 'svg')
                     plt.show()

##### Figure 5 w in HC/ND

In [None]:
# Ensure proper indexing with .iloc
param = 'w'
use_name = 'HC'
use_value = 1 if use_name == 'HC' else 2
filtered_data = data[data['type'] == use_value]
score = (pd.concat((filtered_data[f'{param}_rew'], filtered_data[f'{param}_pun'])))
raw_group = filtered_data['background']

# Using .iloc to access by position and fixing group assignment
group = ['aversive' if raw_group.iloc[t] == 2 else 'neutral' for t in range(len(filtered_data))] * 2

# Ensure 'time' and 'subject' lists are of appropriate length
time = ['reward'] * len(filtered_data) + ['punishment'] * len(filtered_data)
subject = [f'subject{i}' for i in range(1, len(filtered_data) + 1)] * 2

new_df = pd.DataFrame({'value': score, 'group': group, 'time': time, 'subject': subject})

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8), dpi=300)
font = {'family': 'Arial', 'weight': 'regular'}
plt.rc('font', **font)
color = {'neutral': [23/255, 66/255, 99/255], 'aversive':  [215/255, 87/255, 40/255]}
y_name = param

violin(ax, data = new_df, x = 'time', y = 'value',
       hue = 'group', order = ['reward', 'punishment'], 
       hue_order = ['neutral', 'aversive'], palette = color)
ax.get_legend().remove()
basic_format(ax, '', f'{y_name}')
# plt.yticks([-4, -2,  0])
path = f'new_fig'
# plt.savefig(f'{path}/w_HC.svg', transparent = True, format = 'svg')
plt.show()

##### Figure 6 correlation

In [None]:
## prepare dataframe
## smoke_quantify, smoke_year
filtered_data = data[(data['background'] == 2) & (data['type'] == 2)]
y = (filtered_data['w_rew'])
x = filtered_data['smoke_year']# ['smoke_quantify']

pg.corr(x, y, alternative='two-sided', method='pearson')

In [None]:
fig, ax = plt.subplots(figsize = (2, 1.5), dpi = 300)
sns.regplot(
            x=x, y=y, color="indianred",
            line_kws={'color': 'indianred', 'alpha': 0.3, 'zorder': 1},
            scatter_kws={'zorder': 1} 
            )
basic_format(ax, 'Cigarettes per day',r'w')
# plt.savefig(r'new_fig\w_corr.svg', transparent = True, format = 'svg')

#### parameter recovery

In [None]:
# load datum
real_data = pd.read_csv(r'param_rev\true_parameters.csv') #, skiprows = range(1, 800) 
recov_data = pd.read_csv(r'param_rev\param_recov.csv')

In [None]:
# do correlation
vars = ['beta1', 'beta2', 'alpha1', 'alpha2', 'lambda', 'p', 'w']
for varname in vars:
    y = recov_data[varname]
    x = real_data[varname]

    res = pg.corr(x, y, alternative='two-sided', method='pearson')
    print(f'{varname}:{res}')
    pth = 'param_rev/figs'
    os.makedirs(pth, exist_ok = True)
    # plot
    fig, ax = plt.subplots(figsize = (4,3.5), dpi = 300)
    sns.regplot(
                x=x, y=y, color="indianred",
                line_kws={'color': 'indianred', 'alpha': 0.3, 'zorder': 1},
                scatter_kws={'zorder': 1} 
                )

    basic_format(ax, f'real {varname}', f'recovered {varname}')
    # plt.savefig(f'{pth}/{varname}.svg', transparent = True, format = 'svg')
