In [1]:
import torch
import plotly.graph_objects as go
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils.convert import from_networkx, to_networkx

In [8]:
file_path = "graphs/cb_nabil.pt"  
graph_data = torch.load(file_path)

if not isinstance(graph_data, Data):
    raise ValueError("The loaded file does not contain a PyTorch Geometric Data object.")

nx_graph = to_networkx(graph_data, to_undirected=True)

pos = nx.spring_layout(nx_graph)  # Force-directed layout

x_nodes = [pos[node][0] for node in nx_graph.nodes()]
y_nodes = [pos[node][1] for node in nx_graph.nodes()]

edge_x = []
edge_y = []
for edge in nx_graph.edges():
    edge_x += [pos[edge[0]][0], pos[edge[1]][0], None]
    edge_y += [pos[edge[0]][1], pos[edge[1]][1], None]

# Create edge traces
edge_trace = go.Scatter(
    x=edge_x,
    y=edge_y,
    line=dict(width=0.5, color="#888"),
    hoverinfo="none",
    mode="lines"
)

# Create node traces
node_trace = go.Scatter(
    x=x_nodes,
    y=y_nodes,
    mode="markers",
    marker=dict(
        size=10,
        color=list(nx_graph.degree()),  # Node color based on degree
        colorscale="YlGnBu",
        showscale=True,
        colorbar=dict(
            title="Node Degree"
        )
    ),
    text=[f"Node {n}" for n in nx_graph.nodes()],  # Node labels
    hoverinfo="text"
)

# Create the final figure
fig = go.Figure(data=[edge_trace, node_trace],
                layout=go.Layout(
                    showlegend=False,
                    hovermode="closest",
                    margin=dict(b=0, l=0, r=0, t=0),
                    xaxis=dict(showgrid=False, zeroline=False),
                    yaxis=dict(showgrid=False, zeroline=False)
                ))

# Show the visualization
fig.show()

  graph_data = torch.load(file_path)


ValueError: The loaded file does not contain a PyTorch Geometric Data object.