# Interactive Tensor and Neural Network Operations Visualization

This notebook provides interactive visualizations for understanding tensor operations and their applications in neural networks.

## Contents
1. Tensor Visualization Tools
2. Neural Network Layer Operations
3. Attention Mechanism Visualization
4. Backpropagation Flow
5. Interactive Layer Architecture Builder

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets as widgets
from IPython.display import display, HTML
import seaborn as sns

%matplotlib inline
plt.style.use('seaborn')

def create_custom_colormap(n_colors):
    """Create a custom colormap with good contrast"""
    return sns.color_palette("husl", n_colors=n_colors)

class TensorVisualizer:
    @staticmethod
    def plot_tensor_3d(tensor, title="3D Tensor Visualization"):
        """Create an interactive 3D visualization of a tensor"""
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        x, y, z = np.indices(tensor.shape)
        values = tensor.flatten()
        
        # Normalize values for color mapping
        norm_values = (values - values.min()) / (values.max() - values.min())
        colors = plt.cm.viridis(norm_values)
        
        scatter = ax.scatter(x.flatten(), y.flatten(), z.flatten(),
                           c=values, cmap='viridis', alpha=0.6)
        
        ax.set_xlabel('Dimension 1')
        ax.set_ylabel('Dimension 2')
        ax.set_zlabel('Dimension 3')
        plt.colorbar(scatter)
        plt.title(title)
        
        return fig, ax
    
    @staticmethod
    def plot_tensor_slices(tensor, axis=0):
        """Plot all slices of a tensor along a specified axis"""
        n_slices = tensor.shape[axis]
        n_cols = min(4, n_slices)
        n_rows = (n_slices + n_cols - 1) // n_cols
        
        fig = plt.figure(figsize=(15, 3*n_rows))
        
        for i in range(n_slices):
            plt.subplot(n_rows, n_cols, i+1)
            if axis == 0:
                plt.imshow(tensor[i], cmap='viridis')
            elif axis == 1:
                plt.imshow(tensor[:, i], cmap='viridis')
            else:
                plt.imshow(tensor[:, :, i], cmap='viridis')
            plt.colorbar()
            plt.title(f'Slice {i}')
        
        plt.tight_layout()
        return fig

class NeuralOpsVisualizer:
    @staticmethod
    def visualize_linear_layer(input_size, output_size, batch_size=1):
        """Interactive visualization of a linear layer transformation"""
        weights = np.random.randn(input_size, output_size)
        bias = np.random.randn(output_size)
        
        def update_plot(batch_idx):
            input_data = np.random.randn(batch_size, input_size)
            output_data = np.dot(input_data, weights) + bias
            
            plt.figure(figsize=(15, 5))
            
            # Input visualization
            plt.subplot(131)
            plt.imshow(input_data.reshape(-1, input_size), cmap='viridis')
            plt.title('Input')
            plt.colorbar()
            
            # Weights visualization
            plt.subplot(132)
            plt.imshow(weights, cmap='viridis')
            plt.title('Weights')
            plt.colorbar()
            
            # Output visualization
            plt.subplot(133)
            plt.imshow(output_data.reshape(-1, output_size), cmap='viridis')
            plt.title('Output')
            plt.colorbar()
            
            plt.tight_layout()
            plt.show()
        
        widgets.interact(update_plot, 
                        batch_idx=widgets.IntSlider(min=0, max=batch_size-1, 
                                                  step=1, value=0))
    
    @staticmethod
    def visualize_attention(seq_length, d_k):
        """Interactive visualization of self-attention mechanism"""
        def compute_attention(temperature):
            # Generate random queries, keys, and values
            queries = np.random.randn(seq_length, d_k)
            keys = np.random.randn(seq_length, d_k)
            values = np.random.randn(seq_length, d_k)
            
            # Compute attention scores
            scores = np.dot(queries, keys.T) / np.sqrt(d_k)
            # Apply temperature scaling
            scores = scores / temperature
            # Compute attention weights
            weights = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
            # Compute attention output
            output = np.dot(weights, values)
            
            plt.figure(figsize=(15, 5))
            
            # Attention weights
            plt.subplot(131)
            plt.imshow(weights, cmap='viridis')
            plt.title('Attention Weights')
            plt.colorbar()
            
            # Input values
            plt.subplot(132)
            plt.imshow(values, cmap='viridis')
            plt.title('Values')
            plt.colorbar()
            
            # Output
            plt.subplot(133)
            plt.imshow(output, cmap='viridis')
            plt.title('Attention Output')
            plt.colorbar()
            
            plt.tight_layout()
            plt.show()
        
        widgets.interact(compute_attention, 
                        temperature=widgets.FloatSlider(min=0.1, max=2.0, 
                                                      step=0.1, value=1.0))

class BackpropVisualizer:
    def __init__(self, layer_sizes):
        self.layer_sizes = layer_sizes
        self.weights = [np.random.randn(i, j) for i, j in zip(layer_sizes[:-1], layer_sizes[1:])]
        self.gradients = [np.zeros_like(w) for w in self.weights]
    
    def visualize_network(self):
        """Visualize network architecture with gradients"""
        plt.figure(figsize=(15, 8))
        
        # Plot neurons
        for i, size in enumerate(self.layer_sizes):
            x = np.ones(size) * i
            y = np.linspace(0, size-1, size)
            plt.scatter(x, y, s=100, c='blue', alpha=0.5)
            
            if i < len(self.weights):
                # Plot connections with gradient colors
                for j in range(size):
                    for k in range(self.layer_sizes[i+1]):
                        weight = self.weights[i][j, k]
                        gradient = self.gradients[i][j, k]
                        color = 'red' if gradient > 0 else 'blue'
                        alpha = min(1, abs(gradient))
                        plt.plot([i, i+1], [j, k], c=color, alpha=alpha)
        
        plt.title('Neural Network with Gradient Flow')
        plt.grid(True)
        plt.show()
    
    def update_gradients(self):
        """Simulate gradient update"""
        self.gradients = [np.random.randn(*w.shape) for w in self.weights]
        self.visualize_network()

# Example usage
tensor_vis = TensorVisualizer()
neural_vis = NeuralOpsVisualizer()
backprop_vis = BackpropVisualizer([4, 6, 4, 2])

# Create some example data
example_tensor = np.random.randn(4, 4, 4)
tensor_vis.plot_tensor_3d(example_tensor)
plt.show()

tensor_vis.plot_tensor_slices(example_tensor)
plt.show()

print("Interactive Linear Layer Visualization:")
neural_vis.visualize_linear_layer(10, 5, batch_size=3)

print("\nInteractive Attention Visualization:")
neural_vis.visualize_attention(seq_length=8, d_k=4)

print("\nBackpropagation Visualization:")
backprop_vis.visualize_network()

## Interactive Network Architecture Builder

Build and visualize custom neural network architectures with different layer types.

In [None]:
class NetworkArchitectureBuilder:
    def __init__(self):
        self.layers = []
        
    def add_layer(self, layer_type, units, activation='relu'):
        self.layers.append({
            'type': layer_type,
            'units': units,
            'activation': activation
        })
        self.visualize_architecture()
    
    def visualize_architecture(self):
        """Visualize current network architecture"""
        plt.figure(figsize=(15, 8))
        
        # Calculate total height needed
        max_units = max(layer['units'] for layer in self.layers)
        height = max_units + 2  # Add padding
        
        for i, layer in enumerate(self.layers):
            # Plot neurons
            units = layer['units']
            y = np.linspace(0, height-1, units)
            plt.scatter([i]*units, y, s=100, label=f"{layer['type']} ({units})")
            
            # Add activation function annotation
            plt.text(i, height-0.5, layer['activation'], 
                    horizontalalignment='center')
            
            # Draw connections to next layer
            if i < len(self.layers) - 1:
                next_units = self.layers[i+1]['units']
                next_y = np.linspace(0, height-1, next_units)
                for y1 in y:
                    for y2 in next_y:
                        plt.plot([i, i+1], [y1, y2], 'gray', alpha=0.1)
        
        plt.grid(True)
        plt.title('Neural Network Architecture')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()

# Create interactive widgets
layer_type = widgets.Dropdown(
    options=['Dense', 'Conv2D', 'LSTM', 'Attention'],
    value='Dense',
    description='Layer Type:'
)

units = widgets.IntSlider(
    value=32,
    min=1,
    max=128,
    description='Units:'
)

activation = widgets.Dropdown(
    options=['relu', 'tanh', 'sigmoid', 'linear'],
    value='relu',
    description='Activation:'
)

# Create builder instance
builder = NetworkArchitectureBuilder()

def add_layer(button):
    builder.add_layer(layer_type.value, units.value, activation.value)

add_button = widgets.Button(description='Add Layer')
add_button.on_click(add_layer)

# Display controls
display(widgets.VBox([layer_type, units, activation, add_button]))

## Exercises

1. Create a custom tensor operation and visualize its effect
2. Build and visualize a CNN architecture
3. Implement and visualize different attention mechanisms
4. Create an animation of gradient flow during training

Try implementing these in the cell below:

In [None]:
# Your solution here
