In [None]:
# @title Import libraries and setup environment
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from dataclasses import dataclass
from typing import List, Tuple, Optional
from IPython.display import HTML

# ensure plots render inline
%matplotlib inline

In [None]:
# @title Define global configuration parameters
@dataclass
class TPUConfig:
    """configuration parameters for the systolic array."""
    array_size: int = 4        # N x N grid size
    input_width: int = 8       # width of input registers (unused for now, prep for quantization)
    clock_speed: float = 1.0   # simulation delay

    # visualization settings
    cmap: str = 'viridis'

config = TPUConfig()
print(f"TinyTPU initialized with {config.array_size}x{config.array_size} grid.")

In [None]:
# @title Define the Processing Element (PE)
class ProcessingElement:
    """
    Represents a single MAC unit in the systolic array.
    Architecture: Weight-Stationary.

    Data Flow:
    - Weights: Fixed (Stationary)
    - Activations (in_a): Flow Left -> Right
    - Partial Sums (in_b): Flow Top -> Bottom
    """
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y
        self.weight = 0.0

        # Registers for data passing
        self.in_a = 0.0  # Activation input (from Left)
        self.in_b = 0.0  # Partial Sum input (from Top)

        self.out_a = 0.0 # Activation output (to Right)
        self.out_b = 0.0 # Partial Sum output (to Bottom)

    def load_weight(self, w: float):
        """Loads a stationary weight into the PE."""
        self.weight = w

    def compute(self):
        """
        Performs the MAC operation.
        Logic: New Partial Sum = Top Partial Sum + (Left Activation * Weight)
        """
        # 1. Calculate the local product
        product = self.in_a * self.weight

        # 2. Add to the partial sum flowing from above
        new_partial_sum = self.in_b + product

        # 3. Latch results to output registers for the next cycle
        self.out_a = self.in_a        # Pass activation to the neighbor on the right
        self.out_b = new_partial_sum  # Pass updated sum to the neighbor below

    def reset(self):
        self.in_a = 0.0
        self.in_b = 0.0
        self.out_a = 0.0
        self.out_b = 0.0

In [None]:
# @title Define the Systolic Array Grid
class SystolicArray:
    """
    The controller that manages the grid of PEs.
    """
    def __init__(self, size: int):
        self.size = size
        # Create a 2D grid of Processing Elements
        # grid[row][col] -> y is row, x is col
        self.grid = [[ProcessingElement(x, y) for x in range(size)] for y in range(size)]

    def load_weights(self, weights: np.ndarray):
        """
        Loads the weight matrix into the PE grid.
        Expects a square matrix of size (size, size).
        """
        assert weights.shape == (self.size, self.size), "Weight shape mismatch"
        for y in range(self.size):
            for x in range(self.size):
                self.grid[y][x].load_weight(weights[y, x])

    def reset(self):
        """Clears all registers."""
        for row in self.grid:
            for pe in row:
                pe.reset()

# Instantiate the array based on global config
tpu = SystolicArray(size=config.array_size)
print(f"Created {config.array_size}x{config.array_size} Systolic Array.")

In [None]:
# @title Data Pre-processing: Input Skewing
def skew_input_matrix(matrix: np.ndarray) -> np.ndarray:
    """
    Skews the input matrix to create the systolic 'wavefront'.
    For a Weight-Stationary architecture:
    - Input A (Rows) flow into the Left side.
    - We must delay Row i by i cycles so it meets the partial sum from above.
    """
    rows, cols = matrix.shape
    # The simulation needs enough time for the wave to pass through the whole array
    # Total width = Cols + Rows (skew) + extra drain time
    padded_width = cols + rows + 10

    skewed_buffer = np.zeros((rows, padded_width))

    for r in range(rows):
        # Shift each row to the right by 'r' steps
        # Matrix data: matrix[r, :]
        skewed_buffer[r, r : r + cols] = matrix[r, :]

    return skewed_buffer

# Visualization of what Skewing does
test_mat = np.ones((4, 4))
print("Original:\n", test_mat)
print("\nSkewed Input Buffer (Visualized):\n", skew_input_matrix(test_mat)[:, :10])

In [None]:
# @title Orchestration Controller
# We extend the class by inheriting or redefining.
# For this notebook flow, we will redefine with the added logic.

class SystolicArray(SystolicArray): # Inherit to keep old init/methods if we wanted, but let's just add the method
    def run_cycle(self, input_slice: np.ndarray):
        """
        Advances the array by one clock cycle.

        Args:
            input_slice: A 1D array of size (N,) containing the inputs
                         for the left side of the array for this specific cycle.
        """
        assert len(input_slice) == self.size

        # 1. READ PHASE (Simulate parallel read)
        # We need to capture the state of the neighbors BEFORE they update.
        # Bottom-up update avoids overwriting data needed by the row below,
        # but explicit buffering is safer for clarity.

        # We'll create a snapshot of the current 'out' registers to act as the 'wire'
        snapshot_out_a = [[pe.out_a for pe in row] for row in self.grid]
        snapshot_out_b = [[pe.out_b for pe in row] for row in self.grid]

        # 2. UPDATE PHASE (Latch inputs)
        for y in range(self.size):
            for x in range(self.size):
                pe = self.grid[y][x]

                # Input A (Activations) comes from Left
                if x == 0:
                    pe.in_a = input_slice[y] # From global input
                else:
                    pe.in_a = snapshot_out_a[y][x-1] # From neighbor

                # Input B (Partial Sums) comes from Top
                if y == 0:
                    pe.in_b = 0.0 # Top of array always initializes sum to 0
                else:
                    pe.in_b = snapshot_out_b[y-1][x] # From neighbor

        # 3. COMPUTE PHASE
        for row in self.grid:
            for pe in row:
                pe.compute()

    def get_outputs(self):
        """Returns the partial sums exiting the bottom of the array."""
        return np.array([pe.out_b for pe in self.grid[-1]])

# Re-init to use the new class definition
tpu = SystolicArray(size=config.array_size)
print("Systolic Array updated with clock cycle logic.")