In [None]:
# Neural Network Visualization in Jupyter Notebook
# This notebook demonstrates different methods to create neural network architecture diagrams

import numpy as np
import matplotlib.pyplot as plt
import io
import IPython.display as display
from IPython.display import HTML, Markdown
import os
import subprocess
import sys


import graphviz






In [None]:
import matplotlib.pyplot as plt

def create_unet_architecture():
    """
    Creates a U-Net architecture visualization.

    - The input layer is drawn as three separate 3D boxes (one per channel),
      while subsequent layers are represented by a single 3D block.
    - Arrows are drawn between the computed center positions of each layer,
      ensuring they line up properly.
    - The final output is saved as a high-resolution PNG.
    """
    # Create the figure and axes for plotting.
    fig, ax = plt.subplots(figsize=(20, 12))
    ax.set_xlim(0, 18)  # x-axis span for positioning the blocks and arrows
    ax.set_ylim(0, 10)  # y-axis span
    ax.axis('off')      # Hide the axes for a cleaner appearance

    def draw_block(x, y, width, height, depth, colors, label, num_blocks=1, gap=0.2):
        """
        Draws one or more 3D-like blocks.

        Parameters:
          x (float): Starting x-coordinate of the group.
          y (float): y-coordinate for the blocks.
          width (float): The width of each block.
          height (float): The height of each block.
          depth (float): Depth value to control the 3D perspective effect.
          colors (list): List of colors for the blocks.
          label (str): Text label for the group.
          num_blocks (int): Number of blocks to draw side-by-side.
          gap (float): Horizontal gap between blocks (if num_blocks>1).
        """
        # Calculate the total width of the block group.
        group_width = num_blocks * width + (num_blocks - 1) * gap

        # Loop over each block in the group.
        for i in range(num_blocks):
            xpos = x + i * (width + gap)  # Shift the x position for each block.
            color = colors[i % len(colors)]  # Cycle through colors if fewer than num_blocks

            # Draw the front face (base rectangle).
            front = plt.Rectangle((xpos, y), width, height, color=color, ec='black')
            ax.add_patch(front)

            # Draw the top face if a depth effect is desired.
            if depth > 0:
                top = plt.Polygon([
                    (xpos, y + height),                             # upper left corner
                    (xpos + width, y + height),                     # upper right corner
                    (xpos + width + depth * 0.3, y + height + depth * 0.3),  # shifted upper right
                    (xpos + depth * 0.3, y + height + depth * 0.3)    # shifted upper left
                ], color=color, alpha=0.7, ec='black')
                ax.add_patch(top)

            # Draw the right face if a depth effect is desired.
            if depth > 0:
                right = plt.Polygon([
                    (xpos + width, y),                              # bottom right
                    (xpos + width, y + height),                     # top right
                    (xpos + width + depth * 0.3, y + height + depth * 0.3),  # shifted top right
                    (xpos + width + depth * 0.3, y + depth * 0.3)     # shifted bottom right
                ], color=color, alpha=0.5, ec='black')
                ax.add_patch(right)

        # Place a single label above the entire group of blocks.
        label_x = x + group_width / 2
        label_y = y + height + 0.2
        plt.text(label_x, label_y, label, ha='center', fontsize=12)

    # ------------------------------------------------------------------------
    # Draw the layers of the U-Net:
    # ------------------------------------------------------------------------
    # Input Layer: 3 blocks representing the 3 channels (e.g., RGB 32x32 images).
    #   - Starting at x=1, y=4, each block has width 0.8 and height 2.
    #   - We specify num_blocks=3 and a gap of 0.2 between them.
    draw_block(1, 7, 0.2, 2, 0.5, ['#FF0000', '#00FF00', '#0000FF'], 'Input\n32×32×3', num_blocks=3, gap=0.2)

    # Conv1 layer.
    draw_block(3, 6, 0.7, 1.6, 0.7, ['#FFA07A'], 'Convolution 1\n30×30×32')

    # Pool1 layer.
    draw_block(4.5, 5, 0.6, 1.2, 0.6, ['#ADD8E6'], 'Max Pool 1\n15×15×32')

    # Conv2 layer.
    draw_block(6, 4, 0.5, 0.8, 0.5, ['#FFA07A'], 'Convolution 2\n13×13×64')

    # Pool2 layer.
    draw_block(7.5, 3, 0.4, 0.4, 0.4, ['#ADD8E6'], 'Max Pool2\n6×6×64')

    # Bottleneck.
    draw_block(9, 1, 0.3, 0.3, 0.3, ['#FFA07A'], 'Bottleneck\n3×3×128')

    # Up1 layer.
    draw_block(10.5, 3, 0.4, 0.4, 0.4, ['#ADD8E6'], 'Up Sample 1\n6×6×64')

    # Conv3 layer.
    draw_block(12, 4, 0.5, 0.8, 0.5, ['#FFA07A'], 'Convolution 3\n13×13×64')

    # Up2 layer.
    draw_block(13.5, 5, 0.6, 1.2, 0.6, ['#ADD8E6'], 'Up Sample 2\n15×15×32')

    # Conv4 layer.
    draw_block(15, 6, 0.7, 1.6, 0.7, ['#FFA07A'], 'Convolution 4\n30×30×32')

    # Output layer: a single block.
    draw_block(17, 7, 0.2, 2, 0.5, ['#808080'], 'Output\n32×32×1', num_blocks=1)

    # ------------------------------------------------------------------------
    # Define the center positions of each layer for arrow alignment.
    # We compute centers based on x, width, y, and height for each layer:
    # For groups with multiple blocks, we use the group’s overall width.
    # ------------------------------------------------------------------------
    # Input group (3 blocks): group width = 3*0.8 + 2*0.2 = 2.4 + 0.4 = 2.8; center_x = 1 + 2.8/2 = 2.4
    center_input = (2.4, 8)

    # For single-block layers, center = (x + width/2, y + height/2)
    center_conv1     = (3.35, 6.8)
    center_pool1     = (4.8, 5.6)
    center_conv2     = (6.25, 4.4)
    center_pool2     = (7.7, 3.2)
    center_bottleneck= (9.15, 1.15)
    center_up1       = (10.7, 3.2)
    center_conv3     = (12.25, 4.4)
    center_up2       = (13.8, 5.6)
    center_conv4     = (15.35, 6.8)
    center_output    = (16.9, 8)

    # Create a list of these centers in the order of the layers.
    centers = [
        center_input, center_conv1, center_pool1, center_conv2,
        center_pool2, center_bottleneck, center_up1, center_conv3,
        center_up2, center_conv4, center_output
    ]

    # ------------------------------------------------------------------------
    # Draw arrows between consecutive layer centers using annotate.
    # This ensures the arrows start and end at the computed centers.
    # ------------------------------------------------------------------------
    for i in range(len(centers) - 1):
        start = centers[i]
        end = centers[i + 1]
        ax.annotate(
            '', xy=end, xytext=start,
            arrowprops=dict(arrowstyle="->", color='black', lw=2)
        )

    # Additional arrows for "attention" connections
    attention_connections = [
        (center_conv1, center_conv4, 'red'),
        (center_pool1, center_up2, 'blue'),
        (center_conv2, center_conv3, 'green'),
        (center_pool2, center_up1, 'purple')
    ]

    for start, end, color in attention_connections:
        ax.annotate(
            '', xy=end, xytext=start,
            arrowprops=dict(arrowstyle="->", color=color, lw=2)
        )

    # Set the title of the visualization.
    plt.title('U-Net Architecture Visualization', fontsize=16)

    # Save the figure as a high-resolution PNG.
    output_filename = 'unet_architecture_3d.png'
    plt.savefig(output_filename, dpi=300)
    print(f" Saved architecture visualization as {output_filename}")

    return fig

# Generate, save, and display the U-Net architecture visualization.
fig = create_unet_architecture()
plt.show()