In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

df = pd.read_csv("https://raw.githubusercontent.com/smart-stats/ds4bio_book/main/book/assetts/kirby21AllLevels.csv").drop(['Unnamed: 0'], axis = 1)
df.head()

# load in the hierarchy information
url = "https://raw.githubusercontent.com/bcaffo/MRIcloudT1volumetrics/master/inst/extdata/multilevel_lookup_table.txt"
multilevel_lookup = pd.read_csv(url, sep = "\t").drop(['Level5'], axis = 1)
multilevel_lookup = multilevel_lookup.rename(columns = {
    "modify"   : "roi", 
    "modify.1" : "level4",
    "modify.2" : "level3", 
    "modify.3" : "level2",
    "modify.4" : "level1"})
multilevel_lookup = multilevel_lookup[['roi', 'level4', 'level3', 'level2', 'level1']]
multilevel_lookup.head()


id = 127
subjectData = df.loc[(df.type == 1) & (df.level == 5) & (df.id == id)]
subjectData = subjectData[['roi', 'volume']]

# Merge data with multilevel data
subjectData = pd.merge(subjectData, multilevel_lookup, on='roi')
subjectData = subjectData.assign(icv="ICV")
subjectData = subjectData.assign(comp = subjectData.volume / np.sum(subjectData.volume))
subjectData.head()

# Create tables by level groups
tb1 = subjectData.groupby(['icv', 'level1']).sum().reset_index()
tb2 = subjectData.groupby(['level1', 'level2']).sum().reset_index()
tb3 = subjectData.groupby(['level2', 'level3']).sum().reset_index()
tb4 = subjectData.groupby(['level3', 'level4']).sum().reset_index()

# Create label, source, target, values for Sankey diagram
def popSankey(df):
    label = []
    source = []
    target = [] 
    value = []
    count = 0
    stID = {}
    for index, row in df.iterrows():
        if (row[0] not in label):
            label.append(row[0])
            stID[row[0]] = count
            count += 1
    for index, row in df.iterrows():
        if (row[1] not in label):
            label.append(row[1])
            stID[row[1]] = count
            count += 1
    for index, row in df.iterrows():
        source.append(stID[row[0]])
        target.append(stID[row[1]])
        value.append(row[3]) 
    link = dict(source = source, target = target, value = value)
    node = dict(label = label, pad=50, thickness=5)
    sk = go.Sankey(link=link, node=node)
    return go.Figure(sk)

fg1 = popSankey(tb1)
fg2 = popSankey(tb2)
fg3 = popSankey(tb3)
fg4 = popSankey(tb4)

# Display subject's data as a Sankey diagram
fg1.show()
fg2.show()
fg3.show()
fg4.show()

# To html
fg1.write_html("FigureOne.html")
fg2.write_html("FigureTwo.html")
fg3.write_html("FigureThree.html")
fg4.write_html("FigureFour.html")

  tb1 = subjectData.groupby(['icv', 'level1']).sum().reset_index()
  tb2 = subjectData.groupby(['level1', 'level2']).sum().reset_index()
  tb3 = subjectData.groupby(['level2', 'level3']).sum().reset_index()
  tb4 = subjectData.groupby(['level3', 'level4']).sum().reset_index()
