In [1]:
import torch
import plotly.graph_objects as go
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils.convert import from_networkx, to_networkx

In [2]:
graphs = torch.load("graphs.pt")

G = to_networkx(graphs[0])

pos = nx.spring_layout(G)

x_coords = [pos[node][0] for node in G.nodes()]
y_coords = [pos[node][1] for node in G.nodes()]

edge_x = []
edge_y = []
for edge in G.edges():
    edge_x.extend([pos[edge[0]][0], pos[edge[1]][0], None])
    edge_y.extend([pos[edge[0]][1], pos[edge[1]][1], None])

edge_trace = go.Scatter(
    x=edge_x, y=edge_y,
    line=dict(width=1, color='#888'),
    hoverinfo='none',
    mode='lines')

node_trace = go.Scatter(
    x=x_coords, y=y_coords,
    mode='markers',
    hoverinfo='text',
    marker=dict(
        size=10,
    )
)

node_trace.text = [f"Node {node}" for node in G.nodes()]

fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    title='Interactive Graph Visualization',
                    titlefont_size=16,
                    showlegend=False,
                    hovermode='closest',
                    margin=dict(b=0, l=0, r=0, t=40),
                    xaxis=dict(showgrid=False, zeroline=False),
                    yaxis=dict(showgrid=False, zeroline=False)
                ))

fig.show()

  graphs = torch.load("graphs.pt")
