In [None]:
# @title 1. Import libraries and setup environment
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from dataclasses import dataclass
from typing import List, Tuple, Optional
from IPython.display import HTML

# ensure plots render inline
%matplotlib inline

In [None]:
# @title 2. Define global configuration parameters
@dataclass
class TPUConfig:
    """configuration parameters for the systolic array."""
    array_size: int = 4        # N x N grid size
    input_width: int = 8       # width of input registers (unused for now, prep for quantization)
    clock_speed: float = 1.0   # simulation delay

    # visualization settings
    cmap: str = 'viridis'

config = TPUConfig()
print(f"TinyTPU initialized with {config.array_size}x{config.array_size} grid.")

In [None]:
# @title 3. Define the Processing Element (PE)
class ProcessingElement:
    """
    Represents a single MAC unit in the systolic array.
    Architecture: Weight-Stationary.

    Data Flow:
    - Weights: Fixed (Stationary)
    - Activations (in_a): Flow Left -> Right
    - Partial Sums (in_b): Flow Top -> Bottom
    """
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y
        self.weight = 0.0

        # Registers for data passing
        self.in_a = 0.0  # Activation input (from Left)
        self.in_b = 0.0  # Partial Sum input (from Top)

        self.out_a = 0.0 # Activation output (to Right)
        self.out_b = 0.0 # Partial Sum output (to Bottom)

    def load_weight(self, w: float):
        """Loads a stationary weight into the PE."""
        self.weight = w

    def compute(self):
        """
        Performs the MAC operation.
        Logic: New Partial Sum = Top Partial Sum + (Left Activation * Weight)
        """
        # 1. Calculate the local product
        product = self.in_a * self.weight

        # 2. Add to the partial sum flowing from above
        new_partial_sum = self.in_b + product

        # 3. Latch results to output registers for the next cycle
        self.out_a = self.in_a        # Pass activation to the neighbor on the right
        self.out_b = new_partial_sum  # Pass updated sum to the neighbor below

    def reset(self):
        self.in_a = 0.0
        self.in_b = 0.0
        self.out_a = 0.0
        self.out_b = 0.0

In [None]:
# @title 4. Define the Systolic Array Grid
class SystolicArray:
    """
    The controller that manages the grid of PEs.
    """
    def __init__(self, size: int):
        self.size = size
        # Create a 2D grid of Processing Elements
        # grid[row][col] -> y is row, x is col
        self.grid = [[ProcessingElement(x, y) for x in range(size)] for y in range(size)]

    def load_weights(self, weights: np.ndarray):
        """
        Loads the weight matrix into the PE grid.
        Expects a square matrix of size (size, size).
        """
        assert weights.shape == (self.size, self.size), "Weight shape mismatch"
        for y in range(self.size):
            for x in range(self.size):
                self.grid[y][x].load_weight(weights[y, x])

    def reset(self):
        """Clears all registers."""
        for row in self.grid:
            for pe in row:
                pe.reset()

# Instantiate the array based on global config
tpu = SystolicArray(size=config.array_size)
print(f"Created {config.array_size}x{config.array_size} Systolic Array.")

In [None]:
# @title 5. Data Pre-processing: Input Skewing
def skew_input_matrix(matrix: np.ndarray) -> np.ndarray:
    """
    Skews the input matrix to create the systolic 'wavefront'.
    For a Weight-Stationary architecture:
    - Input A (Rows) flow into the Left side.
    - We must delay Row i by i cycles so it meets the partial sum from above.
    """
    rows, cols = matrix.shape
    # The simulation needs enough time for the wave to pass through the whole array
    # Total width = Cols + Rows (skew) + extra drain time
    padded_width = cols + rows + 10

    skewed_buffer = np.zeros((rows, padded_width))

    for r in range(rows):
        # Shift each row to the right by 'r' steps
        # Matrix data: matrix[r, :]
        skewed_buffer[r, r : r + cols] = matrix[r, :]

    return skewed_buffer

# Visualization of what Skewing does
test_mat = np.ones((4, 4))
print("Original:\n", test_mat)
print("\nSkewed Input Buffer (Visualized):\n", skew_input_matrix(test_mat)[:, :10])

In [None]:
# @title 6. Orchestration Controller
# We extend the class by inheriting or redefining.
# For this notebook flow, we will redefine with the added logic.

class SystolicArray(SystolicArray): # Inherit to keep old init/methods if we wanted, but let's just add the method
    def run_cycle(self, input_slice: np.ndarray):
        """
        Advances the array by one clock cycle.

        Args:
            input_slice: A 1D array of size (N,) containing the inputs
                         for the left side of the array for this specific cycle.
        """
        assert len(input_slice) == self.size

        # 1. READ PHASE (Simulate parallel read)
        # We need to capture the state of the neighbors BEFORE they update.
        # Bottom-up update avoids overwriting data needed by the row below,
        # but explicit buffering is safer for clarity.

        # We'll create a snapshot of the current 'out' registers to act as the 'wire'
        snapshot_out_a = [[pe.out_a for pe in row] for row in self.grid]
        snapshot_out_b = [[pe.out_b for pe in row] for row in self.grid]

        # 2. UPDATE PHASE (Latch inputs)
        for y in range(self.size):
            for x in range(self.size):
                pe = self.grid[y][x]

                # Input A (Activations) comes from Left
                if x == 0:
                    pe.in_a = input_slice[y] # From global input
                else:
                    pe.in_a = snapshot_out_a[y][x-1] # From neighbor

                # Input B (Partial Sums) comes from Top
                if y == 0:
                    pe.in_b = 0.0 # Top of array always initializes sum to 0
                else:
                    pe.in_b = snapshot_out_b[y-1][x] # From neighbor

        # 3. COMPUTE PHASE
        for row in self.grid:
            for pe in row:
                pe.compute()

    def get_outputs(self):
        """Returns the partial sums exiting the bottom of the array."""
        return np.array([pe.out_b for pe in self.grid[-1]])

# Re-init to use the new class definition
tpu = SystolicArray(size=config.array_size)
print("Systolic Array updated with clock cycle logic.")

In [None]:
# @title 7. Visualization Engine
def run_and_visualize(array, input_matrix, weights, interval=200):
    """
    Runs the full simulation and generates an animation of the internal state.
    """
    # 1. Setup
    array.reset()
    array.load_weights(weights)
    skewed_inputs = skew_input_matrix(input_matrix)

    # Store history for animation
    history_accum = []

    # 2. Execution Loop
    total_cycles = skewed_inputs.shape[1]
    output_result = []

    print(f"Simulating {total_cycles} cycles...")

    for t in range(total_cycles):
        current_inputs = skewed_inputs[:, t]
        array.run_cycle(current_inputs)
        output_result.append(array.get_outputs())

        # Snapshot state
        grid_state = np.zeros((array.size, array.size))
        for y in range(array.size):
            for x in range(array.size):
                grid_state[y, x] = array.grid[y][x].out_b
        history_accum.append(grid_state.copy())

    # 3. Animation Logic
    fig, ax = plt.subplots(figsize=(5, 5))

    def update(frame):
        ax.clear()
        data = history_accum[frame]
        im = ax.imshow(data, cmap=config.cmap, vmin=0, vmax=np.max(history_accum))

        # Annotate values
        for y in range(array.size):
            for x in range(array.size):
                val = data[y, x]
                # Dynamic text color for readability
                color = "w" if val > np.max(history_accum)/2 else "k"
                ax.text(x, y, f"{val:.1f}", ha="center", va="center", color=color)

        ax.set_title(f"Cycle {frame}: Partial Sum Wavefront")
        ax.set_xticks([])
        ax.set_yticks([])
        return [im]

    # PASS INTERVAL HERE DIRECTLY
    ani = animation.FuncAnimation(fig, update, frames=total_cycles, interval=interval, blit=False)
    plt.close()

    return ani, np.array(output_result)

print("Visualization Engine ready.")

In [None]:
# @title 8. Execute Simulation
# 1. Define Inputs
# Use distinct values on diagonal to debug row-shifting errors clearly
# Row 0 will be x1, Row 1 x2, Row 2 x3, Row 3 x4
input_a = np.diag([1.0, 2.0, 3.0, 4.0])

# B (Weights): Simple gradient
weights_b = np.array([
    [1, 2, 3, 4],
    [1, 2, 3, 4],
    [1, 2, 3, 4],
    [1, 2, 3, 4]
], dtype=float)

# 2. Run Simulation
# Using interval=500 to slow down the playback
print(f"Computing A ({input_a.shape}) x B ({weights_b.shape})...")
anim, raw_outputs = run_and_visualize(tpu, input_a, weights_b, interval=500)

# 3. Post-Process Outputs
def deskew_output(raw_stream, n_rows, n_cols):
    result = np.zeros((n_rows, n_cols))
    array_height = config.array_size

    # CORRECTION: Latency is height - 1.
    # T=0: Input enters Row 0.
    # T=3: Result exits Row 3 (Bottom).
    base_latency = array_height - 1

    for r in range(n_rows):
        for c in range(n_cols):
            target_cycle = r + base_latency + c
            if target_cycle < len(raw_stream):
                result[r, c] = raw_stream[target_cycle][c]
    return result

final_result = deskew_output(raw_outputs, input_a.shape[0], weights_b.shape[1])

# 4. Display Animation
HTML(anim.to_jshtml())

In [None]:
# @title 9. Verification
reference = np.dot(input_a, weights_b)

print("--- Reference (Numpy) ---")
print(reference)
print("\n--- Simulator Output ---")
print(final_result)

# Strict check
if np.allclose(reference, final_result):
    print("\nSUCCESS: Simulation matches Numpy!")
else:
    print("\nMISMATCH: Check skew/timing logic.")

In [None]:
# @title 10. Define Quantized Arithmetic
class QuantizedPE(ProcessingElement):
    """
    A Processing Element that simulates Int8 quantization.
    We simulate this by rounding values before the Multiply-Accumulate (MAC) step.
    """
    def compute(self):
        # 1. Quantize Inputs (Simulate Int8 casting)
        # We round to nearest integer to simulate limited precision
        q_in_a = np.round(self.in_a)
        q_weight = np.round(self.weight)

        # 2. Perform Integer MAC
        product = q_in_a * q_weight

        # 3. Add to partial sum (Accumulators are usually higher precision, e.g. Int32)
        # So we don't round the partial sum 'in_b' strictly, just the product.
        new_partial_sum = self.in_b + product

        # 4. Latch
        self.out_a = self.in_a
        self.out_b = new_partial_sum

# We also need a Quantized Array controller to spawn these specific PEs
class QuantizedSystolicArray(SystolicArray):
    def __init__(self, size: int):
        self.size = size
        # Spawn QuantizedPEs instead of standard PEs
        self.grid = [[QuantizedPE(x, y) for x in range(size)] for y in range(size)]

print("Quantized Hardware definitions ready.")

In [None]:
# @title 11. Quantization Accuracy Test
# 1. Setup Inputs with FLOATING POINT values
# Random values between 0.0 and 5.0
np.random.seed(42)
input_a_float = np.random.rand(4, 4) * 5.0
weights_b_float = np.random.rand(4, 4) * 5.0

# 2. Run FP32 Simulation (Standard)
print("Running FP32 Simulation...")
tpu_fp32 = SystolicArray(size=4)
_, raw_fp32 = run_and_visualize(tpu_fp32, input_a_float, weights_b_float, interval=100)
result_fp32 = deskew_output(raw_fp32, 4, 4)

# 3. Run Int8 Simulation (Quantized)
print("Running Int8 Simulation...")
tpu_int8 = QuantizedSystolicArray(size=4)
_, raw_int8 = run_and_visualize(tpu_int8, input_a_float, weights_b_float, interval=100)
result_int8 = deskew_output(raw_int8, 4, 4)

# 4. Compare
print("\n--- FP32 Result (Subset) ---")
print(result_fp32[:2, :2]) # Just showing a corner to save space

print("\n--- Int8 Result (Subset) ---")
print(result_int8[:2, :2])

# Calculate Error
mse = ((result_fp32 - result_int8)**2).mean()
print(f"\nQuantization MSE Loss: {mse:.4f}")
print("Note: Loss is expected because we are rounding inputs!")