<a href="https://colab.research.google.com/github/ubaidillah-chem/fouling-ml/blob/main/07_mlp_visualizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Arrow, FancyBboxPatch
from typing import Dict, List, Tuple, Union
import re

def visualize_mlp(model_repr: str, config: Dict = None):
    """Visualize an MLP architecture from a PyTorch model's string representation.

    Args:
        model_repr: String representation of a PyTorch model (from repr(model))
        config: Dictionary containing visualization parameters
    """
    # Default configuration
    default_config = {
        # Figure parameters
        'figsize': (14*5, 6*5),
        'x_limits': (0, 18*5),
        'y_limits': (0, 6*5),

        # Layer box parameters
        'box_width': 1.0*5,
        'box_height': 0.5*5,
        'box_rounding': 0.1*5,
        'box_padding': 0.05*5,
        'box_linewidth': 1.5*5,

        # Spacing parameters
        'layer_spacing': 1.6*5,
        'start_x': 1*5,
        'layer_y': 3*5,

        # Arrow parameters
        'arrow_width': 0.08*5,
        'arrow_color': '#555555',
        'arrow_alpha': 0.7,

        # Group box parameters
        'group_box_height': 0.9*5,
        'group_box_rounding': 0.2*5,
        'group_box_padding': 0.15*5,
        'group_box_linewidth': 2*5,
        'group_box_alpha': 0.6,

        # Text parameters
        'fontsize': 10*5,
        'title_fontsize': 14*5,

        # Colors
        'colors': {
            'Input': '#FF9AA2',
            'Linear': '#A2E1FF',
            'ReLU': '#B5EAD7',
            'BatchNorm': '#FFDAC1',
            'Output': '#C7CEEA',
            'Other': '#E2F0CB'
        },

        # Title
        'title': "MLP Architecture Visualization",

        # Rename layer type
        'rename': {
            'BatchNorm1d': 'BatchNorm'
        }
    }

    # Merge user config with defaults
    if config:
        default_config.update(config)
    config = default_config

    def parse_model_repr(repr_str: str) -> List[Tuple]:
        """Parse the model representation string into layer information."""
        layers = []

        # Extract the Sequential block if present
        sequential_match = re.search(r'Sequential\(([\s\S]*?)\)\)', repr_str)
        if sequential_match:
            content = sequential_match.group(1)
        else:
            content = repr_str

        # Find all layer lines with their parameters
        layer_lines = re.findall(r'\(\d+\): (\w+)\(([\s\S]*?)(?=\)\n|$)', content)

        # First pass: identify all linear layers to find the last one
        linear_indices = [i for i, (layer_type, _) in enumerate(layer_lines)
                         if layer_type.strip() == 'Linear']
        last_linear_idx = linear_indices[-1] if linear_indices else -1

        for i, (layer_type, params) in enumerate(layer_lines):
            layer_type = layer_type.strip()
            layer_type = config['rename'].get(layer_type, layer_type)

            if layer_type == 'Linear':
                # Extract in_features and out_features
                in_match = re.search(r'in_features=(\d+)', params)
                out_match = re.search(r'out_features=(\d+)', params)
                in_feat = int(in_match.group(1)) if in_match else 0
                out_feat = int(out_match.group(1)) if out_match else 0

                # Check if this is the last linear layer
                if i == last_linear_idx:
                    layers.append((f"Output\n({out_feat})",
                                 in_feat, out_feat,
                                 config['colors']['Output']))
                else:
                    layers.append((f"{layer_type}\n({out_feat})",
                                 in_feat, out_feat,
                                 config['colors']['Linear']))
            elif layer_type == 'ReLU':
                prev_out = layers[-1][2] if layers else 0
                layers.append((layer_type, prev_out, prev_out,
                             config['colors']['ReLU']))
            elif layer_type == 'BatchNorm1d':
                num_match = re.search(r'(num_features=(\d+))|(\d+)(?=\s*,)', params)
                if num_match:
                    num_feat = int(num_match.group(2)) if num_match.group(2) else int(num_match.group(3))
                else:
                    num_feat = layers[-1][2] if layers else 0
                layers.append((f"{layer_type}\n({num_feat})",
                             num_feat, num_feat,
                             config['colors']['BatchNorm']))
            else:
                prev_out = layers[-1][2] if layers else 0
                layers.append((layer_type, prev_out, prev_out,
                             config['colors'].get(layer_type, config['colors']['Other'])))

        return layers

    # Parse the model representation
    layers = parse_model_repr(model_repr)

    # Add input layer if needed
    if not layers:
        raise ValueError("Model appears to have no layers")

    # Always add input layer at beginning
    layers.insert(0, (f"Input\n({layers[0][1]})",
                   layers[0][1],
                   layers[0][1],
                   config['colors']['Input']))

    # Create figure and axis (same as before)
    fig, ax = plt.subplots(figsize=config['figsize'])
    ax.set_xlim(*config['x_limits'])
    ax.set_ylim(*config['y_limits'])
    ax.axis('off')

    # Draw layers
    for i, (name, in_feat, out_feat, color) in enumerate(layers):
        x_pos = i * config['layer_spacing'] + config['start_x']

        # Draw the layer box with rounded corners
        box = FancyBboxPatch((x_pos, config['layer_y'] - config['box_height']/2),
                            config['box_width'], config['box_height'],
                            boxstyle=f"round,pad={config['box_padding']},rounding_size={config['box_rounding']}",
                            linewidth=config['box_linewidth'],
                            edgecolor='#333333',
                            facecolor=color)
        ax.add_patch(box)

        # Add layer name
        ax.text(x_pos + config['box_width']/2, config['layer_y'], name,
                ha='center', va='center', fontsize=config['fontsize'])

        # Draw arrows between layers
        if i > 0:
            prev_x = (i-1) * config['layer_spacing'] + config['start_x'] + config['box_width'] + config['box_padding']
            curr_x = x_pos - config['box_padding']
            arrow = Arrow(prev_x, config['layer_y'], curr_x - prev_x, 0,
                         width=config['arrow_width'], color=config['arrow_color'], alpha=config['arrow_alpha'])
            ax.add_patch(arrow)

    # Draw grouping boxes for Linear->ReLU pairs
    group_indices = []
    for i in range(len(layers)-1):
        if 'Linear' in layers[i][0] and 'ReLU' in layers[i+1][0]:
            group_indices.append((i, i+1))

    for start_idx, end_idx in group_indices:
        x_start = start_idx * config['layer_spacing'] + config['start_x']
        x_end = end_idx * config['layer_spacing'] + config['start_x'] + config['box_width']
        group_width = x_end - x_start
        group_box = FancyBboxPatch((x_start, config['layer_y'] - config['box_height']*0.9),
                                 group_width, config['group_box_height'],
                                 boxstyle=f"round,pad={config['group_box_padding']},rounding_size={config['group_box_rounding']}",
                                 linewidth=config['group_box_linewidth'],
                                 linestyle='--',
                                 edgecolor='#555555',
                                 facecolor='none',
                                 alpha=config['group_box_alpha'])
        ax.add_patch(group_box)

    # Add legend
    legend_elements = [
        Rectangle((0,0),1,1, fc=config['colors']['Input'], label='Input'),
        Rectangle((0,0),1,1, fc=config['colors']['Linear'], label='Linear'),
        Rectangle((0,0),1,1, fc=config['colors']['ReLU'], label='ReLU'),
        Rectangle((0,0),1,1, fc=config['colors']['BatchNorm'], label='BatchNorm'),
        Rectangle((0,0),1,1, fc=config['colors']['Output'], label='Output'),
        Rectangle((0,0),1,1, fc=config['colors']['Other'], label='Other'),
    ]
    ax.legend(handles=legend_elements, loc='upper right',
              bbox_to_anchor=(1, 1), fontsize=9, framealpha=0.9)

    plt.title(config['title'], fontsize=config['title_fontsize'], pad=20, fontweight='bold')
    plt.tight_layout()
    plt.show()

In [None]:
model_repr = """
MLPModel(
  (model): Sequential(
    (0): Linear(in_features=47, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=1, bias=True)
    (3): Softplus(beta=1.0, threshold=20.0)
  )
)
"""

config = {'arrow_width': 0.1}
visualize_mlp(model_repr, config)