In [1]:
# Graph Viz
import matplotlib.pyplot as plt
import networkx as nx

from tqdm import tqdm
import datetime
import requests
import openai
import json
import os
import re

import dash
from dash import dcc, html
import plotly.graph_objs as go
from dash.dependencies import Input, Output
import networkx as nx
import random
import numpy as np

# FONCTIONNE MAIS TRES LENT, et les noeuds fils ne sont pas colorés

In [2]:
# Load file from disk if necessary
filename = f'models_v1_3_2023-09-07.txt'
with open(filename, 'r') as read_file:
    data_dict = json.load(read_file)

In [3]:
data_dict

{'StructuredData': {'Regression': ['LinearRegression',
   'RidgeRegression',
   'LassoRegression',
   'ElasticNet',
   'SupportVectorRegression',
   'DecisionTreeRegression',
   'RandomForestRegression',
   'AdaBoostRegression',
   'GradientBoostingRegression',
   'XGBoost',
   'LightGBM',
   'CatBoost',
   'ArtificialNeuralNetworks',
   'LongShort-TermMemory'],
  'Classification': ['LogisticRegression',
   'LinearDiscriminantAnalysis',
   'QuadraticDiscriminantAnalysis',
   'SupportVectorMachines',
   'DecisionTreeClassifier',
   'RandomForestClassifier',
   'AdaBoostClassifier',
   'GradientBoostingClassifier',
   'XGBoostClassifier',
   'LightGBMClassifier',
   'CatBoostClassifier',
   'K-NearestNeighbors',
   'NaiveBayesClassifier',
   'ArtificialNeuralNetworks'],
  'Clustering': ['K-MeansClustering',
   'DBSCAN',
   'HierarchicalClustering',
   'SpectralClustering',
   'MeanShift',
   'AffinityPropagation',
   'OPTICS',
   'BIRCH'],
  'DimensionalityReduction': ['PrincipalComponen

In [8]:
# Create a directed graph using NetworkX
G = nx.DiGraph()

for category, subcategories in data_dict.items():
    G.add_node(category)
    for subcategory, algorithms in subcategories.items():
        G.add_node(subcategory)
        G.add_edge(category, subcategory)
        for algorithm in algorithms:
            G.add_node(algorithm)
            G.add_edge(subcategory, algorithm)

In [9]:
# Create Dash app
app = dash.Dash(__name__)

# Define layout for the app
app.layout = html.Div([
    dcc.Graph(
        id='graph-3d',
        config={'displayModeBar': False},
        style={'height': '800px'},
    ),
    dcc.Interval(
        id='interval-component',
        interval=1 * 50,  # in milliseconds
        n_intervals=0
    )
])

In [10]:
# Create a color mapping for subcategories and their algorithms
subcategory_colors = {}
for subcategory in G.nodes():
    if G.out_degree(subcategory) > 0:  # Check if it's a subcategory (has outgoing edges)
        subcategory_colors[subcategory] = f"rgb({random.randint(0, 255)}, {random.randint(0, 255)}, {random.randint(0, 255)})"


In [11]:
def generate_3d_graph():
    # Position nodes in 3D space using NetworkX's spring layout
    pos = nx.spring_layout(G, dim=3)
    

    # Create nodes and edges for 3D graph
    nodes = []
    edges = []

    for node, position in pos.items():
        x, y, z = position
        if node in subcategory_colors:
            color = subcategory_colors[node]
        else:
            color = "gray"  # Default color for other nodes
        nodes.append(go.Scatter3d(
            x=[x],
            y=[y],
            z=[z],
            mode='markers+text',
            marker=dict(size=8, color=color),
            text=node,
            hoverinfo='text'
        ))

    for edge in G.edges():
        source, target = edge
        source_pos = pos[source]
        target_pos = pos[target]
        edge_trace = go.Scatter3d(
            x=[source_pos[0], target_pos[0]],
            y=[source_pos[1], target_pos[1]],
            z=[source_pos[2], target_pos[2]],
            mode='lines',
            line=dict(width=2, color="gray"),  # Default color for edges
        )
        edges.append(edge_trace)

    return nodes + edges

# Define callback to update the 3D graph
@app.callback(
    Output('graph-3d', 'figure'),
    Input('interval-component', 'n_intervals')
)
def update_graph(n_intervals):
    # Generate updated 3D graph data
    graph_data = generate_3d_graph()
    
    # Define layout for the 3D graph
    layout = go.Layout(
        margin=dict(l=0, r=0, b=0, t=0),
        scene=dict(
            xaxis=dict(title='', showticklabels=False),
            yaxis=dict(title='', showticklabels=False),
            zaxis=dict(title='', showticklabels=False),
        ),
    )
    
    # Return the updated graph data and layout
    return {'data': graph_data, 'layout': layout}

if __name__ == '__main__':
    app.run(port=2223,debug=True, jupyter_mode="external")

Dash app running on http://127.0.0.1:2223/


In [None]:
app.

#### Future potential updates: 
--> Color the nodes the same way as their subcateg

# Other Option

In [28]:
import networkx as nx
from pyvis.network import Network

# Assuming 'data' contains your structured data

# Initialize a directed graph
G = nx.DiGraph()

# Function to add nodes and edges to the graph
def add_nodes_edges(data, parent=None):
    for key, value in data.items():
        if parent:
            G.add_node(key)
            G.add_edge(parent, key)
        if isinstance(value, dict):
            add_nodes_edges(value, key)
        elif isinstance(value, list):
            for item in value:
                G.add_node(item)
                G.add_edge(key, item)

# Add nodes and edges from structured data
add_nodes_edges(data_dict)

# Initialize PyVis network
nt = Network('1000px', '1000px', notebook=False, directed=True)

# Instead of using from_nx, we manually add nodes and edges to avoid potential issues
for node in G.nodes:
    nt.add_node(node, label=node)

for edge in G.edges:
    nt.add_edge(edge[0], edge[1])

# Setup physics for better visualization
nt.toggle_physics(True)

# Generate and save the network visualization
nt.save_graph('structured_data_graph.html')


In [30]:
from IPython.display import IFrame

IFrame(src='structured_data_graph.html', width=1000, height=800)