In [1]:
# @title Import libraries and setup environment
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from dataclasses import dataclass
from typing import List, Tuple, Optional
from IPython.display import HTML

# ensure plots render inline
%matplotlib inline

In [2]:
# @title Define global configuration parameters
@dataclass
class TPUConfig:
    """configuration parameters for the systolic array."""
    array_size: int = 4        # N x N grid size
    input_width: int = 8       # width of input registers (unused for now, prep for quantization)
    clock_speed: float = 1.0   # simulation delay

    # visualization settings
    cmap: str = 'viridis'

config = TPUConfig()
print(f"TinyTPU initialized with {config.array_size}x{config.array_size} grid.")

TinyTPU initialized with 4x4 grid.


In [3]:
# @title Define the Processing Element (PE)
class ProcessingElement:
    """
    Represents a single MAC unit in the systolic array.
    Architecture: Weight-Stationary.

    Data Flow:
    - Weights: Fixed (Stationary)
    - Activations (in_a): Flow Left -> Right
    - Partial Sums (in_b): Flow Top -> Bottom
    """
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y
        self.weight = 0.0

        # Registers for data passing
        self.in_a = 0.0  # Activation input (from Left)
        self.in_b = 0.0  # Partial Sum input (from Top)

        self.out_a = 0.0 # Activation output (to Right)
        self.out_b = 0.0 # Partial Sum output (to Bottom)

    def load_weight(self, w: float):
        """Loads a stationary weight into the PE."""
        self.weight = w

    def compute(self):
        """
        Performs the MAC operation.
        Logic: New Partial Sum = Top Partial Sum + (Left Activation * Weight)
        """
        # 1. Calculate the local product
        product = self.in_a * self.weight

        # 2. Add to the partial sum flowing from above
        new_partial_sum = self.in_b + product

        # 3. Latch results to output registers for the next cycle
        self.out_a = self.in_a        # Pass activation to the neighbor on the right
        self.out_b = new_partial_sum  # Pass updated sum to the neighbor below

    def reset(self):
        self.in_a = 0.0
        self.in_b = 0.0
        self.out_a = 0.0
        self.out_b = 0.0

In [4]:
# @title Define the Systolic Array Grid
class SystolicArray:
    """
    The controller that manages the grid of PEs.
    """
    def __init__(self, size: int):
        self.size = size
        # Create a 2D grid of Processing Elements
        # grid[row][col] -> y is row, x is col
        self.grid = [[ProcessingElement(x, y) for x in range(size)] for y in range(size)]

    def load_weights(self, weights: np.ndarray):
        """
        Loads the weight matrix into the PE grid.
        Expects a square matrix of size (size, size).
        """
        assert weights.shape == (self.size, self.size), "Weight shape mismatch"
        for y in range(self.size):
            for x in range(self.size):
                self.grid[y][x].load_weight(weights[y, x])

    def reset(self):
        """Clears all registers."""
        for row in self.grid:
            for pe in row:
                pe.reset()

# Instantiate the array based on global config
tpu = SystolicArray(size=config.array_size)
print(f"Created {config.array_size}x{config.array_size} Systolic Array.")

Created 4x4 Systolic Array.
