# 3D plot
In this notebook, we will plot, in 3D space, the PC components for each region averaged over all subjects.

In [159]:
# load gradients data

In [160]:
import pandas as pd

df = pd.read_csv('data/gradients.csv')

In [161]:
df.head()

Unnamed: 0,subject,epoch,region,7net,17net,g1,g2,g3,ecc
0,1,baseline,7Networks_LH_Vis_1,Vis,DorsAttnA,0.516958,-0.186538,1.501042,1.59849
1,1,baseline,7Networks_LH_Vis_2,Vis,VisCent,0.636391,1.407541,0.977047,1.827782
2,1,baseline,7Networks_LH_Vis_3,Vis,DorsAttnA,0.396966,0.630318,1.806822,1.954351
3,1,baseline,7Networks_LH_Vis_4,Vis,VisCent,2.380421,2.049884,0.085447,3.142568
4,1,baseline,7Networks_LH_Vis_5,Vis,DefaultC,0.408135,0.698659,0.977117,1.268644


In [162]:
# now we can get the mean of ecc for each subject and each epoch and each region and drop subject
# df.groupby(level=['epoch', 'region', 'subject']).mean().droplevel('subject')

In [163]:
df_mean = df.groupby(['region', 'epoch', '7net', '17net']).agg('mean').drop('subject', axis=1).reset_index()

In [164]:
df_mean.head()

Unnamed: 0,region,epoch,7net,17net,g1,g2,g3,ecc
0,7Networks_LH_Cont_Cing_1,baseline,Cont,SalVentAttnB,-1.568082,-0.292681,-0.214813,2.021287
1,7Networks_LH_Cont_Cing_1,early,Cont,SalVentAttnB,-1.635978,0.037227,-0.567208,2.298056
2,7Networks_LH_Cont_Cing_1,late,Cont,SalVentAttnB,-1.657511,0.038653,-0.614793,2.245801
3,7Networks_LH_Cont_Cing_2,baseline,Cont,ContC,0.483634,-1.805169,-0.963203,2.474936
4,7Networks_LH_Cont_Cing_2,early,Cont,ContC,0.39838,-1.651304,-1.331622,2.381022


color every region with its network (`7net` or `17net`)

In [165]:
cmap7 = {
        'Vis': (119, 20, 140),
        'SomMot': (70, 126, 175),
        'DorsAttn': (0, 117, 7),
        'SalVentAttn': (195, 59, 255),
        'Limbic': (219, 249, 165),
        'Cont': (230, 149, 33),
        'Default': (205, 65, 80),
        'Striatum': (0,0,0),    # color Striatum as black
    }
cmap7 = {k: 'rgb'+str(v) for k, v in cmap7.items()}

In [166]:
# this is how you can feed in color codes to plotly
# color_discrete_map={'Vis': 'rgb(0,255,0)'}

## colored by 7 networks
plot `baseline`

In [167]:
import plotly.express as pex

In [168]:
dfm_baseline = df_mean[df_mean.epoch=='baseline']

In [169]:
ax = pex.scatter_3d(x='g1',y='g2',z='g3', color='7net',
                    data_frame=dfm_baseline, opacity=.7,
                    color_discrete_map=cmap7,
                    category_orders={'7net': cmap7.keys()}
                    )
ax.update_layout(margin=dict(l=0, r=0, b=0, t=0))
ax.update_traces(marker_size=3)
ax.show()
# ax.write_image("fig1.svg")

## colored by eccentricity

In [170]:
ax = pex.scatter_3d(x='g1',y='g2',z='g3',color='ecc',
                         data_frame=dfm_baseline, opacity=.7,
                         color_continuous_scale='viridis',
                         )
ax.update_layout(margin=dict(l=0, r=0, b=0, t=0))
ax.update_traces(marker_size=3)
ax.show()
# ax.write_image("fig2.svg")

# traverse between epochs
from baseline to early

In [171]:
import plotly.graph_objects as go

In [172]:
# uncomment the pair you want to plot:
pair = ('baseline', 'early')
# pair = ('early', 'late')

fig = go.Figure()

for r in df_mean.region.unique():
    data = df_mean.set_index(['region', 'epoch']).loc[r]

    # plot `baseline` in 3d
    fig.add_trace(go.Scatter3d(x=[data.loc[pair[0]].g1],
                               y=[data.loc[pair[0]].g2],
                               z=[data.loc[pair[0]].g3],
                               mode='markers',
                               marker=dict(size=.1, color='blue',
                                           # symbol='circle',
                                           opacity=0.7),
                               name=pair[0]))

    # plot `early` in 3d
    fig.add_trace(go.Scatter3d(x=[data.loc[pair[1]].g1],
                               y=[data.loc[pair[1]].g2],
                               z=[data.loc[pair[1]].g3],
                               mode='markers',
                               marker=dict(size=3, color='orange',
                                           # symbol='x',
                                           opacity=0.7),
                               name=pair[1]))

    # add a line between `baseline` and `early`
    fig.add_trace(
        go.Scatter3d(x=[data.loc[pair[0], 'g1'], data.loc[pair[1], 'g1']],
                     y=[data.loc[pair[0], 'g2'], data.loc[pair[1], 'g2']],
                     z=[data.loc[pair[0], 'g3'], data.loc[pair[1], 'g3']],
                     mode='lines', line=dict(color='blue', width=2.5)
                     )
    )


fig.update_layout(margin=dict(l=0, r=0, b=0, t=0),showlegend=False)
fig.show()
# fig.write_image(f"{pair[0][0]}2{pair[1][0]}.svg")

## plot lines only for significant regions

In [173]:
# load the significant regions from anova
dfa = pd.read_csv('data/dfa.csv', index_col=[0,1,2])

In [174]:
dfa.shape

(1012, 13)

In [175]:
# p-corrected < 0.05
dfa[dfa['p-corr'] < 0.05].shape

(208, 13)

In [176]:
# setup color map as input for go.Scatter3d

In [177]:
# function to convert 'rgb(119, 20, 140)' (str) -> (119, 20, 140) (tuple)
def str2rgb(s):
    s = s.replace('rgb', '').replace('(', '').replace(')', '')
    return tuple([int(i) for i in s.split(',')])
str2rgb('rgb(119, 20, 140)')

(119, 20, 140)

In [178]:
# function 'rgb(119, 20, 140)' -> #77148c
# new dict for network colors, from 'rgb' to hex code
cmap7_hex = {k: '#%02x%02x%02x' % str2rgb(v) for k, v in cmap7.items()}
cmap7_hex

{'Vis': '#77148c',
 'SomMot': '#467eaf',
 'DorsAttn': '#007507',
 'SalVentAttn': '#c33bff',
 'Limbic': '#dbf9a5',
 'Cont': '#e69521',
 'Default': '#cd4150',
 'Striatum': '#000000'}

In [179]:
# uncomment the pair you want to plot:
pair = ('baseline', 'early')
# pair = ('early', 'late')

fig = go.Figure()

for r in df_mean.region.unique():   # for all 1012 regions

    data = df_mean.set_index(['region', 'epoch']).loc[r]
    net = data['7net'].unique()
    hex_color = cmap7_hex[net[0]]

    # plot tail of the line
    fig.add_trace(go.Scatter3d(x=[data.loc[pair[0]].g1],
                               y=[data.loc[pair[0]].g2],
                               z=[data.loc[pair[0]].g3],
                               mode='markers', marker=dict(size=.7, color=hex_color,opacity=0.7),name=pair[0]))

    region_is_significant = True if dfa.loc[r,'p-corr'].values < 0.05 else False
    if region_is_significant:
        # plot head of the line
        fig.add_trace(go.Scatter3d(x=[data.loc[pair[1]].g1],
                                   y=[data.loc[pair[1]].g2],
                                   z=[data.loc[pair[1]].g3],
                                   mode='markers',marker=dict(size=3, color=hex_color,opacity=0.7),name=pair[1]))

        # add a line between head to tail for example, baseline to early
        fig.add_trace(
            go.Scatter3d(x=[data.loc[pair[0], 'g1'], data.loc[pair[1], 'g1']],
                         y=[data.loc[pair[0], 'g2'], data.loc[pair[1], 'g2']],
                         z=[data.loc[pair[0], 'g3'], data.loc[pair[1], 'g3']],
                         mode='lines', line=dict(color=hex_color, width=2)
                         ))

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), showlegend=False)
fig.show()
# fig.write_image(f"{pair[0][0]}2{pair[1][0]}.svg")

## path from baseline to early to late

In [180]:
pair = ('baseline', 'early', 'late')
# the path with two lines for significant regions
fig = go.Figure()

for r in df_mean.region.unique():   # for all 1012 regions

    data = df_mean.set_index(['region', 'epoch']).loc[r]
    net = data['7net'].unique()
    hex_color = cmap7_hex[net[0]]

    # plot 1st dot
    fig.add_trace(go.Scatter3d(x=[data.loc[pair[0]].g1],
                               y=[data.loc[pair[0]].g2],
                               z=[data.loc[pair[0]].g3],
                               mode='markers', marker=dict(size=.7, color=hex_color,opacity=0.7),name=pair[0]))

    # plot 2nd dot
    fig.add_trace(go.Scatter3d(x=[data.loc[pair[1]].g1],
                               y=[data.loc[pair[1]].g2],
                               z=[data.loc[pair[1]].g3],
                               mode='markers', marker=dict(size=.7, color=hex_color,opacity=0.7),name=pair[1]))


    region_is_significant = True if dfa.loc[r,'p-corr'].values < 0.05 else False
    if region_is_significant:
        # plot 3rd dot
        fig.add_trace(go.Scatter3d(x=[data.loc[pair[2]].g1],
                                   y=[data.loc[pair[2]].g2],
                                   z=[data.loc[pair[2]].g3],
                                   mode='markers',marker=dict(size=3, color=hex_color,opacity=0.7),name=pair[2]))

        # add first line from baseline to early
        fig.add_trace(
            go.Scatter3d(x=[data.loc[pair[0], 'g1'], data.loc[pair[1], 'g1']],
                         y=[data.loc[pair[0], 'g2'], data.loc[pair[1], 'g2']],
                         z=[data.loc[pair[0], 'g3'], data.loc[pair[1], 'g3']],
                         mode='lines', line=dict(color=hex_color, width=2)
                         ))
        # add second line from early to late
        fig.add_trace(
            go.Scatter3d(x=[data.loc[pair[1], 'g1'], data.loc[pair[2], 'g1']],
                         y=[data.loc[pair[1], 'g2'], data.loc[pair[2], 'g2']],
                         z=[data.loc[pair[1], 'g3'], data.loc[pair[2], 'g3']],
                         mode='lines', line=dict(color=hex_color, width=2)
                         ))

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), showlegend=False)
fig.show()
# fig.write_image("b2e2l.svg")