<a href="https://colab.research.google.com/github/tsato-code/colab_notebooks/blob/main/sankey_diagram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from sklearn.preprocessing import LabelEncoder

In [2]:
data = pd.DataFrame({
    'current': ['Maruti', 'Ford', 'Maruti', 'Hyundai', 'Tata', 'Mahindra', 'Tata', 'Mahindra', 'Tata', 'Ford', 'Ford', 'Ford', 'Hyundai', 'Tata', 'Maruti', 'Maruti', 'Ford'],
    'next': ['Tata', 'Hyundai', 'Maruti', 'Hyundai', 'Tata', 'Mahindra', 'Tata', 'Ford', 'Hyundai', 'Tata', 'Mahindra', 'Tata', 'Tata', 'Maruti', 'Maruti', 'Ford', 'Tata']
    })
data.head()

Unnamed: 0,current,next
0,Maruti,Tata
1,Ford,Hyundai
2,Maruti,Maruti
3,Hyundai,Hyundai
4,Tata,Tata


In [3]:
# Groupbyして集計
cat = pd.concat([data['current'], data['next']]).unique()
display(cat)

# LabelEncoderでエンコード
le = LabelEncoder()
encoded = le.fit_transform(cat)
decoded = le.inverse_transform(encoded)

# エンコードした値を確認
display(encoded)
display(decoded)

# データの整形
data_cnt = data.groupby(['current', 'next']).size().reset_index()
source = le.fit_transform(data_cnt['current'])
target = le.fit_transform(data_cnt['next'])
value = list(data_cnt[0])

# 整形データの確認
display(source)
display(target)
display(value)

# サンキーダイアグラムのプロット
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 10,
      thickness = 2,
      line = dict(color = "black", width = 0.5),
      label = decoded,
      color = "blue"
    ),
    link = dict(
      source = source,
      target = target,
      value = value
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

array(['Maruti', 'Ford', 'Hyundai', 'Tata', 'Mahindra'], dtype=object)

array([3, 0, 1, 4, 2])

array(['Maruti', 'Ford', 'Hyundai', 'Tata', 'Mahindra'], dtype=object)

array([0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4])

array([1, 2, 4, 1, 4, 0, 2, 0, 3, 4, 1, 3, 4])

[1, 1, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2]

In [4]:
cat = pd.concat([data['current'], data['next']]).unique()

# LabelEncoderでエンコード
le = LabelEncoder()
encoded = le.fit_transform(cat)
decoded = le.inverse_transform(encoded)

# エンコードした値を確認
display(encoded)
display(decoded)

# データの整形
data_cnt = data.groupby(['current', 'next']).size().reset_index()
source = le.fit_transform(data_cnt['current'])
target = le.fit_transform(data_cnt['next']) + len(cat)
value = list(data_cnt[0])

# 整形データの確認
display(source)
display(target)
display(value)

# サンキーダイアグラムのプロット
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 20,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = list(decoded) * 2,
      color = "blue"
    ),
    link = dict(
      source = source,
      target = target,
      value = value
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

array([3, 0, 1, 4, 2])

array(['Maruti', 'Ford', 'Hyundai', 'Tata', 'Mahindra'], dtype=object)

array([0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4])

array([6, 7, 9, 6, 9, 5, 7, 5, 8, 9, 6, 8, 9])

[1, 1, 3, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2]