In [1]:
import networkx as nx
import pandas as pd
import numpy as np

# pip install pysankey seaborn
from pySankey.sankey import sankey


In [89]:
G_citations = nx.Graph(nx.read_gexf("output/citations.gexf"))
Gcc = sorted(nx.connected_components(nx.Graph(G_citations)), key=len, reverse=True)
G_citations = G_citations.subgraph(Gcc[0])

G_coauthors = nx.read_gexf("output/coauthors.gexf")
Gcc = sorted(nx.connected_components(nx.Graph(G_coauthors)), key=len, reverse=True)
G_coauthors = G_coauthors.subgraph(Gcc[0])

author_list = set(G_citations.nodes)&set(G_coauthors.nodes)

In [None]:
partition_citations = nx.algorithms.community.louvain.louvain_communities(G_citations, resolution=1)
partition_coauthors = nx.algorithms.community.louvain.louvain_communities(G_coauthors, resolution=0.2)

In [None]:
node_community_citations = dict()
node_community_coauthors = dict()

for i, community in enumerate(partition_citations):
    for node in community:
        node_community_citations[node] = i

n = np.max(list(node_community_citations.values())) + 1
n = int(n)

for i, community in enumerate(partition_coauthors):
    for node in community:
        node_community_coauthors[node] = i + n

links = nx.DiGraph()

for node in author_list:
    a, b = node_community_citations[node], node_community_coauthors[node]
    if links.has_edge(a, b):
        links[a][b]["value"] += 1
    else:
        links.add_edge(a,b,value=1)

sources = []
targets = []
values = []
degree = nx.degree(links, weight="value")

for source, target, attrs in links.edges(data=True):
    if degree[source] < 100 or degree[target] < 100:
        continue
    sources.append(source)
    targets.append(target)
    values.append(attrs["value"])
    

In [None]:
sankey(
    left=sources, right=targets, 
    leftWeight=values, rightWeight=values, 
    aspect=20, fontsize=0
)

In [None]:
import plotly.graph_objects as go
fig = go.Figure(go.Sankey(
    node=dict(
        pad=20,
        thickness=20,
        line=dict(width=0),
        label=list(set(sources + targets)),  # Unique labels
    ),
    link=dict(
        source=[list(set(sources + targets)).index(s) for s in sources],
        target=[list(set(sources + targets)).index(t) for t in targets],
        value=values
    )
))
fig.show()