In [4]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import seaborn as sns
from matplotlib.gridspec import GridSpec

# =============================
# CONFIGURATION
# =============================
BUFFER_DIR = "../cnn_analyze/game_frames"
CNN_DATA_DIR = "../cnn_analyze/cnn_frames"
OUTPUT_DIR = "visualizations"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# =============================
# UTILITY FUNCTIONS
# =============================
def get_available_batches(directory):
    """Get list of all available batch directories"""
    batches = [d for d in os.listdir(directory) 
               if os.path.isdir(os.path.join(directory, d)) and d.startswith('batch_')]
    batches.sort(key=lambda x: int(x.split('_')[1]))
    return batches

def load_buffer_images(buffer_dir, batch_name, buffer_type):
    """Load all frames from a specific buffer type"""
    buffer_path = os.path.join(buffer_dir, batch_name, buffer_type)
    if not os.path.exists(buffer_path):
        return None
    
    frames = []
    frame_files = sorted([f for f in os.listdir(buffer_path) if f.endswith('.png')])
    
    for frame_file in frame_files:
        img = Image.open(os.path.join(buffer_path, frame_file))
        frames.append(np.array(img))
    
    return frames

def load_cnn_data(cnn_dir, batch_name):
    """Load all CNN activation data for a batch"""
    batch_path = os.path.join(cnn_dir, batch_name)
    if not os.path.exists(batch_path):
        return None
    
    data = {}
    
    # Load all .npy files
    for file in os.listdir(batch_path):
        if file.endswith('.npy'):
            key = file.replace('.npy', '')
            data[key] = np.load(os.path.join(batch_path, file), allow_pickle=True)
    
    return data

def normalize_for_display(array):
    """Normalize array to 0-255 range for display"""
    arr_min, arr_max = array.min(), array.max()
    if arr_max > arr_min:
        return ((array - arr_min) / (arr_max - arr_min) * 255).astype(np.uint8)
    else:
        return (array * 255).astype(np.uint8)

# =============================
# VISUALIZATION FUNCTIONS
# =============================
def visualize_input_frames(input_data, save_path):
    """Visualize all 6 input frames"""
    # input_data shape: (1, 6, H, W)
    frames = input_data[0]  # Remove batch dimension
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Input Frames to CNN (6 Sequential Frames)', fontsize=16, fontweight='bold')
    
    for i, ax in enumerate(axes.flat):
        frame = frames[i]
        ax.imshow(frame, cmap='gray')
        ax.set_title(f'Frame {i} (t-{5-i})', fontsize=12)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def visualize_conv_layer(activation, layer_name, save_path, max_filters=32):
    """Visualize feature maps from a convolutional layer"""
    # activation shape: (1, num_filters, H, W)
    feature_maps = activation[0]  # Remove batch dimension
    num_filters = min(feature_maps.shape[0], max_filters)
    
    # Calculate grid size
    grid_size = int(np.ceil(np.sqrt(num_filters)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20))
    fig.suptitle(f'{layer_name} - Feature Maps (showing {num_filters}/{feature_maps.shape[0]} filters)', 
                 fontsize=16, fontweight='bold')
    
    for i in range(grid_size * grid_size):
        row = i // grid_size
        col = i % grid_size
        ax = axes[row, col] if grid_size > 1 else axes
        
        if i < num_filters:
            feature_map = feature_maps[i]
            normalized = normalize_for_display(feature_map)
            ax.imshow(normalized, cmap='viridis')
            ax.set_title(f'Filter {i}\n[{feature_map.min():.2f}, {feature_map.max():.2f}]', 
                        fontsize=8)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def visualize_filter_statistics(activation, layer_name, save_path):
    """Visualize statistics of filter activations"""
    # activation shape: (1, num_filters, H, W)
    feature_maps = activation[0]  # Remove batch dimension
    
    # Calculate statistics for each filter
    num_filters = feature_maps.shape[0]
    means = [feature_maps[i].mean() for i in range(num_filters)]
    stds = [feature_maps[i].std() for i in range(num_filters)]
    maxs = [feature_maps[i].max() for i in range(num_filters)]
    mins = [feature_maps[i].min() for i in range(num_filters)]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'{layer_name} - Filter Statistics', fontsize=16, fontweight='bold')
    
    # Mean activations
    axes[0, 0].bar(range(num_filters), means, color='blue', alpha=0.7)
    axes[0, 0].set_title('Mean Activation per Filter')
    axes[0, 0].set_xlabel('Filter Index')
    axes[0, 0].set_ylabel('Mean Value')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Standard deviation
    axes[0, 1].bar(range(num_filters), stds, color='green', alpha=0.7)
    axes[0, 1].set_title('Standard Deviation per Filter')
    axes[0, 1].set_xlabel('Filter Index')
    axes[0, 1].set_ylabel('Std Dev')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Max activations
    axes[1, 0].bar(range(num_filters), maxs, color='red', alpha=0.7)
    axes[1, 0].set_title('Max Activation per Filter')
    axes[1, 0].set_xlabel('Filter Index')
    axes[1, 0].set_ylabel('Max Value')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Min activations
    axes[1, 1].bar(range(num_filters), mins, color='orange', alpha=0.7)
    axes[1, 1].set_title('Min Activation per Filter')
    axes[1, 1].set_xlabel('Filter Index')
    axes[1, 1].set_ylabel('Min Value')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def visualize_activation_heatmap(activation, layer_name, save_path, num_filters=16):
    """Create a heatmap showing activation patterns across filters"""
    # activation shape: (1, num_filters, H, W)
    feature_maps = activation[0]  # Remove batch dimension
    num_to_show = min(num_filters, feature_maps.shape[0])
    
    # Flatten spatial dimensions and take subset of filters
    data = feature_maps[:num_to_show].reshape(num_to_show, -1)
    
    fig, ax = plt.subplots(figsize=(20, 8))
    sns.heatmap(data, cmap='viridis', ax=ax, cbar_kws={'label': 'Activation Value'})
    ax.set_title(f'{layer_name} - Activation Heatmap (First {num_to_show} Filters)', 
                 fontsize=14, fontweight='bold')
    ax.set_xlabel('Spatial Position (flattened)')
    ax.set_ylabel('Filter Index')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def visualize_q_values(q_values, action_taken, save_path):
    """Visualize Q-values as a bar chart"""
    q_vals = q_values[0] if len(q_values.shape) > 1 else q_values
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    actions = np.arange(len(q_vals))
    colors = ['red' if i == action_taken else 'blue' for i in actions]
    
    bars = ax.bar(actions, q_vals, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    
    # Highlight the chosen action
    ax.axvline(action_taken, color='red', linestyle='--', linewidth=2, alpha=0.5, 
               label=f'Chosen Action: {action_taken}')
    
    ax.set_xlabel('Action Index', fontsize=12)
    ax.set_ylabel('Q-Value', fontsize=12)
    ax.set_title(f'Q-Values for All Actions\n(Chosen: {action_taken}, Q={q_vals[action_taken]:.4f})', 
                 fontsize=14, fontweight='bold')
    ax.set_xticks(actions)
    ax.grid(True, alpha=0.3, axis='y')
    ax.legend()
    
    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars, q_vals)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.3f}',
                ha='center', va='bottom' if val >= 0 else 'top', 
                fontsize=10, fontweight='bold' if i == action_taken else 'normal')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def visualize_buffer_comparison(buffer_frames_dict, save_path):
    """Compare all buffer types side by side"""
    buffer_types = ['original', 'normalized', 'downscaled', 'grayscale']
    
    # Check which buffers are available
    available_buffers = {k: v for k, v in buffer_frames_dict.items() if v is not None}
    
    if not available_buffers:
        return
    
    # Get number of frames (assume all buffers have same number)
    num_frames = len(list(available_buffers.values())[0])
    num_buffer_types = len(available_buffers)
    
    fig, axes = plt.subplots(num_buffer_types, num_frames, 
                            figsize=(3*num_frames, 3*num_buffer_types))
    fig.suptitle('Buffer Comparison Across Processing Steps', fontsize=16, fontweight='bold')
    
    for i, (buffer_name, frames) in enumerate(available_buffers.items()):
        for j, frame in enumerate(frames):
            if num_buffer_types == 1:
                ax = axes[j]
            elif num_frames == 1:
                ax = axes[i]
            else:
                ax = axes[i, j]
            
            ax.imshow(frame, cmap='gray' if len(frame.shape) == 2 else None)
            if j == 0:
                ax.set_ylabel(buffer_name.capitalize(), fontsize=10, fontweight='bold')
            if i == 0:
                ax.set_title(f'Frame {j}', fontsize=10)
            ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def create_comprehensive_report(batch_name, cnn_data, buffer_frames, output_dir):
    """Create a comprehensive visualization report for a batch"""
    batch_output_dir = os.path.join(output_dir, batch_name)
    os.makedirs(batch_output_dir, exist_ok=True)
    
    # Load info
    info = cnn_data.get('info', None)
    if info is not None:
        info = info.item() if isinstance(info, np.ndarray) else info
        action_taken = info.get('action_taken', 0)
        episode = info.get('episode', 0)
        step = info.get('step', 0)
    else:
        action_taken = 0
        episode = 0
        step = 0
    
    print(f"Creating visualizations for {batch_name} (Episode {episode}, Step {step})...")
    
    # 1. Input frames
    if 'input' in cnn_data:
        print("  - Input frames...")
        visualize_input_frames(cnn_data['input'], 
                             os.path.join(batch_output_dir, '1_input_frames.png'))
    
    # 2. Conv1 layer
    if 'conv1' in cnn_data:
        print("  - Conv1 activations...")
        visualize_conv_layer(cnn_data['conv1'], 'Conv1 (32 filters, 8x8 kernel, stride 4)',
                           os.path.join(batch_output_dir, '2_conv1_features.png'))
        visualize_filter_statistics(cnn_data['conv1'], 'Conv1',
                                  os.path.join(batch_output_dir, '2_conv1_statistics.png'))
        visualize_activation_heatmap(cnn_data['conv1'], 'Conv1',
                                   os.path.join(batch_output_dir, '2_conv1_heatmap.png'))
    
    # 3. Conv2 layer
    if 'conv2' in cnn_data:
        print("  - Conv2 activations...")
        visualize_conv_layer(cnn_data['conv2'], 'Conv2 (64 filters, 4x4 kernel, stride 2)',
                           os.path.join(batch_output_dir, '3_conv2_features.png'))
        visualize_filter_statistics(cnn_data['conv2'], 'Conv2',
                                  os.path.join(batch_output_dir, '3_conv2_statistics.png'))
        visualize_activation_heatmap(cnn_data['conv2'], 'Conv2',
                                   os.path.join(batch_output_dir, '3_conv2_heatmap.png'))
    
    # 4. Conv3 layer
    if 'conv3' in cnn_data:
        print("  - Conv3 activations...")
        visualize_conv_layer(cnn_data['conv3'], 'Conv3 (64 filters, 3x3 kernel, stride 1)',
                           os.path.join(batch_output_dir, '4_conv3_features.png'))
        visualize_filter_statistics(cnn_data['conv3'], 'Conv3',
                                  os.path.join(batch_output_dir, '4_conv3_statistics.png'))
        visualize_activation_heatmap(cnn_data['conv3'], 'Conv3',
                                   os.path.join(batch_output_dir, '4_conv3_heatmap.png'))
    
    # 5. Q-values
    if 'q_values' in cnn_data:
        print("  - Q-values...")
        visualize_q_values(cnn_data['q_values'], action_taken,
                         os.path.join(batch_output_dir, '5_q_values.png'))
    
    # 6. Buffer comparison
    if buffer_frames:
        print("  - Buffer comparison...")
        visualize_buffer_comparison(buffer_frames,
                                  os.path.join(batch_output_dir, '0_buffer_comparison.png'))
    
    print(f"✓ Completed {batch_name}")

def visualize_all_batches():
    """Main function to visualize all available batches"""
    # Get available batches
    cnn_batches = get_available_batches(CNN_DATA_DIR)
    buffer_batches = get_available_batches(BUFFER_DIR)
    
    print(f"Found {len(cnn_batches)} CNN data batches")
    print(f"Found {len(buffer_batches)} buffer batches")
    
    # Find common batches
    common_batches = set(cnn_batches) & set(buffer_batches)
    print(f"\nProcessing {len(common_batches)} batches with complete data...\n")
    
    for batch in sorted(common_batches):
        # Load CNN data
        cnn_data = load_cnn_data(CNN_DATA_DIR, batch)
        
        # Load buffer data
        buffer_frames = {
            'original': load_buffer_images(BUFFER_DIR, batch, 'original'),
            'normalized': load_buffer_images(BUFFER_DIR, batch, 'normalized'),
            'downscaled': load_buffer_images(BUFFER_DIR, batch, 'downscaled'),
            'grayscale': load_buffer_images(BUFFER_DIR, batch, 'grayscale')
        }
        
        # Create visualizations
        create_comprehensive_report(batch, cnn_data, buffer_frames, OUTPUT_DIR)
    
    print(f"\n{'='*60}")
    print(f"All visualizations saved to: {OUTPUT_DIR}")
    print(f"{'='*60}")

# =============================
# MAIN EXECUTION
# =============================
if __name__ == "__main__":
    print("="*60)
    print("CNN Data Visualizer")
    print("="*60)
    visualize_all_batches()

CNN Data Visualizer
Found 4 CNN data batches
Found 4 buffer batches

Processing 4 batches with complete data...

Creating visualizations for batch_0000_ep0_step100 (Episode 0, Step 100)...
  - Input frames...
  - Conv1 activations...
  - Conv2 activations...
  - Conv3 activations...
  - Q-values...
  - Buffer comparison...
✓ Completed batch_0000_ep0_step100
Creating visualizations for batch_0001_ep0_step200 (Episode 0, Step 200)...
  - Input frames...
  - Conv1 activations...
  - Conv2 activations...
  - Conv3 activations...
  - Q-values...
  - Buffer comparison...
✓ Completed batch_0001_ep0_step200
Creating visualizations for batch_0002_ep0_step400 (Episode 0, Step 400)...
  - Input frames...
  - Conv1 activations...
  - Conv2 activations...
  - Conv3 activations...
  - Q-values...
  - Buffer comparison...
✓ Completed batch_0002_ep0_step400
Creating visualizations for batch_0003_ep0_step500 (Episode 0, Step 500)...
  - Input frames...
  - Conv1 activations...
  - Conv2 activations...
