In [1]:
import pandas as pd

df = pd.read_csv('data/gradients.csv')

In [2]:
df

Unnamed: 0,subject,epoch,region,7net,17net,g1,g2,g3,ecc
0,1,baseline,7Networks_LH_Vis_1,Vis,DorsAttnA,0.516958,-0.186538,1.501042,1.598490
1,1,baseline,7Networks_LH_Vis_2,Vis,VisCent,0.636391,1.407541,0.977047,1.827782
2,1,baseline,7Networks_LH_Vis_3,Vis,DorsAttnA,0.396966,0.630318,1.806822,1.954351
3,1,baseline,7Networks_LH_Vis_4,Vis,VisCent,2.380421,2.049884,0.085447,3.142568
4,1,baseline,7Networks_LH_Vis_5,Vis,DefaultC,0.408135,0.698659,0.977117,1.268644
...,...,...,...,...,...,...,...,...,...
109291,46,late,Right Putamen,Striatum,StriatumRight,-0.861413,-0.664291,-0.450307,1.177324
109292,46,late,Right Pallidum,Striatum,StriatumRight,-0.530126,-0.774942,-0.468807,1.049451
109293,46,late,Right Hippocampus,Striatum,StriatumRight,-0.938860,-0.247830,-1.223909,1.562316
109294,46,late,Right Amygdala,Striatum,StriatumRight,-1.114832,-0.355354,-1.374354,1.804986


In [3]:
df_mean = df.groupby(['region', 'epoch', '7net', '17net']).agg('mean').drop('subject', axis=1).reset_index()

In [4]:
df_mean

Unnamed: 0,region,epoch,7net,17net,g1,g2,g3,ecc
0,7Networks_LH_Cont_Cing_1,baseline,Cont,SalVentAttnB,-1.568082,-0.292681,-0.214813,2.021287
1,7Networks_LH_Cont_Cing_1,early,Cont,SalVentAttnB,-1.635978,0.037227,-0.567208,2.298056
2,7Networks_LH_Cont_Cing_1,late,Cont,SalVentAttnB,-1.657511,0.038653,-0.614793,2.245801
3,7Networks_LH_Cont_Cing_2,baseline,Cont,ContC,0.483634,-1.805169,-0.963203,2.474936
4,7Networks_LH_Cont_Cing_2,early,Cont,ContC,0.398380,-1.651304,-1.331622,2.381022
...,...,...,...,...,...,...,...,...
3031,Right Putamen,early,Striatum,StriatumRight,-0.689255,-0.110267,-0.145195,1.166278
3032,Right Putamen,late,Striatum,StriatumRight,-0.939216,-0.178111,-0.185569,1.369881
3033,Right Thalamus,baseline,Striatum,StriatumRight,-0.216845,-0.441681,-0.152570,1.117756
3034,Right Thalamus,early,Striatum,StriatumRight,-0.242206,-0.229092,-0.132197,1.109874


In [7]:
dfm_baseline = df_mean[df_mean.epoch=='baseline']

In [8]:
import plotly.express as pex
# import plotly.offline as pyo
# pyo.init_notebook_mode()

In [9]:
import seaborn as sns

def yeo_cmap(as_palette=False, networks=7):
    # ref: https://github.com/danjgale/adaptation-manifolds/blob/main/adaptman/analyses/plotting.py
    if networks == 17:
        cmap = {
            'VisCent': (120, 18, 136),
            'VisPeri': (255, 0, 2),
            'SomMotA': (70, 130, 181),
            'SomMotB': (43, 204, 165),
            'DorsAttnA': (74, 156, 61),
            'DorsAttnB': (0, 118, 17),
            'SalVentAttnA': (196, 58, 251),
            'SalVentAttnB': (255, 153, 214),
            'TempPar': (9, 41, 250),
            'ContA': (230, 148, 36),
            'ContB': (136, 50, 75),
            'ContC': (119, 140, 179),
            'DefaultA': (255, 254, 1),
            'DefaultB': (205, 62, 81),
            'DefaultC': (0, 0, 132),
            'LimbicA': (224, 248, 166),
            'LimbicB': (126, 135, 55)
        }
    else:
        cmap = {
            'Vis': (119, 20, 140),
            'SomMot': (70, 126, 175),
            'DorsAttn': (0, 117, 7),
            'SalVentAttn': (195, 59, 255),
            'Limbic': (219, 249, 165),
            'Cont': (230, 149, 33),
            'Default': (205, 65, 80)
        }
    if as_palette:  return sns.color_palette(cmap.values())
    else:   return cmap

In [10]:
# this is how you can feed in color codes to plotly
# color_discrete_map={'Vis': 'rgb(0,255,0)'}

In [11]:
cmap7 = yeo_cmap(networks=7)
cmap7['Striatum'] = (0,0,0) # add Striatum
cmap7 = {k: 'rgb'+str(v) for k, v in cmap7.items()}

In [23]:
ax = pex.scatter_3d(x='g1',y='g2',z='g3', color='7net',
                    data_frame=dfm_baseline, opacity=.7,
                    color_discrete_map=cmap7,
                    category_orders={'7net': cmap7.keys()}
                    )
ax.update_layout(margin=dict(l=0, r=0, b=0, t=0))
ax.update_traces(marker_size=3)
ax.show()
ax.write_image("fig1.svg")

In [22]:
my_plot = pex.scatter_3d(x='g1',y='g2',z='g3',color='ecc',
                         data_frame=dfm_baseline, opacity=.7,
                         color_continuous_scale='viridis',
                         )
my_plot.update_layout(margin=dict(l=0, r=0, b=0, t=0))
my_plot.update_traces(marker_size=3)
my_plot.show()
my_plot.write_image("fig2.svg")