In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from IPython.display import HTML

def create_complex_gnn_animation():
    """Create animated GNN visualization for Jupyter notebook"""
    
    # Create a more complex 8-node graph
    G = nx.Graph()
    G.add_edges_from([
        (0, 1), (0, 2), (0, 3),  # Node 0 connected to 1, 2, 3
        (1, 4), (1, 5),          # Node 1 connected to 4, 5
        (2, 5), (2, 6),          # Node 2 connected to 5, 6
        (3, 6), (3, 7),          # Node 3 connected to 6, 7
        (4, 5),                  # Cross connection
        (5, 6),                  # Cross connection
        (6, 7),                  # Cross connection
        (4, 7)                   # Long-range connection
    ])
    
    # Use a layout that shows the structure well
    pos = nx.spring_layout(G, seed=2, k=3, iterations=100)
    
    # Define stages with more complex propagation
    stages = [
        {
            'title': 'Initial: Node 0 activated',
            'colors': ['red', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue'],
            'sizes': [1500, 800, 800, 800, 800, 800, 800, 800]
        },
        {
            'title': 'Layer 1: Direct neighbors (1,2,3) receive signal',
            'colors': ['red', 'orange', 'orange', 'orange', 'lightblue', 'lightblue', 'lightblue', 'lightblue'],
            'sizes': [1500, 1200, 1200, 1200, 800, 800, 800, 800]
        },
        {
            'title': 'Layer 2: Second-order neighbors (4,5,6,7) activated',
            'colors': ['red', 'orange', 'orange', 'orange', 'coral', 'coral', 'coral', 'coral'],
            'sizes': [1500, 1200, 1200, 1200, 1000, 1000, 1000, 1000]
        },
        {
            'title': 'Layer 3: Cross-connections strengthen signals',
            'colors': ['darkred', 'red', 'red', 'red', 'orange', 'orange', 'orange', 'orange'],
            'sizes': [1500, 1300, 1300, 1300, 1200, 1200, 1200, 1200]
        },
        {
            'title': 'Final: All nodes fully influenced',
            'colors': ['darkred', 'darkred', 'darkred', 'darkred', 'red', 'red', 'red', 'red'],
            'sizes': [1500, 1400, 1400, 1400, 1300, 1300, 1300, 1300]
        },
        {
            'title': 'Decay phase 1: Outer nodes fade first',
            'colors': ['red', 'red', 'red', 'red', 'orange', 'orange', 'orange', 'orange'],
            'sizes': [1400, 1300, 1300, 1300, 1100, 1100, 1100, 1100]
        },
        {
            'title': 'Decay phase 2: Inner ring fades',
            'colors': ['red', 'orange', 'orange', 'orange', 'coral', 'coral', 'coral', 'coral'],
            'sizes': [1300, 1100, 1100, 1100, 900, 900, 900, 900]
        },
        {
            'title': 'Decay phase 3: Only source remains active',
            'colors': ['orange', 'coral', 'coral', 'coral', 'lightblue', 'lightblue', 'lightblue', 'lightblue'],
            'sizes': [1200, 900, 900, 900, 800, 800, 800, 800]
        },
        {
            'title': 'Reset: All nodes return to baseline',
            'colors': ['lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue', 'lightblue'],
            'sizes': [800, 800, 800, 800, 800, 800, 800, 800]
        }
    ]
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(12, 10))
    
    def animate(frame):
        ax.clear()
        
        # Get current stage
        stage = stages[frame]
        
        # Draw edges with varying thickness based on activation
        edge_colors = []
        edge_widths = []
        
        for edge in G.edges():
            node1, node2 = edge
            # Edge gets thicker if both nodes are activated
            color1 = stage['colors'][node1]
            color2 = stage['colors'][node2]
            
            if color1 != 'lightblue' and color2 != 'lightblue':
                edge_colors.append('darkgray')
                edge_widths.append(4)
            elif color1 != 'lightblue' or color2 != 'lightblue':
                edge_colors.append('gray')
                edge_widths.append(3)
            else:
                edge_colors.append('lightgray')
                edge_widths.append(2)
        
        # Draw edges
        nx.draw_networkx_edges(G, pos, ax=ax, edge_color=edge_colors, 
                              width=edge_widths, alpha=0.7)
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, ax=ax,
                              node_color=stage['colors'],
                              node_size=stage['sizes'],
                              alpha=0.9)
        
        # Draw labels
        nx.draw_networkx_labels(G, pos, ax=ax, font_size=14, 
                               font_weight='bold', font_color='white')
        
        ax.set_title(f"Time: {frame + 1}", 
                    fontsize=16, fontweight='bold', pad=20)
        ax.axis('off')
    
    # Create animation
    anim = animation.FuncAnimation(fig, animate, frames=len(stages), 
                                 interval=1200, repeat=True)
    
    return anim, fig

# Create and display the animation
anim, fig = create_complex_gnn_animation()

# Display in Jupyter
HTML(anim.to_jshtml())
# Only run this cell if you want to save the animation as GIF
def save_animation_as_gif(anim, filename='complex_gnn_propagation.gif'):
    """Save the animation as GIF"""
    anim.save(filename, writer='pillow', fps=0.8)
    print(f"GIF saved as '{filename}'")

# Uncomment the line below to save
save_animation_as_gif(anim)
