# TinyTPU: A simple JAX simulator

This notebook implements a simple cycle-accurate TPU simulation. The main goal is to illustrate how simple high-level JAX code can get executed with low-level systolic hardware.

It uses eager execution (interpreting operations one-by-one) rather than XLA-style graph compilation.

**Architecture**

*   Frontend (Interpreter): Intercepts JAX calls (dot, add, relu) at runtime and dispatches them immediately to the simulator.
*   Middleware (Tiling): Slices large matrices to fit the hardware grid and handles data skewing required for systolic data flow.
*   Backend (Hardware Model): Simulates a Systolic Array (weight-stationary) for matrix math and a Vector Unit for element-wise ops.
*   Quantization: Rounds inputs to Int8 before computation to model realistic hardware precision loss.

**Usage**: Run the cells sequentially. The final demo executes a complete Dense Layer ($Y = \text{ReLU}(X \cdot W + B)$) on the simulator and visualizes the quantization error relative to the FP32 reference.

In [None]:
# @title Hardware simulator: MxU and VPU
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from dataclasses import dataclass

# --- Configuration ---
@dataclass
class Config:
    grid_size: int = 4
    input_width: int = 8

config = Config()

# --- Hardware Model ---

class PE:
    """Weight-stationary processing element."""
    def __init__(self, x, y):
        self.x, self.y = x, y
        self.weight = 0.0
        self.a_in = 0.0  # Activations (Left -> Right)
        self.b_in = 0.0  # Partial Sums (Top -> Bottom)
        self.a_out = 0.0
        self.b_out = 0.0

    def load_weight(self, w):
        # In a real int8 implementation, we'd clamp here
        self.weight = w

    def compute(self):
        # Basic MAC operation
        # For quantization sim, we round inputs before the multiply
        w_q = np.round(self.weight)
        a_q = np.round(self.a_in)

        self.a_out = self.a_in # Pass activation through
        self.b_out = self.b_in + (a_q * w_q)

    def reset(self):
        self.a_in = self.b_in = self.a_out = self.b_out = 0.0

class SystolicArray:
    def __init__(self, size):
        self.size = size
        self.grid = [[PE(x, y) for x in range(size)] for y in range(size)]

    def load_weights(self, weights):
        assert weights.shape == (self.size, self.size)
        for r in range(self.size):
            for c in range(self.size):
                self.grid[r][c].load_weight(weights[r, c])

    def step(self, input_slice):
        # Capture current state to simulate synchronous clock edge
        # We need to read from neighbors' OUTPUT registers before they update
        next_a = np.zeros((self.size, self.size))
        next_b = np.zeros((self.size, self.size))

        # Read phase
        for r in range(self.size):
            for c in range(self.size):
                # Inputs from left
                if c == 0: val_a = input_slice[r]
                else:      val_a = self.grid[r][c-1].a_out

                # Inputs from top
                if r == 0: val_b = 0.0
                else:      val_b = self.grid[r-1][c].b_out

                # Update PE registers
                self.grid[r][c].a_in = val_a
                self.grid[r][c].b_in = val_b

        # Execute phase
        for row in self.grid:
            for pe in row:
                pe.compute()

    def read_out(self):
        # Read partial sums flowing out the bottom
        return np.array([pe.b_out for pe in self.grid[-1]])

class VectorUnit:
    """
    Simulates the TPU's Vector Processing Unit (VPU).
    Handles element-wise operations like Add, Relu, Exp.
    """
    def execute(self, opcode, *operands):
        # In a cycle-accurate sim, we would model vector lanes here.
        # For this functional sim, we perform the operation and quantize the result.

        if opcode == 'add':
            res = operands[0] + operands[1]
        elif opcode == 'relu':
            res = np.maximum(operands[0], 0)
        else:
            raise NotImplementedError(f"VPU opcode '{opcode}' not supported")

        # Simulating Int8 quantization noise for VPU ops
        return np.round(res)

In [None]:
# @title Data layout + tiling utilities

def skew_matrix(matrix):
    # Skews input A so that row 'i' is delayed by 'i' cycles
    # Essential for aligning data with the partial sum wavefront
    rows, cols = matrix.shape
    width = cols + rows + 10 # Buffer for drain time
    buffer = np.zeros((rows, width))

    for r in range(rows):
        buffer[r, r : r + cols] = matrix[r, :]
    return buffer

def deskew_result(stream, rows, cols):
    # Reconstructs the 2D matrix from the temporal output stream
    res = np.zeros((rows, cols))
    latency = 3 # Hardcoded for 4x4 grid (N-1)

    for r in range(rows):
        for c in range(cols):
            # The result for (r, c) appears at specific tick
            tick = r + latency + c
            if tick < len(stream):
                res[r, c] = stream[tick][c]
    return res

def tile_matmul(lhs, rhs, grid_size):
    # Breaks large matmul into grid-sized chunks
    # Assumes shapes are divisible by grid_size for now
    m, k = lhs.shape
    k2, n = rhs.shape
    assert k == k2

    ops = []

    # Iterate through output blocks
    for r in range(0, m, grid_size):
        for c in range(0, n, grid_size):
            # Accumulate along the K dimension
            for i in range(0, k, grid_size):
                tile_a = lhs[r:r+grid_size, i:i+grid_size]
                tile_b = rhs[i:i+grid_size, c:c+grid_size]

                ops.append({
                    'a': tile_a,
                    'b': tile_b,
                    'dest': (r, c) # Top-left coordinate in result
                })
    return ops

# --- Execution Engine ---

def run_tpu_job(lhs, rhs, grid_size=4):
    program = tile_matmul(lhs, rhs, grid_size)

    # Pre-allocate output memory
    out_shape = (lhs.shape[0], rhs.shape[1])
    global_mem = np.zeros(out_shape)

    # Init hardware
    chip = SystolicArray(grid_size)

    for i, op in enumerate(program):
        # Reset PEs for new tile
        # In real hw we'd stream this, but reset is cleaner for sim
        for row in chip.grid:
            for pe in row: pe.reset()

        chip.load_weights(op['b'])

        # Format input (Must transpose A tile for weight-stationary alignment)
        # Input A rows need to hit PE columns
        inp_stream = skew_matrix(op['a'].T)

        # Run
        raw_out = []
        for t in range(inp_stream.shape[1]):
            chip.step(inp_stream[:, t])
            raw_out.append(chip.read_out())

        # Write back to memory
        partial = deskew_result(np.array(raw_out), grid_size, grid_size)
        r, c = op['dest']
        global_mem[r:r+grid_size, c:c+grid_size] += partial

    return global_mem

In [None]:
# @title Dispatcher

import jax
import jax.numpy as jnp

# Initialize hardware units
vpu = VectorUnit()

def tpu_call(func, *args):
    """
    Traffic Controller: Routes JAX ops to the correct hardware unit.
    """
    arrs = [np.array(x) for x in args]

    # 1. Matrix Multiplication -> Systolic Array (MXU)
    if func.__name__ == 'dot':
        print(f"[Dispatcher] Routing {arrs[0].shape} x {arrs[1].shape} to Systolic Array...")
        return run_tpu_job(arrs[0], arrs[1], config.grid_size)

    # 2. Addition -> Vector Unit (VPU)
    elif func.__name__ == 'add':
        print(f"[Dispatcher] Routing 'add' to Vector Unit...")
        return vpu.execute('add', arrs[0], arrs[1])

    # 3. ReLU -> Vector Unit (VPU)
    # JAX doesn't have a simple 'relu' primitive in numpy, so we detect the function name
    elif func.__name__ == 'relu':
        print(f"[Dispatcher] Routing 'relu' to Vector Unit...")
        return vpu.execute('relu', arrs[0])

    else:
        raise NotImplementedError(f"Op '{func.__name__}' not supported on TinyTPU")

In [None]:
# @title Demo: Running a dense layer
print("\n--- Running Dense Layer (Relu(XW + B)) ---")

# 1. Setup Weights & Biases
key = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(key, 3)

X = jax.random.uniform(k1, (8, 8)) * 5
W = jax.random.uniform(k2, (8, 8)) * 5
B = jax.random.uniform(k3, (8, 8)) * 5

# 2. Execute Eagerly on Simulator
# The simulator resets between steps, but data flows conceptually from MXU -> VPU
step1 = tpu_call(jnp.dot, X, W)   # Matrix Multiply
step2 = tpu_call(jnp.add, step1, B) # Bias Add
output = tpu_call(jax.nn.relu, step2) # Activation

print("\nFinal Output Snapshot (Top Left):")
print(output[:4, :4])

# 3. Verify
ref = jax.nn.relu(jnp.add(jnp.dot(X, W), B))
mse = np.mean((output - ref)**2)
print(f"\nLayer MSE: {mse:.4f}")

In [None]:
# @title Visualization

import matplotlib.pyplot as plt

def visualize_results(reference, tpu_output, title_prefix=""):
    """
    Visualizes the Reference vs. TPU output and the quantization error.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Define the three plots we want
    plots = [
        (reference, f"{title_prefix} Reference (JAX)", "viridis"),
        (tpu_output, f"{title_prefix} TinyTPU (Int8)", "viridis"),
        (reference - tpu_output, "Quantization Error (Diff)", "coolwarm")
    ]

    for ax, (data, title, cmap) in zip(axes, plots):
        im = ax.imshow(data, cmap=cmap)
        ax.set_title(title)
        # Magic numbers to make the colorbar match the plot height
        fig.colorbar(im, ax=ax, fraction=0.042, pad=0.03)

    plt.tight_layout()
    plt.show()

# --- Run Visualization for the Dense Layer ---
# We use the variables 'ref' and 'output' from the previous cell
visualize_results(ref, output, title_prefix="Dense Layer")