# Interactive Examples for Case Studies in NeuroAI

This notebook provides interactive examples to complement Chapter 20: Case Studies in NeuroAI. You can run these examples directly in the book or launch a Binder session to experiment with the code.

## Interactive Example 1: PredNet Visualization

In [None]:
# Install dependencies if needed
import sys
!{sys.executable} -m pip install -q matplotlib numpy ipywidgets plotly

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, fixed
from IPython.display import display, clear_output

# Simplified PredNet simulation function
def simulate_prednet(input_noise=0.1, prediction_strength=0.7, learning_rate=0.1, timesteps=100):
    # Initialize arrays
    inputs = np.zeros(timesteps)
    predictions = np.zeros(timesteps)
    errors = np.zeros(timesteps)
    representation = 0.0
    
    # Simulate a simple oscillating input with noise
    for t in range(timesteps):
        # Input signal (sine wave + noise)
        inputs[t] = 0.5 * np.sin(0.1 * t) + np.random.normal(0, input_noise)
        
        # Generate prediction from representation
        predictions[t] = prediction_strength * representation
        
        # Compute prediction error
        errors[t] = inputs[t] - predictions[t]
        
        # Update representation based on error
        representation += learning_rate * errors[t]
    
    return inputs, predictions, errors

# Plotting function
def plot_prednet_simulation(inputs, predictions, errors):
    plt.figure(figsize=(12, 8))
    
    plt.subplot(3, 1, 1)
    plt.plot(inputs, label='Input')
    plt.title('Input Signal')
    plt.legend()
    
    plt.subplot(3, 1, 2)
    plt.plot(predictions, label='Prediction', color='orange')
    plt.title('Predicted Signal')
    plt.legend()
    
    plt.subplot(3, 1, 3)
    plt.plot(errors, label='Error', color='red')
    plt.title('Prediction Error')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Interactive widget
@widgets.interact(
    input_noise=widgets.FloatSlider(min=0.0, max=0.5, step=0.05, value=0.1, description='Input Noise:'),
    prediction_strength=widgets.FloatSlider(min=0.0, max=1.0, step=0.1, value=0.7, description='Pred. Strength:'),
    learning_rate=widgets.FloatSlider(min=0.01, max=0.5, step=0.01, value=0.1, description='Learning Rate:')
)
def interactive_prednet(input_noise, prediction_strength, learning_rate):
    inputs, predictions, errors = simulate_prednet(input_noise, prediction_strength, learning_rate)
    plot_prednet_simulation(inputs, predictions, errors)

## Interactive Example 2: Prioritized Experience Replay

This interactive example demonstrates how prioritized experience replay can improve learning efficiency in reinforcement learning.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random

# Simple environment
class SimpleGridworld:
    def __init__(self, size=5):
        self.size = size
        self.agent_pos = (0, 0)
        self.goal_pos = (size-1, size-1)
        self.reset()
    
    def reset(self):
        self.agent_pos = (0, 0)
        return self.get_state()
    
    def get_state(self):
        return self.agent_pos[0] * self.size + self.agent_pos[1]
    
    def step(self, action):
        # 0: up, 1: right, 2: down, 3: left
        moves = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        x, y = self.agent_pos
        dx, dy = moves[action]
        
        # Move agent
        new_x = max(0, min(self.size-1, x + dx))
        new_y = max(0, min(self.size-1, y + dy))
        self.agent_pos = (new_x, new_y)
        
        # Calculate reward
        done = self.agent_pos == self.goal_pos
        reward = 1.0 if done else -0.01
        
        return self.get_state(), reward, done

# Simple replay buffer
class ReplayBuffer:
    def __init__(self, capacity=1000):
        self.buffer = []
        self.capacity = capacity
    
    def add(self, experience):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

# Prioritized replay buffer
class PrioritizedReplayBuffer:
    def __init__(self, capacity=1000, alpha=0.6):
        self.buffer = []
        self.priorities = []
        self.capacity = capacity
        self.alpha = alpha  # Priority exponent
    
    def add(self, experience, priority=None):
        if priority is None:
            priority = 1.0  # Default priority
            
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
            self.priorities.pop(0)
            
        self.buffer.append(experience)
        self.priorities.append(priority)
    
    def sample(self, batch_size):
        if not self.buffer:
            return []
            
        # Calculate sampling probabilities
        priorities = np.array(self.priorities) ** self.alpha
        probs = priorities / np.sum(priorities)
        
        # Sample based on priorities
        indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), p=probs)
        samples = [self.buffer[idx] for idx in indices]
        
        return samples
    
    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            if idx < len(self.priorities):
                self.priorities[idx] = priority

# Simple Q-learning agent
class QLearningAgent:
    def __init__(self, state_dim, action_dim, prioritized=False, alpha=0.6):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.q_table = np.zeros((state_dim, action_dim))
        self.prioritized = prioritized
        
        if prioritized:
            self.replay_buffer = PrioritizedReplayBuffer(alpha=alpha)
        else:
            self.replay_buffer = ReplayBuffer()
    
    def select_action(self, state, epsilon=0.1):
        if np.random.random() < epsilon:
            return np.random.randint(self.action_dim)
        else:
            return np.argmax(self.q_table[state])
    
    def learn(self, state, action, reward, next_state, done, learning_rate=0.1, gamma=0.99):
        # Calculate TD error
        q_value = self.q_table[state, action]
        next_q_value = np.max(self.q_table[next_state]) if not done else 0
        td_target = reward + gamma * next_q_value
        td_error = td_target - q_value
        
        # Store experience with priority
        experience = (state, action, reward, next_state, done)
        if self.prioritized:
            priority = abs(td_error) + 0.01  # Small constant for stability
            self.replay_buffer.add(experience, priority)
        else:
            self.replay_buffer.add(experience)
        
        # Update Q-value directly
        self.q_table[state, action] += learning_rate * td_error
        
        return td_error
    
    def replay(self, batch_size=32, learning_rate=0.1, gamma=0.99):
        # Learn from experiences in replay buffer
        experiences = self.replay_buffer.sample(batch_size)
        
        td_errors = []
        for state, action, reward, next_state, done in experiences:
            q_value = self.q_table[state, action]
            next_q_value = np.max(self.q_table[next_state]) if not done else 0
            td_target = reward + gamma * next_q_value
            td_error = td_target - q_value
            
            # Update Q-value
            self.q_table[state, action] += learning_rate * td_error
            td_errors.append(abs(td_error))
        
        # If using prioritized replay, update priorities
        if self.prioritized and experiences:
            indices = list(range(len(self.replay_buffer.buffer) - len(experiences), len(self.replay_buffer.buffer)))
            self.replay_buffer.update_priorities(indices, td_errors)

# Function to run experiment
def run_experiment(prioritized=False, alpha=0.6, episodes=100, epsilon=0.1, learning_rate=0.1, gamma=0.99):
    env = SimpleGridworld(size=5)
    state_dim = env.size * env.size
    action_dim = 4
    
    agent = QLearningAgent(state_dim, action_dim, prioritized, alpha)
    
    rewards_history = []
    steps_history = []
    td_errors = []
    
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        steps = 0
        episode_td_errors = []
        
        while True:
            action = agent.select_action(state, epsilon)
            next_state, reward, done = env.step(action)
            
            # Learn from this experience
            td_error = agent.learn(state, action, reward, next_state, done, learning_rate, gamma)
            episode_td_errors.append(abs(td_error))
            
            # Learn from replay buffer
            if steps % 5 == 0:  # Replay every 5 steps
                agent.replay(batch_size=16, learning_rate=learning_rate, gamma=gamma)
            
            state = next_state
            total_reward += reward
            steps += 1
            
            if done or steps >= 100:  # Max 100 steps per episode
                break
        
        rewards_history.append(total_reward)
        steps_history.append(steps)
        td_errors.append(np.mean(episode_td_errors) if episode_td_errors else 0)
    
    return rewards_history, steps_history, td_errors

# Plotting function
def plot_experiment_results(standard_results, prioritized_results):
    std_rewards, std_steps, std_errors = standard_results
    pri_rewards, pri_steps, pri_errors = prioritized_results
    
    # Create subplots
    fig = make_subplots(rows=3, cols=1, 
                        subplot_titles=('Total Reward per Episode', 
                                        'Steps per Episode',
                                        'Mean TD Error per Episode'))
    
    # Add traces
    fig.add_trace(
        go.Scatter(y=std_rewards, mode='lines', name='Standard Replay'),
        row=1, col=1
    )
    fig.add_trace(
        go.Scatter(y=pri_rewards, mode='lines', name='Prioritized Replay'),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Scatter(y=std_steps, mode='lines', name='Standard Replay'),
        row=2, col=1
    )
    fig.add_trace(
        go.Scatter(y=pri_steps, mode='lines', name='Prioritized Replay'),
        row=2, col=1
    )
    
    fig.add_trace(
        go.Scatter(y=std_errors, mode='lines', name='Standard Replay'),
        row=3, col=1
    )
    fig.add_trace(
        go.Scatter(y=pri_errors, mode='lines', name='Prioritized Replay'),
        row=3, col=1
    )
    
    # Update layout
    fig.update_layout(height=800, width=800, title_text="Standard vs Prioritized Experience Replay")
    
    # Update y-axis titles
    fig.update_yaxes(title_text="Reward", row=1, col=1)
    fig.update_yaxes(title_text="Steps", row=2, col=1)
    fig.update_yaxes(title_text="TD Error", row=3, col=1)
    
    # Update x-axis titles
    fig.update_xaxes(title_text="Episode", row=3, col=1)
    
    fig.show()

# Interactive experiment runner
@widgets.interact(
    alpha=widgets.FloatSlider(min=0.1, max=1.0, step=0.1, value=0.6, description='Alpha:'),
    episodes=widgets.IntSlider(min=10, max=200, step=10, value=50, description='Episodes:'),
    learning_rate=widgets.FloatSlider(min=0.01, max=0.5, step=0.01, value=0.1, description='Learning Rate:')
)
def interactive_replay_experiment(alpha, episodes, learning_rate):
    print("Running experiment with Standard Replay...")
    standard_results = run_experiment(prioritized=False, episodes=episodes, learning_rate=learning_rate)
    
    print("Running experiment with Prioritized Replay...")
    prioritized_results = run_experiment(prioritized=True, alpha=alpha, episodes=episodes, learning_rate=learning_rate)
    
    print("Plotting results...")
    plot_experiment_results(standard_results, prioritized_results)

## Interactive Example 3: Vision Transformer Attention Visualization

This interactive example shows how attention works in Vision Transformers.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display, HTML
import random

# Simulate attention patterns for Vision Transformer
def create_sample_image(size=16, seed=None):
    """Create a sample binary image with a pattern"""
    if seed is not None:
        np.random.seed(seed)
        
    # Create a blank image
    img = np.zeros((size, size))
    
    # Add a random shape (square, circle, or line)
    shape_type = np.random.choice(['square', 'circle', 'line'])
    
    if shape_type == 'square':
        # Add a square
        x, y = np.random.randint(0, size//2, 2)
        w, h = np.random.randint(3, size//2, 2)
        img[x:x+w, y:y+h] = 1
    
    elif shape_type == 'circle':
        # Add a circle
        center_x, center_y = np.random.randint(size//4, 3*size//4, 2)
        radius = np.random.randint(2, size//4)
        
        for i in range(size):
            for j in range(size):
                if (i - center_x)**2 + (j - center_y)**2 < radius**2:
                    img[i, j] = 1
    
    else:  # line
        # Add a line
        start_x, start_y = np.random.randint(0, size, 2)
        end_x, end_y = np.random.randint(0, size, 2)
        
        # Simple line drawing algorithm
        length = max(abs(end_x - start_x), abs(end_y - start_y))
        for i in range(length):
            x = int(start_x + (end_x - start_x) * i / length)
            y = int(start_y + (end_y - start_y) * i / length)
            if 0 <= x < size and 0 <= y < size:
                img[x, y] = 1
    
    return img

def split_image_into_patches(image, patch_size=4):
    """Split an image into patches"""
    h, w = image.shape
    patches = []
    patch_positions = []
    
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            if i + patch_size <= h and j + patch_size <= w:
                patch = image[i:i+patch_size, j:j+patch_size]
                patches.append(patch)
                patch_positions.append((i, j))
    
    return patches, patch_positions

def generate_attention_map(patches, focus_patch_idx, attention_strength=0.8, noise=0.2):
    """Generate a simulated attention map for a specific patch"""
    num_patches = len(patches)
    attention_scores = np.random.rand(num_patches) * noise
    
    # Increase attention for patches that have similar content
    focus_patch = patches[focus_patch_idx]
    for i, patch in enumerate(patches):
        similarity = 1 - np.mean(np.abs(patch - focus_patch))
        attention_scores[i] += similarity * attention_strength
    
    # Normalize
    attention_scores = attention_scores / np.sum(attention_scores)
    
    return attention_scores

def visualize_attention(image, patches, patch_positions, patch_size, focus_patch_idx, attention_scores):
    """Visualize the image and attention map"""
    fig, axs = plt.subplots(1, 2, figsize=(15, 7))
    
    # Plot the original image with patch grid
    axs[0].imshow(image, cmap='gray')
    h, w = image.shape
    
    # Draw patch grid
    for i in range(0, h, patch_size):
        axs[0].axhline(i, color='blue', alpha=0.3)
    for j in range(0, w, patch_size):
        axs[0].axvline(j, color='blue', alpha=0.3)
    
    # Highlight the focus patch
    focus_i, focus_j = patch_positions[focus_patch_idx]
    rect = plt.Rectangle((focus_j, focus_i), patch_size, patch_size, 
                         edgecolor='red', facecolor='none', linewidth=2)
    axs[0].add_patch(rect)
    axs[0].set_title(f"Original Image (Focus on Patch {focus_patch_idx+1})")
    
    # Plot the attention map
    attention_map = np.zeros_like(image)
    for idx, (i, j) in enumerate(patch_positions):
        attention_map[i:i+patch_size, j:j+patch_size] = attention_scores[idx]
    
    im = axs[1].imshow(attention_map, cmap='hot')
    axs[1].set_title("Attention Map")
    fig.colorbar(im, ax=axs[1], shrink=0.8, label='Attention Score')
    
    plt.tight_layout()
    plt.show()

# Interactive visualization
@widgets.interact(
    seed=widgets.IntSlider(min=1, max=10, step=1, value=1, description='Image Seed:'),
    focus_patch=widgets.IntSlider(min=0, max=15, step=1, value=0, description='Focus Patch:'),
    attention_strength=widgets.FloatSlider(min=0.1, max=1.0, step=0.1, value=0.8, description='Attention Strength:'),
    noise=widgets.FloatSlider(min=0.0, max=0.5, step=0.05, value=0.2, description='Noise Level:')
)
def interactive_attention_visualization(seed, focus_patch, attention_strength, noise):
    # Create a sample image
    image_size = 16
    patch_size = 4
    image = create_sample_image(size=image_size, seed=seed)
    
    # Split into patches
    patches, patch_positions = split_image_into_patches(image, patch_size=patch_size)
    
    # Ensure focus_patch is valid
    num_patches = len(patches)
    focus_patch_idx = min(focus_patch, num_patches-1)
    
    # Generate attention map
    attention_scores = generate_attention_map(patches, focus_patch_idx, attention_strength, noise)
    
    # Visualize
    visualize_attention(image, patches, patch_positions, patch_size, focus_patch_idx, attention_scores)

## Interactive Example 4: Glossary - Interactive Neural Mechanisms

Below is an interactive glossary of key terms from Chapter 20, with popups explaining neural mechanisms and their AI implementations.

In [None]:
from IPython.display import HTML, display

# Create interactive glossary with popups
glossary_html = """
<style>
.glossary-term {
    color: #0366d6;
    cursor: pointer;
    font-weight: bold;
    text-decoration: underline;
    position: relative;
    display: inline-block;
}

.glossary-term .term-definition {
    visibility: hidden;
    width: 350px;
    background-color: #f8f9fa;
    color: #333;
    text-align: left;
    border-radius: 6px;
    padding: 10px;
    position: absolute;
    z-index: 1;
    bottom: 125%;
    left: 50%;
    margin-left: -175px;
    box-shadow: 0px 0px 15px rgba(0,0,0,0.2);
    transition: opacity 0.3s;
    opacity: 0;
    font-weight: normal;
    text-decoration: none;
    font-size: 0.9em;
    border: 1px solid #ddd;
}

.glossary-term:hover .term-definition {
    visibility: visible;
    opacity: 1;
}
</style>

<h3>Interactive Glossary for Case Studies in NeuroAI</h3>

<div style="background-color: #f5f5f5; padding: 15px; border-radius: 5px; margin-top: 20px;">
    <p>Hover over each term to see its definition and neural-AI connections.</p>
    
    <div class="glossary-term">Predictive Coding
        <span class="term-definition">
            <strong>Predictive Coding</strong><br>
            <em>Neural Mechanism:</em> The brain continually generates predictions about incoming sensory information and learns from prediction errors.<br>
            <em>AI Implementation:</em> PredNet architecture implements hierarchical predictive processing with explicit representation of prediction errors between layers.<br>
            <em>Chapter Reference:</em> Chapter 20.2
        </span>
    </div>
    <br><br>
    
    <div class="glossary-term">PredNet
        <span class="term-definition">
            <strong>PredNet</strong><br>
            <em>Description:</em> A deep learning architecture implementing hierarchical predictive coding, with explicit computation of prediction errors at each layer.<br>
            <em>Neural Parallel:</em> Mirrors the predictive processing in visual cortex, with both bottom-up and top-down information flow.<br>
            <em>Advantages:</em> Superior performance in video prediction and sample-efficient learning.<br>
            <em>Chapter Reference:</em> Chapter 20.2.2
        </span>
    </div>
    <br><br>
    
    <div class="glossary-term">Prioritized Experience Replay (PER)
        <span class="term-definition">
            <strong>Prioritized Experience Replay (PER)</strong><br>
            <em>Neural Mechanism:</em> The hippocampus selectively consolidates behaviorally relevant experiences during sleep and rest.<br>
            <em>AI Implementation:</em> In reinforcement learning, experiences with high TD error (surprising outcomes) are sampled more frequently during training.<br>
            <em>Benefits:</em> 50% faster convergence and 20% better final performance in many RL tasks.<br>
            <em>Chapter Reference:</em> Chapter 20.3
        </span>
    </div>
    <br><br>
    
    <div class="glossary-term">Vision Transformer (ViT)
        <span class="term-definition">
            <strong>Vision Transformer (ViT)</strong><br>
            <em>Neural Mechanism:</em> Visual attention in humans allows selective processing of relevant information while filtering distractions.<br>
            <em>AI Implementation:</em> Divides images into patches, processes them with self-attention mechanisms to capture relationships between distant image regions.<br>
            <em>Key Innovation:</em> Demonstrates that attention-based models can outperform CNNs on image classification tasks when pre-trained on sufficient data.<br>
            <em>Chapter Reference:</em> Chapter 20.4
        </span>
    </div>
    <br><br>
    
    <div class="glossary-term">Latent Factor Analysis via Dynamical Systems (LFADS)
        <span class="term-definition">
            <strong>Latent Factor Analysis via Dynamical Systems (LFADS)</strong><br>
            <em>Neural Principle:</em> High-dimensional neural activity often reflects low-dimensional latent dynamics.<br>
            <em>AI Technique:</em> Uses variational auto-encoders with recurrent neural networks to model underlying dynamics in neural population recordings.<br>
            <em>Applications:</em> Single-trial neural decoding, extracting dynamics from noisy spike recordings, improved BMI control.<br>
            <em>Chapter Reference:</em> Chapter 20.5
        </span>
    </div>
    <br><br>
    
    <div class="glossary-term">Attention Mechanism
        <span class="term-definition">
            <strong>Attention Mechanism</strong><br>
            <em>Neural Basis:</em> The brain's ability to selectively focus on important stimuli while ignoring distractions.<br>
            <em>AI Implementation:</em> Computational technique that allows models to weight the importance of different input elements.<br>
            <em>Example:</em> Self-attention in Vision Transformers allows each image patch to attend to all other patches.<br>
            <em>Chapter Reference:</em> Chapters 11 and 20.4
        </span>
    </div>
    <br><br>
</div>
"""

display(HTML(glossary_html))

## Connect to Binder

You can run this notebook interactively on Binder without any local installation.

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/yourusername/NeuroAI-Handbook/main?filepath=book/part6/ch20_interactive.ipynb)

## Summary

These interactive examples demonstrate key concepts from Chapter 20: Case Studies in NeuroAI:

1. **PredNet Visualization**: Explore how predictive coding works by adjusting parameters and seeing how prediction errors change.
2. **Prioritized Experience Replay**: Compare standard and prioritized replay methods in reinforcement learning.
3. **Vision Transformer Attention**: Visualize how attention mechanisms in ViT focus on different parts of an image.
4. **Interactive Glossary**: Hover over key terms to see detailed explanations and neural-AI connections.

These examples help bridge theoretical concepts with practical implementations, making the case studies more concrete and accessible.