# Graph Embedding Analysis Notebook

This notebook focuses on loading, generating, and visualizing graph embeddings from the Graph Neural Network (GNN) component. It aims to understand the spatial representations learned by the GNN.

In [None]:
import torch
import networkx as nx
import osmnx as ox
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from src.feature_forge.graph_embedding_features import GraphEmbeddingFeatures, GraphConvolutionalNetwork

sns.set_theme(style="whitegrid")

## 1. Configuration and Graph Loading

In [None]:
CONFIG_PATH = 'conf/rl_agent_params.yaml'
OSM_PROCESSING_CONFIG_PATH = 'conf/osm_processing_config.yaml'
GNN_MODEL_PATH = 'rl_model_registry/gcn_model.pth'

with open(OSM_PROCESSING_CONFIG_PATH, 'r') as f:
    osm_config = yaml.safe_load(f)

graph_path = osm_config['graph_serialization']['output_path']

if not os.path.exists(graph_path):
    print(f"Warning: Graph file not found at {graph_path}. Generating a small temporary graph for analysis.")
    G_temp = ox.graph_from_point((35.7, 51.4), dist=1000, network_type='drive')
    G_temp = ox.add_edge_speeds(G_temp)
    G_temp = ox.add_edge_travel_times(G_temp)
    for n, data in G_temp.nodes(data=True):
        data['closeness_centrality'] = np.random.rand()
    os.makedirs(os.path.dirname(graph_path), exist_ok=True)
    nx.write_gml(G_temp, graph_path)
    print("Temporary graph generated.")

graph_embedder = GraphEmbeddingFeatures(CONFIG_PATH, graph_path)

if os.path.exists(GNN_MODEL_PATH):
    graph_embedder.load_model_weights(GNN_MODEL_PATH)
else:
    print(f"Warning: GNN model weights not found at {GNN_MODEL_PATH}. Using untrained model for embeddings.")
    print("Consider running `python src/model_training/train_graph_gcn.py` first.")

## 2. Generate Node Embeddings

In [None]:
pyg_data, node_map = graph_embedder.preprocess_graph_for_pyg(graph_embedder.graph)
node_embeddings = graph_embedder.generate_embeddings(pyg_data).cpu().numpy()

print(f"Generated {len(node_embeddings)} node embeddings with dimension {node_embeddings.shape[1]}.")
print(f"Sample embedding for node {list(node_map.keys())[0]}:\n{node_embeddings[0, :5]}...")

## 3. Dimensionality Reduction for Visualization (PCA & t-SNE)

In [None]:
if node_embeddings.shape[0] > 1:
    # PCA for initial reduction
    pca = PCA(n_components=min(node_embeddings.shape[1], 50))
    pca_result = pca.fit_transform(node_embeddings)
    print(f"PCA explained variance ratio: {pca.explained_variance_ratio_.sum():.2f}")

    # t-SNE for 2D visualization
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(node_embeddings)-1)) # Perplexity constraint
    tsne_result = tsne.fit_transform(pca_result)

    embedding_df = pd.DataFrame(tsne_result, columns=['TSNE-1', 'TSNE-2'])
    embedding_df['node_id'] = list(node_map.keys())
    
    # Add some node attributes for coloring/categorization
    node_attr_df = pd.DataFrame.from_dict({
        'node_id': list(graph_embedder.graph.nodes()),
        'degree': [graph_embedder.graph.degree(n) for n in graph_embedder.graph.nodes()],
        'centrality': [graph_embedder.graph.nodes[n].get('closeness_centrality', 0) for n in graph_embedder.graph.nodes()]
    }).set_index('node_id')

    embedding_df = embedding_df.join(node_attr_df, on='node_id')

    plt.figure(figsize=(10, 8))
    sns.scatterplot(x='TSNE-1', y='TSNE-2', hue='centrality', size='degree', sizes=(20, 400), data=embedding_df, palette='viridis', alpha=0.7)
    plt.title('Node Embeddings Visualization (t-SNE)')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.show()
else:
    print("Not enough nodes for dimensionality reduction and visualization.")

## 4. Embedding Similarity and Spatial Coherence (Conceptual)

This section would involve querying embedding vectors for specific nodes and analyzing their similarity to physically close or functionally similar nodes. It could also compare embeddings of nodes in high-traffic vs. low-traffic areas.

In [None]:
if node_embeddings.shape[0] > 1:
    from sklearn.metrics.pairwise import cosine_similarity

    # Select a random node
    ref_node_id = list(node_map.keys())[np.random.randint(0, len(node_map))]
    ref_pyg_idx = node_map[ref_node_id]
    ref_embedding = node_embeddings[ref_pyg_idx].reshape(1, -1)

    # Calculate similarity to all other nodes
    similarities = cosine_similarity(ref_embedding, node_embeddings).flatten()

    similarity_df = pd.DataFrame({
        'node_id': list(node_map.keys()),
        'similarity': similarities
    }).sort_values('similarity', ascending=False)

    print(f"\nTop 5 most similar nodes to Node {ref_node_id}:\n")
    print(similarity_df.head(6))

    # Visualize location of similar nodes on the map (requires ox.plot.plot_graph_route)
    top_similar_nodes = similarity_df['node_id'].head(6).tolist()
    
    # Requires OSMnx for plotting actual graph routes
    # fig, ax = ox.plot_graph_route(graph_embedder.graph, top_similar_nodes, route_colors=['red'] + ['blue'] * 5, route_linewidth=3, node_size=10)
    # plt.title(f'Nodes similar to {ref_node_id}')
    # plt.show()

    # Example with dummy plot if OSMnx plotting is not feasible in environment
    fig, ax = ox.plot_graph(graph_embedder.graph, show=False, close=True, figsize=(10,8))
    node_coords = pd.DataFrame.from_dict(dict(graph_embedder.graph.nodes(data=['x','y'])), orient='index')
    
    for node_id in top_similar_nodes:
        color = 'red' if node_id == ref_node_id else 'blue'
        size = 100 if node_id == ref_node_id else 50
        ax.scatter(node_coords.loc[node_id]['x'], node_coords.loc[node_id]['y'], color=color, s=size, alpha=0.8, edgecolors='black')
    ax.set_title(f'Nodes similar to {ref_node_id}')
    plt.show()

else:
    print("Not enough nodes for similarity analysis.")