In [None]:
# Create a new notebook: 02b_Graph_Visualization.ipynb

# 1. Import necessary packages
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from torch_geometric.utils import to_networkx
from nilearn import plotting
from config.paths import *

# Set visualization style
plt.style.use('default')
sns.set_palette("viridis")

# 2. Load the graph dataset
graph_dataset_path = os.path.join(PROCESSED_DIR, 'oasis_graph_dataset.pt')
graph_list = torch.load(graph_dataset_path)

print(f"Loaded {len(graph_list)} graphs")
print(f"Sample graph: {graph_list[0]}")

# 3. Load the dataset info
dataset_info_path = os.path.join(PROCESSED_DIR, 'dataset_info.json')
with open(dataset_info_path, 'r') as f:
    dataset_info = json.load(f)

print("Dataset info:")
for key, value in dataset_info.items():
    print(f"{key}: {value}")

# 4. Visualize the template connectivity matrix
print("\nVisualizing template connectivity matrix...")

# Get the adjacency matrix from the template
adjacency_matrix = np.zeros((dataset_info['num_nodes'], dataset_info['num_nodes']))
adjacency_matrix[graph_list[0].edge_index[0], graph_list[0].edge_index[1]] = graph_list[0].edge_attr.squeeze().numpy()

plt.figure(figsize=(12, 10))
plt.imshow(adjacency_matrix, cmap='hot', interpolation='nearest')
plt.title('Template Structural Connectivity Matrix')
plt.colorbar(label='Connection Strength')
plt.xlabel('Brain Region')
plt.ylabel('Brain Region')
plt.savefig(os.path.join(RESULTS_DIR, 'connectivity_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# 5. Visualize a subset of the graph using NetworkX
print("\nVisualizing graph structure...")

# Convert a sample graph to NetworkX format
sample_graph = graph_list[0]

# Create a smaller subgraph for visualization (first 20 nodes)
subgraph_nodes = 20
subgraph_edge_mask = (sample_graph.edge_index[0] < subgraph_nodes) & (sample_graph.edge_index[1] < subgraph_nodes)
subgraph_edge_index = sample_graph.edge_index[:, subgraph_edge_mask]
subgraph_edge_attr = sample_graph.edge_attr[subgraph_edge_mask]

# Create a NetworkX graph
G = nx.Graph()

# Add nodes
for i in range(subgraph_nodes):
    G.add_node(i, feature=sample_graph.x[i].item())

# Add edges with weights
for i in range(subgraph_edge_index.shape[1]):
    src = subgraph_edge_index[0, i].item()
    dst = subgraph_edge_index[1, i].item()
    weight = subgraph_edge_attr[i].item()
    G.add_edge(src, dst, weight=weight)

# Draw the graph
plt.figure(figsize=(15, 10))
pos = nx.spring_layout(G, seed=42)
node_colors = [G.nodes[n]['feature'] for n in G.nodes()]
edge_weights = [G.edges[e]['weight'] for e in G.edges()]

# Draw nodes
nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                              node_size=200, cmap=plt.cm.viridis)
# Draw edges
edges = nx.draw_networkx_edges(G, pos, width=2, alpha=0.3, 
                              edge_color=edge_weights, edge_cmap=plt.cm.plasma)

# Add labels
nx.draw_networkx_labels(G, pos, font_size=8)

# Add colorbars
plt.colorbar(nodes, label='Node Feature (Gray Matter Density)')
plt.colorbar(edges, label='Edge Weight (Connection Strength)')

plt.title(f'Brain Graph Structure (First {subgraph_nodes} Regions)')
plt.axis('off')
plt.savefig(os.path.join(RESULTS_DIR, 'graph_structure.png'), dpi=300, bbox_inches='tight')
plt.show()

# 6. Visualize node features distribution
print("\nVisualizing node features distribution...")

# Get all node features
all_node_features = []
for graph in graph_list:
    all_node_features.extend(graph.x.numpy().flatten())

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(all_node_features, bins=50, alpha=0.7, color='skyblue')
plt.title('Distribution of Node Features')
plt.xlabel('Gray Matter Density')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
# Plot mean node features for each graph
mean_features = [graph.x.mean().item() for graph in graph_list]
labels = [graph.y.item() for graph in graph_list]

# Separate by class
mean_features_0 = [mean_features[i] for i in range(len(mean_features)) if labels[i] == 0]
mean_features_1 = [mean_features[i] for i in range(len(mean_features)) if labels[i] == 1]

plt.hist(mean_features_0, alpha=0.7, label='NonDemented', bins=20)
plt.hist(mean_features_1, alpha=0.7, label='Demented', bins=20)
plt.title('Mean Node Features by Class')
plt.xlabel('Mean Gray Matter Density')
plt.ylabel('Frequency')
plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'node_features_distribution.png'), dpi=300, bbox_inches='tight')
plt.show()

# 7. Visualize the brain regions using Nilearn
print("\nVisualizing brain regions...")

# Load the AAL atlas
from nilearn import datasets
aal_atlas = datasets.fetch_atlas_aal()
atlas_filename = aal_atlas.maps

# Get coordinates for the regions
from nilearn.plotting import find_parcellation_cut_coords
coords = find_parcellation_cut_coords(atlas_filename)

# Plot the brain regions
plotting.plot_connectome(adjacency_matrix, coords, edge_threshold='90%',
                        title='Brain Regions and Connections', 
                        node_size=20, colorbar=True)
plt.savefig(os.path.join(RESULTS_DIR, 'brain_regions.png'), dpi=300, bbox_inches='tight')
plt.show()

# 8. Create a 3D interactive visualization (if in Jupyter)
try:
    from nilearn import plotting
    html_view = plotting.view_connectome(adjacency_matrix, coords, edge_threshold='90%')
    html_view.save_as_html(os.path.join(RESULTS_DIR, '3d_brain_connectome.html'))
    print("3D interactive visualization saved as HTML")
except Exception as e:
    print(f"Could not create 3D visualization: {e}")

# 9. Analyze graph properties
print("\nAnalyzing graph properties...")

# Calculate some graph metrics for each graph
avg_degrees = []
avg_clustering = []

for graph in graph_list[:10]:  # Just do for first 10 to save time
    # Convert to NetworkX
    G = nx.Graph()
    for i in range(graph.num_nodes):
        G.add_node(i)
    
    for i in range(graph.edge_index.shape[1]):
        src = graph.edge_index[0, i].item()
        dst = graph.edge_index[1, i].item()
        G.add_edge(src, dst)
    
    # Calculate metrics
    avg_degrees.append(np.mean(list(dict(G.degree()).values())))
    avg_clustering.append(nx.average_clustering(G))

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.hist(avg_degrees, bins=20, alpha=0.7, color='lightgreen')
plt.title('Average Degree Distribution')
plt.xlabel('Average Degree')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
plt.hist(avg_clustering, bins=20, alpha=0.7, color='orange')
plt.title('Average Clustering Coefficient')
plt.xlabel('Clustering Coefficient')
plt.ylabel('Frequency')

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'graph_properties.png'), dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("GRAPH VISUALIZATION COMPLETE")
print("="*60)
print("Visualizations saved to the results directory:")
print(f"- Connectivity matrix: {os.path.join(RESULTS_DIR, 'connectivity_matrix.png')}")
print(f"- Graph structure: {os.path.join(RESULTS_DIR, 'graph_structure.png')}")
print(f"- Node features distribution: {os.path.join(RESULTS_DIR, 'node_features_distribution.png')}")
print(f"- Brain regions: {os.path.join(RESULTS_DIR, 'brain_regions.png')}")
print(f"- Graph properties: {os.path.join(RESULTS_DIR, 'graph_properties.png')}")
if os.path.exists(os.path.join(RESULTS_DIR, '3d_brain_connectome.html')):
    print(f"- 3D interactive visualization: {os.path.join(RESULTS_DIR, '3d_brain_connectome.html')}")