# Interactive Session 2: Deep Networks and Brain Hierarchies

## Exploring Hierarchical Processing

This session focuses on understanding how deep neural networks can model the hierarchical processing found in brain systems, particularly in vision and language processing.

### Learning Goals
- Understand hierarchical feature learning
- Explore convolutional neural networks and visual processing
- Interactive visualization of feature maps and filters
- Compare artificial and biological hierarchies

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import seaborn as sns
from torchvision import transforms, datasets
from PIL import Image
import cv2

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print("🧠 Interactive Deep Learning for Brain Modeling 🧠")
print("Exploring hierarchical processing in neural networks...")

## Interactive Element 1: Convolutional Filters and Edge Detection

Explore how convolutional filters work, inspired by simple cells in the visual cortex!

In [None]:
def create_sample_image(image_type='checkerboard', size=64):
    """Create different test images"""
    if image_type == 'checkerboard':
        img = np.zeros((size, size))
        for i in range(0, size, 8):
            for j in range(0, size, 8):
                if (i//8 + j//8) % 2 == 0:
                    img[i:i+8, j:j+8] = 1
    elif image_type == 'circles':
        img = np.zeros((size, size))
        center = size // 2
        y, x = np.ogrid[:size, :size]
        mask1 = (x - center)**2 + (y - center)**2 <= (size//4)**2
        mask2 = (x - center)**2 + (y - center)**2 <= (size//6)**2
        img[mask1] = 1
        img[mask2] = 0
    elif image_type == 'stripes':
        img = np.zeros((size, size))
        for i in range(0, size, 8):
            img[i:i+4, :] = 1
    else:  # random
        img = np.random.rand(size, size)
    
    return img

def apply_conv_filter(image, filter_type='horizontal_edge', custom_filter=None):
    """Apply different convolutional filters"""
    
    # Define different filters inspired by visual cortex
    filters = {
        'horizontal_edge': np.array([[-1, -1, -1], [2, 2, 2], [-1, -1, -1]]),
        'vertical_edge': np.array([[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]]),
        'diagonal_edge': np.array([[0, -1, -1], [1, 0, -1], [1, 1, 0]]),
        'gaussian_blur': np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 16,
        'sharpen': np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]),
        'laplacian': np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
    }
    
    if custom_filter is not None:
        kernel = custom_filter
    else:
        kernel = filters[filter_type]
    
    # Apply convolution
    filtered = cv2.filter2D(image, -1, kernel)
    
    return filtered, kernel

def interactive_convolution(image_type='checkerboard', filter_type='horizontal_edge', 
                           show_kernel=True):
    """Interactive convolution demonstration"""
    
    # Create image
    img = create_sample_image(image_type)
    
    # Apply filter
    filtered_img, kernel = apply_conv_filter(img, filter_type)
    
    # Create visualization
    fig, axes = plt.subplots(1, 3 if show_kernel else 2, figsize=(15 if show_kernel else 10, 5))
    
    # Original image
    axes[0].imshow(img, cmap='gray')
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Filtered image
    axes[1].imshow(filtered_img, cmap='gray')
    axes[1].set_title(f'After {filter_type.replace("_", " ").title()} Filter')
    axes[1].axis('off')
    
    # Kernel visualization
    if show_kernel:
        im = axes[2].imshow(kernel, cmap='RdBu', vmin=-2, vmax=2)
        axes[2].set_title('Filter Kernel')
        
        # Add text annotations for kernel values
        for i in range(kernel.shape[0]):
            for j in range(kernel.shape[1]):
                text = axes[2].text(j, i, f'{kernel[i, j]:.1f}',
                                  ha="center", va="center", color="white" if abs(kernel[i, j]) > 1 else "black")
        
        plt.colorbar(im, ax=axes[2])
    
    plt.tight_layout()
    plt.show()
    
    # Explanation
    explanations = {
        'horizontal_edge': "Detects horizontal edges - similar to simple cells in V1 that respond to specific orientations",
        'vertical_edge': "Detects vertical edges - models orientation-selective neurons in visual cortex",
        'diagonal_edge': "Detects diagonal edges - shows how different orientations can be detected",
        'gaussian_blur': "Smooths the image - similar to center-surround receptive fields",
        'sharpen': "Enhances edges and details - amplifies high-frequency components",
        'laplacian': "Edge detection filter - responds to rapid intensity changes"
    }
    
    print(f"🔍 {explanations.get(filter_type, 'Custom filter applied')}")

# Create interactive widgets
image_widget = widgets.Dropdown(
    options=['checkerboard', 'circles', 'stripes', 'random'],
    value='checkerboard',
    description='Image Type:'
)

filter_widget = widgets.Dropdown(
    options=['horizontal_edge', 'vertical_edge', 'diagonal_edge', 'gaussian_blur', 'sharpen', 'laplacian'],
    value='horizontal_edge',
    description='Filter:'
)

kernel_widget = widgets.Checkbox(
    value=True,
    description='Show Kernel'
)

interactive_conv = widgets.interactive(
    interactive_convolution,
    image_type=image_widget,
    filter_type=filter_widget,
    show_kernel=kernel_widget
)

display(interactive_conv)

## Interactive Element 2: Feature Map Visualization

Explore how features become more complex as we go deeper into the network!

In [None]:
class SimpleCNN(nn.Module):
    """A simple CNN for feature visualization"""
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
        self.conv3 = nn.Conv2d(16, 32, 3, padding=1)
        
    def forward(self, x, return_features=False):
        features = []
        
        # First layer
        x1 = torch.relu(self.conv1(x))
        features.append(x1)
        
        # Second layer
        x2 = torch.relu(self.conv2(x1))
        features.append(x2)
        
        # Third layer
        x3 = torch.relu(self.conv3(x2))
        features.append(x3)
        
        if return_features:
            return features
        return x3

def visualize_feature_maps(layer_num=1, feature_map=0, input_type='checkerboard'):
    """Visualize feature maps from different layers"""
    
    # Create model and input
    model = SimpleCNN()
    model.eval()
    
    # Create input image
    img = create_sample_image(input_type, 64)
    input_tensor = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    
    # Get feature maps
    with torch.no_grad():
        features = model(input_tensor, return_features=True)
    
    # Select layer and feature map
    selected_features = features[layer_num - 1]
    num_features = selected_features.shape[1]
    feature_map = min(feature_map, num_features - 1)
    
    # Visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img, cmap='gray')
    axes[0].set_title('Input Image')
    axes[0].axis('off')
    
    # Selected feature map
    feature_data = selected_features[0, feature_map].cpu().numpy()
    axes[1].imshow(feature_data, cmap='viridis')
    axes[1].set_title(f'Layer {layer_num}, Feature Map {feature_map}')
    axes[1].axis('off')
    
    # All feature maps from selected layer (grid)
    if num_features <= 16:
        grid_size = int(np.ceil(np.sqrt(num_features)))
        feature_grid = np.zeros((grid_size * feature_data.shape[0], grid_size * feature_data.shape[1]))
        
        for i in range(num_features):
            row = i // grid_size
            col = i % grid_size
            start_row = row * feature_data.shape[0]
            end_row = start_row + feature_data.shape[0]
            start_col = col * feature_data.shape[1]
            end_col = start_col + feature_data.shape[1]
            
            feature_grid[start_row:end_row, start_col:end_col] = selected_features[0, i].cpu().numpy()
        
        axes[2].imshow(feature_grid, cmap='viridis')
        axes[2].set_title(f'All {num_features} Feature Maps (Layer {layer_num})')
    else:
        # Show mean activation if too many features
        mean_activation = torch.mean(selected_features[0], dim=0).cpu().numpy()
        axes[2].imshow(mean_activation, cmap='viridis')
        axes[2].set_title(f'Mean Activation (Layer {layer_num})')
    
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Information about the layer
    layer_info = {
        1: "Layer 1: Detects simple features like edges and basic patterns",
        2: "Layer 2: Combines simple features into more complex patterns", 
        3: "Layer 3: Creates high-level feature representations"
    }
    
    print(f"📊 {layer_info.get(layer_num, 'Feature visualization')}")
    print(f"Number of feature maps in this layer: {num_features}")
    print(f"Feature map size: {feature_data.shape}")

# Create interactive widgets
layer_widget = widgets.IntSlider(
    value=1, min=1, max=3, step=1, description='Layer:'
)

feature_widget = widgets.IntSlider(
    value=0, min=0, max=31, step=1, description='Feature Map:'
)

input_widget = widgets.Dropdown(
    options=['checkerboard', 'circles', 'stripes', 'random'],
    value='checkerboard',
    description='Input:'
)

interactive_features = widgets.interactive(
    visualize_feature_maps,
    layer_num=layer_widget,
    feature_map=feature_widget,
    input_type=input_widget
)

display(interactive_features)

## Biological Connections

### Visual Cortex Hierarchy
The visual cortex processes information in a hierarchical manner:

1. **V1 (Primary Visual Cortex)**: Simple and complex cells detect edges, orientations
2. **V2**: Combines V1 features into more complex patterns
3. **V4**: Color and shape processing
4. **IT (Inferotemporal Cortex)**: Object recognition

Our CNN layers mirror this hierarchy:
- **Layer 1**: Edge detection (like V1 simple cells)
- **Layer 2**: Pattern combination (like V2)
- **Layer 3**: Complex feature detection (like V4/IT)

### Discussion Questions

1. **Feature Complexity**: How do the features change as you go from layer 1 to layer 3?

2. **Receptive Fields**: Notice how deeper layers respond to larger portions of the input. How does this relate to receptive field sizes in the brain?

3. **Specialization**: Different feature maps in the same layer detect different patterns. How might this relate to neural specialization in the brain?

4. **Hierarchical Processing**: Can you see how simple features combine to create complex ones? How might this apply to other brain functions beyond vision?