In [1]:
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from networkx.algorithms import community
data_df = pd.read_csv("train_data.csv", index_col="sample")
corr = data_df.corr('spearman')
# Create graph and find communities
G = nx.from_pandas_adjacency(corr)
partition = community.louvain_communities(G, weight='weight')

# Generate positions in 3D
pos = nx.spring_layout(G, dim=3)

# Create a community map
community_map = {}
for idx, community in enumerate(partition):
    for node in community:
        community_map[node] = idx

# Function to plot the 3D network graph
def plot_3d_network(G, pos, community_map,lim =0.3, excluded_communities=[]):
    # Filter nodes and edges based on excluded communities
    nodes_to_remove = [node for node, comm in community_map.items() if comm in excluded_communities]
    G_filtered = G.copy()
    G_filtered.remove_nodes_from(nodes_to_remove)
    
    # Update positions
    pos_filtered = {node: pos[node] for node in G_filtered.nodes()}
    
    # Edge trace
    edge_x = []
    edge_y = []
    edge_z = []
    edge_weights = []
    for edge in G_filtered.edges(data=True):
        if (abs(edge[2]['weight']) < lim):
            continue
        x0, y0, z0 = pos_filtered[edge[0]]
        x1, y1, z1 = pos_filtered[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
        edge_z.append(z0)
        edge_z.append(z1)
        edge_z.append(None)
        edge_weights.append(edge[2]['weight'])

    edge_trace = go.Scatter3d(
        x=edge_x, y=edge_y, z=edge_z,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    # Node trace
    node_x = []
    node_y = []
    node_z = []
    node_text = []
    node_color = []

    for node in G_filtered.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(f"{node} (Community {community_map[node]})")
        node_color.append(community_map[node])

    node_trace = go.Scatter3d(
        x=node_x, y=node_y, z=node_z,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='Jet',
            color=node_color,
            size=10,
            colorbar=dict(
                thickness=15,
                title='Community',
                xanchor='left',
                titleside='right'
            ),
            line_width=2),
        text=node_text
    )

    # Create a figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='3D Network graph with community detection',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[dict(
                            text="",
                            showarrow=False,
                            xref="paper", yref="paper")],
                        scene=dict(
                            xaxis=dict(showgrid=False, zeroline=False),
                            yaxis=dict(showgrid=False, zeroline=False),
                            zaxis=dict(showgrid=False, zeroline=False))
                        )
                    )

    return fig

# Initialize Dash app
app = Dash(__name__)

# Layout of the app
app.layout = html.Div([
    html.H1("3D Network Graph with Community Detection"),
    dcc.Dropdown(
        id='community-dropdown',
        options=[{'label': f'Community {i}', 'value': i} for i in range(len(partition))],
        multi=True,
        placeholder="Select communities to exclude"
    ),
    html.P("Min Corr:"),
    dcc.Input(id="min corr", type="number", value=float(0.5)),
    dcc.Graph(id='network-graph')
])

# Callback to update the graph based on selected communities
@app.callback(
    Output('network-graph', 'figure'),
    Input('community-dropdown', 'value'),
    Input("min corr", "value")
)
def update_graph(excluded_communities, min_corr):
    if excluded_communities is None:
        excluded_communities = []
    if min_corr is None:
        min_corr = 0.5
    fig = plot_3d_network(G, pos, community_map,min_corr, excluded_communities)
    return fig

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True)


OSError: Address 'http://127.0.0.1:8050' already in use.
    Try passing a different port to run_server.