In [1]:
import pandas as pd
import plotly.graph_objects as go

import pprint as pp

import warnings
warnings.simplefilter("ignore")

In [2]:
df = pd.read_csv("data/evs.csv")
print(df)
df['Evs'] = df['Evs'].apply(lambda x: x.replace('%',''))
df['Evs'] = df['Evs'].astype(int)

       Zone       State/UT  Evs  870141
0     North          Delhi  14%  125347
1     North        Haryana   3%   24206
2     North         Punjab   1%    8069
3     North  Uttar Pradesh  29%  255700
4     North    Uttarakhand   3%   23428
5      East          Assam   5%   43057
6      East          Bihar   7%   58014
7      East         Odisha   1%    9887
8      East        Tripura   1%    7103
9      East    West Bengal   5%   43384
10    South      Karnataka   8%   72544
11    South         Kerala   1%   11959
12    South     Tamil Nadu   5%   44817
13     West        Gujarat   2%   13063
14     West    Maharashtra   6%   52506
15     West      Rajasthan   5%   46862
16  Central   Chhattisgarh   1%   11881
17  Central      Jharkhand   1%   10954


- Make a map of zone to list of states, number of states
- node_labels : iterate keys, values
- link_labels : iterate map
- node_x_pos  : iterate keys, values
- node_y_pos  : 
    - zone: mid state
    - state: state

TARGET : 
[0.05, 0.07, 0.09, 0.11, 0.13]
[0.15, 0.17, 0.19, 0.21, 0.23]
[0.25, 0.27, 0.29]
[0.31, 0.33, 0.35]
[0.37, 0.39]


[0.05, 0.3, 0.55, 0.8, 1.05]
[1.3, 1.55, 1.8, 2.05, 2.3]
[2.55, 2.8, 3.05]
[3.3, 3.55, 3.8]
[4.05, 4.3]

In [3]:
ZONES   = df['Zone'].unique().tolist()
print("ZONES : ", ZONES);print()

Y_POS_1 = 1/(5-1)
Y_POS_2 = 1/50
node_labels = ZONES.copy()
node_x_pos = [0.01 for zone in ZONES]
node_y_pos, node_state_y_pos = [], []
link_labels = []
link_values        = df['Evs'].values
source, target = [], []
ZONE_STATES = {}
zone_len = len(ZONES)
node_len = zone_len
state_len = 0
for i in range(zone_len):
    zone = ZONES[i]
    states = df[df['Zone']==zone]['State/UT'].values.tolist()

    source.append(i)
    target.extend([node_len + j for j in range(len(states))])

    node_labels.extend(states)
    link_labels.extend(list(map(lambda x: f"{zone}-{x}", states)))

    node_x_pos.extend([0.6 for j in range(len(states))])

    node_y_pos.append(round(i*0.25+0.01,2))#round(i*Y_POS_1+0.01,2)) #(len(states)/2)*(Y_POS)+0.05,4))

    # round((j)*Y_POS+0.05,2)
    individual_node_state_y_pos = [ round((state_len + j)*Y_POS_2+0.05, 2) for j in range(len(states)) ]
    print(state_len, individual_node_state_y_pos)
    node_state_y_pos.extend(individual_node_state_y_pos)

    state_len += len(states)

    ZONE_STATES[zone] = { "id" : i, "states" : states }

node_y_pos.extend(node_state_y_pos)

pp.pprint(ZONE_STATES);print()

print('node_labels : ', len(node_labels), node_labels)
print('link_labels : ', len(link_labels), link_labels);print()
print('source : ', len(source), source)
print('target : ', len(target), target);print()

print('node_x_pos : ', len(node_x_pos), node_x_pos)
print('node_y_pos : ', len(node_y_pos), node_y_pos);print()


ZONES :  ['North', 'East', 'South', 'West', 'Central']

0 [0.05, 0.07, 0.09, 0.11, 0.13]
5 [0.15, 0.17, 0.19, 0.21, 0.23]
10 [0.25, 0.27, 0.29]
13 [0.31, 0.33, 0.35]
16 [0.37, 0.39]
{'Central': {'id': 4, 'states': ['Chhattisgarh', 'Jharkhand']},
 'East': {'id': 1,
          'states': ['Assam', 'Bihar', 'Odisha', 'Tripura', 'West Bengal']},
 'North': {'id': 0,
           'states': ['Delhi',
                      'Haryana',
                      'Punjab',
                      'Uttar Pradesh',
                      'Uttarakhand']},
 'South': {'id': 2, 'states': ['Karnataka', 'Kerala', 'Tamil Nadu']},
 'West': {'id': 3, 'states': ['Gujarat', 'Maharashtra', 'Rajasthan']}}

node_labels :  23 ['North', 'East', 'South', 'West', 'Central', 'Delhi', 'Haryana', 'Punjab', 'Uttar Pradesh', 'Uttarakhand', 'Assam', 'Bihar', 'Odisha', 'Tripura', 'West Bengal', 'Karnataka', 'Kerala', 'Tamil Nadu', 'Gujarat', 'Maharashtra', 'Rajasthan', 'Chhattisgarh', 'Jharkhand']
link_labels :  18 ['North-Delhi', 'No

In [4]:
# Display the figure
NODES = dict(pad  = 20, thickness = 20, line  = dict(color = "lightslategrey", width = 0.5),hovertemplate=" ",
            label = node_labels, x = node_x_pos, y = node_y_pos,)# color = node_colors)

LINKS = dict(source = source, target = target, value = link_values, 
            label = link_labels,#color = link_colors,
            hovertemplate="%{label}",)

data = go.Sankey( arrangement='snap', node = NODES, link = LINKS)
fig = go.Figure(data)
fig.update_traces( valueformat='3d', valuesuffix=' %', selector=dict(type='sankey'))
fig.update_layout(title="EVs - Zone & State",  font_size=16,  width=1200,height=1000,)
fig.update_layout(hoverlabel=dict( bgcolor="grey", font_size=14, font_family="Rockwell"))
fig.show()

In [5]:
NODE_COLORS = ["seagreen", "dodgerblue", "orange", "palevioletred", "darkcyan"]
LINK_COLORS = ["lightgreen", "lightskyblue", "bisque", "pink", "lightcyan"]

ZONES   = df['Zone'].unique().tolist()
STATES  = df['State/UT'].unique().tolist()
NODES   = ZONES + STATES

NODES_ID = {}
num = 0
for node in NODES:
    NODES_ID[node] = num

node_x_pos, node_y_pos      = [], []
node_labels        = NODES
node_colors        = NODE_COLORS
link_labels        = (df['Zone'] + "-" + df['State/UT']).values.tolist()
link_values        = df['Evs'].values
link_colors        = []


NODES_ZONES = {}
num = 0
Y_POS = 1/(5-1)
for zone in ZONES:
    NODES_ZONES[zone] = num
    node_x_pos.append(0.01)
    node_y_pos.append(round(num*Y_POS+0.01,2))
    link_colors.extend([LINK_COLORS[num]]*3)
    num += 1

df['source'] = df['Zone'].apply(lambda x: NODES_ZONES[x])
source = df['source'].values


NODES_STATES = {}
X_POS = 0.6
Y_POS = 1/50
for state in STATES:
    NODES_STATES[state] = num
    node_x_pos.append(X_POS)
    node_y_pos.append(round((num-5)*Y_POS+0.05,2)) # 
    num += 1

df['target'] = df['State/UT'].apply(lambda x: NODES_STATES[x])
target = df['target'].values




#node_labels.extend(list(NODES_ZONES.keys()))
#node_labels.extend(list(NODES_STATES.keys()))


print("node_labels",len(node_labels),node_labels)
print("node_x_pos",len(node_x_pos),node_x_pos)
print("node_y_pos",len(node_y_pos),node_y_pos)
print("source", len(source), source)
print("target", len(target), target)
print("link_labels", len(link_labels), link_labels)
print("link_values", len(link_values), link_values)
print("link_colors", len(link_colors), link_colors)

node_labels 23 ['North', 'East', 'South', 'West', 'Central', 'Delhi', 'Haryana', 'Punjab', 'Uttar Pradesh', 'Uttarakhand', 'Assam', 'Bihar', 'Odisha', 'Tripura', 'West Bengal', 'Karnataka', 'Kerala', 'Tamil Nadu', 'Gujarat', 'Maharashtra', 'Rajasthan', 'Chhattisgarh', 'Jharkhand']
node_x_pos 23 [0.01, 0.01, 0.01, 0.01, 0.01, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]
node_y_pos 23 [0.01, 0.26, 0.51, 0.76, 1.01, 0.05, 0.07, 0.09, 0.11, 0.13, 0.15, 0.17, 0.19, 0.21, 0.23, 0.25, 0.27, 0.29, 0.31, 0.33, 0.35, 0.37, 0.39]
source 18 [0 0 0 0 0 1 1 1 1 1 2 2 2 3 3 3 4 4]
target 18 [ 5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22]
link_labels 18 ['North-Delhi', 'North-Haryana', 'North-Punjab', 'North-Uttar Pradesh', 'North-Uttarakhand', 'East-Assam', 'East-Bihar', 'East-Odisha', 'East-Tripura', 'East-West Bengal', 'South-Karnataka', 'South-Kerala', 'South-Tamil Nadu', 'West-Gujarat', 'West-Maharashtra', 'West-Rajasthan', 'Central-Chhattisgar

[0.01, 0.26, 0.51, 0.76, 1.01, 0.05, 0.07, 0.09, 0.11, 0.13, 0.15, 0.17, 0.19, 0.21, 0.23, 0.25, 0.27, 0.29, 0.31, 0.33, 0.35, 0.37, 0.39]


In [11]:
# Display the figure
NODES = dict(pad  = 20, thickness = 20, line  = dict(color = "lightslategrey", width = 0.5),hovertemplate=" ",
            label = node_labels, x = node_x_pos, y = node_y_pos, color = node_colors)

LINKS = dict(source = source, target = target, value = link_values, 
            label = link_labels, color = link_colors,
            hovertemplate="%{label}",)

data = go.Sankey( arrangement='snap', node = NODES, link = LINKS)
fig = go.Figure(data)
fig.update_traces( valueformat='3d', valuesuffix=' %', selector=dict(type='sankey'))
fig.update_layout(title="EVs - Zone & State",  font_size=16,  width=1200,height=600,)
fig.update_layout(hoverlabel=dict( bgcolor="grey", font_size=14, font_family="Rockwell"))
fig.show()