# GPU Puzzles
- by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp)

![](https://github.com/srush/GPU-Puzzles/raw/main/cuda.png)

GPU architectures are critical to machine learning, and seem to be
becoming even more important every day. However, you can be an expert
in machine learning without ever touching GPU code. It is hard to gain
intuition working through abstractions. 

This notebook is an attempt to teach beginner GPU programming in a
completely interactive fashion. Instead of providing text with
concepts, it throws you right into coding and building GPU
kernels. The exercises use NUMBA which directly maps Python
code to CUDA kernels. It looks like Python but is basically
identical to writing low-level CUDA code. 
In a few hours, I think you can go from basics to
understanding the real algorithms that power 99% of deep learning
today. If you do want to read the manual, it is here:

[NUMBA CUDA Guide](https://numba.readthedocs.io/en/stable/cuda/index.html)

I recommend doing these in Colab, as it is easy to get started.  Be
sure to make your own copy, turn on GPU mode in the settings (`Runtime / Change runtime type`, then set `Hardware accelerator` to `GPU`), and
then get to coding.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/GPU-Puzzles/blob/main/GPU_puzzlers.ipynb)

(If you are into this style of puzzle, also check out my [Tensor
Puzzles](https://github.com/srush/Tensor-Puzzles) for PyTorch.)

In [None]:
!pip install -qqq git+https://github.com/chalk-diagrams/planar git+https://github.com/danoneata/chalk@srush-patch-1
!wget -q https://github.com/srush/GPU-Puzzles/raw/main/robot.png https://github.com/srush/GPU-Puzzles/raw/main/lib.py

In [None]:
import numba
import numpy as np
import warnings
from lib import CudaProblem, Coord

In [None]:
warnings.filterwarnings(
    action="ignore", category=numba.NumbaPerformanceWarning, module="numba"
)

## Puzzle 1: Map

Implement a "kernel" (GPU function) that adds 10 to each position of vector `a`
and stores it in vector `out`.  You have 1 thread per position.

**Warning** This code looks like Python but it is really CUDA! You cannot use
standard python tools like list comprehensions or ask for Numpy properties
like shape or size (if you need the size, it is given as an argument).
The puzzles only require doing simple operations, basically
+, *, simple array indexing, for loops, and if statements.
You are allowed to use local variables. 
If you get an
error it is probably because you did something fancy :). 

*Tip: Think of the function `call` as being run 1 time for each thread.
The only difference is that `cuda.threadIdx.x` changes each time.*

##  Puzzle 1: Map - Deep Dive Solution Explanation

---

###  Problem Statement
Implement a kernel that adds 10 to each position of vector `a` and stores the result in `out`.

---

###  Core Concept: GPU Parallel Execution Model

**The Fundamental Shift in Thinking:**
- **CPU**: Write a loop that processes elements one-by-one
- **GPU**: Write code for ONE element, run it on THOUSANDS of threads simultaneously

```
CPU Sequential:                    GPU Parallel:
┌─────────────────────┐           ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ for i in range(4):  │           │ T0  │ │ T1  │ │ T2  │ │ T3  │
│   out[i]=a[i]+10    │           │out[0]│ │out[1]│ │out[2]│ │out[3]│
│ (4 sequential ops)  │           └─────┘ └─────┘ └─────┘ └─────┘
└─────────────────────┘              └───────────┬───────────┘
                                         ALL AT ONCE
```

---

###  Line-by-Line Deep Analysis

#### `local_i = cuda.threadIdx.x`

**What is threadIdx.x?**
Every CUDA thread has a built-in 3D identifier: `(threadIdx.x, threadIdx.y, threadIdx.z)`. For 1D arrays, we use only `.x`.

```
When kernel launches with 4 threads:

Thread 0: cuda.threadIdx.x == 0
Thread 1: cuda.threadIdx.x == 1  
Thread 2: cuda.threadIdx.x == 2
Thread 3: cuda.threadIdx.x == 3

This is AUTOMATIC - you don't set it, you READ it.
It's how each thread knows "I am thread #N"
```

#### `out[local_i] = a[local_i] + 10`

Each thread:
1. **Reads** from global memory: `a[local_i]`
2. **Computes**: adds 10
3. **Writes** to global memory: `out[local_i]`

---

###  Complete Execution Trace

```
Configuration: SIZE=4, a=[0,1,2,3]

╔════════════════════════════════════════════════════════════════╗
║                    PARALLEL EXECUTION                           ║
║            All 4 threads execute SIMULTANEOUSLY                 ║
╠════════════════════════════════════════════════════════════════╣
║                                                                 ║
║  THREAD 0                      THREAD 1                        ║
║  ┌──────────────────────┐     ┌──────────────────────┐        ║
║  │ threadIdx.x = 0      │     │ threadIdx.x = 1      │        ║
║  │ local_i = 0          │     │ local_i = 1          │        ║
║  │                      │     │                      │        ║
║  │ 1. Read a[0] → 0     │     │ 1. Read a[1] → 1     │        ║
║  │ 2. Compute 0+10=10   │     │ 2. Compute 1+10=11   │        ║
║  │ 3. Write out[0]=10   │     │ 3. Write out[1]=11   │        ║
║  └──────────────────────┘     └──────────────────────┘        ║
║                                                                 ║
║  THREAD 2                      THREAD 3                        ║
║  ┌──────────────────────┐     ┌──────────────────────┐        ║
║  │ threadIdx.x = 2      │     │ threadIdx.x = 3      │        ║
║  │ local_i = 2          │     │ local_i = 3          │        ║
║  │                      │     │                      │        ║
║  │ 1. Read a[2] → 2     │     │ 1. Read a[3] → 3     │        ║
║  │ 2. Compute 2+10=12   │     │ 2. Compute 3+10=13   │        ║
║  │ 3. Write out[2]=12   │     │ 3. Write out[3]=13   │        ║
║  └──────────────────────┘     └──────────────────────┘        ║
║                                                                 ║
╚════════════════════════════════════════════════════════════════╝

RESULT: out = [10, 11, 12, 13]
```

---

###  Key Insight: Same Code, Different Data

```python
# This is the SAME code that ALL threads run:
local_i = cuda.threadIdx.x
out[local_i] = a[local_i] + 10

# But each thread gets a DIFFERENT value for threadIdx.x!
# That's what makes it parallel.
```

---

###  Memory Access Pattern: Coalesced Access

When adjacent threads access adjacent memory addresses, GPU can combine these into efficient bulk transfers:

```
Coalesced (GOOD):          Non-coalesced (BAD):
T0→a[0] T1→a[1] T2→a[2]    T0→a[0] T1→a[7] T2→a[3]
    ↓                           ↓
ONE memory transaction      MULTIPLE transactions
```

Our solution is perfectly coalesced!

---

###  Key Takeaways

1. **threadIdx.x** = unique ID for each thread (0 to N-1)
2. **Write code for ONE thread** - GPU runs it on all threads
3. **Use thread ID as array index** - creates the parallelism
4. **No loops needed** - parallelism replaces iteration


In [None]:
def map_spec(a):
    return a + 10


def map_test(cuda):
    def call(out, a) -> None:
        local_i = cuda.threadIdx.x
        out[local_i] = a[local_i] + 10
        
    return call


SIZE = 4
out = np.zeros((SIZE,))
a = np.arange(SIZE)
problem = CudaProblem(
    "Map", map_test, [a], out, threadsperblock=Coord(SIZE, 1), spec=map_spec
)
problem.show()

In [None]:
problem.check()

## Puzzle 2 - Zip

Implement a kernel that adds together each position of `a` and `b` and stores it in `out`.
You have 1 thread per position.

##  Puzzle 2: Zip - Deep Dive Solution Explanation

---

###  Problem Statement
Add corresponding elements of vectors `a` and `b`, store in `out`: `out[i] = a[i] + b[i]`

---

###  Core Concept: Element-wise Operations

This is the most common GPU pattern - each output depends only on inputs at the same index.

```
a = [0, 1, 2, 3]
    +  +  +  +   (element-wise)
b = [0, 1, 2, 3]
    =  =  =  =
out=[0, 2, 4, 6]
```

---

###  Thread Execution Trace

```
╔════════════════════════════════════════════════════════════════╗
║  Thread 0                      Thread 1                        ║
║  ┌──────────────────────┐     ┌──────────────────────┐        ║
║  │ Read a[0] → 0        │     │ Read a[1] → 1        │        ║
║  │ Read b[0] → 0        │     │ Read b[1] → 1        │        ║
║  │ Compute: 0+0=0       │     │ Compute: 1+1=2       │        ║
║  │ Write out[0]=0       │     │ Write out[1]=2       │        ║
║  └──────────────────────┘     └──────────────────────┘        ║
║                                                                 ║
║  Thread 2                      Thread 3                        ║
║  ┌──────────────────────┐     ┌──────────────────────┐        ║
║  │ Read a[2] → 2        │     │ Read a[3] → 3        │        ║
║  │ Read b[2] → 2        │     │ Read b[3] → 3        │        ║
║  │ Compute: 2+2=4       │     │ Compute: 3+3=6       │        ║
║  │ Write out[2]=4       │     │ Write out[3]=6       │        ║
║  └──────────────────────┘     └──────────────────────┘        ║
╚════════════════════════════════════════════════════════════════╝
```

---

###  Why No Synchronization Needed?

Each thread operates on completely independent data:
- Thread 0 only touches index 0
- Thread 1 only touches index 1
- No thread reads what another writes

**No data dependencies = no synchronization required!**

---

###  Key Takeaways

1. **Same index pattern** - Thread i accesses index i in ALL arrays
2. **Multiple reads, one write** - Can read from many arrays, write to one
3. **Perfect parallelism** - Threads are 100% independent


In [None]:
from threading import local


def zip_spec(a, b):
    return a + b


def zip_test(cuda):
    def call(out, a, b) -> None:
        local_i = cuda.threadIdx.x
        out[local_i]=a[local_i]+b[local_i]

    return call


SIZE = 4
out = np.zeros((SIZE,))
a = np.arange(SIZE)
b = np.arange(SIZE)
problem = CudaProblem(
    "Zip", zip_test, [a, b], out, threadsperblock=Coord(SIZE, 1), spec=zip_spec
)
problem.show()

In [None]:
problem.check()

## Puzzle 3 - Guards

Implement a kernel that adds 10 to each position of `a` and stores it in `out`.
You have more threads than positions.

##  Puzzle 3: Guards - Deep Dive Solution Explanation

---

###  Problem Statement
Add 10 to each element of `a`, but handle the case where **more threads are launched than data elements exist**.

---

###  Core Concept: Why Guards Are Essential

```
The Problem:
├── Array size:    4 elements (indices 0,1,2,3)
├── Threads:       8 threads (indices 0,1,2,3,4,5,6,7)
└── Extra threads: 4 threads that have NO DATA to work on!

WITHOUT guard:
Thread 4 executes: out[4] = a[4] + 10
                        ↓
                   a[4] DOESN'T EXIST! 
                   → Undefined behavior
                   → Crash or memory corruption
```

---

###  Guard Condition Execution

```
╔════════════════════════════════════════════════════════════════╗
║                    THE GUARD IN ACTION                          ║
╠════════════════════════════════════════════════════════════════╣
║                                                                 ║
║  Thread 0: local_i=0   0 < 4 ?  YES  → out[0]=a[0]+10        ║
║  Thread 1: local_i=1   1 < 4 ?  YES  → out[1]=a[1]+10        ║
║  Thread 2: local_i=2   2 < 4 ?  YES  → out[2]=a[2]+10        ║
║  Thread 3: local_i=3   3 < 4 ?  YES  → out[3]=a[3]+10        ║
║  ─────────────────────────────────────────────────────────────  ║
║  Thread 4: local_i=4   4 < 4 ?  NO   → (does nothing)        ║
║  Thread 5: local_i=5   5 < 4 ?  NO   → (does nothing)        ║
║  Thread 6: local_i=6   6 < 4 ?  NO   → (does nothing)        ║
║  Thread 7: local_i=7   7 < 4 ?  NO   → (does nothing)        ║
║                                                                 ║
╚════════════════════════════════════════════════════════════════╝
```

---

###  What Happens Without a Guard?

```
Memory layout:

Valid memory:        Invalid (unallocated):
a[0] a[1] a[2] a[3]  ??? ??? ??? ???
 ↑    ↑    ↑    ↑     ↑   ↑   ↑   ↑
T0   T1   T2   T3    T4  T5  T6  T7 (try to access)
                      ↓
              UNDEFINED BEHAVIOR:
              • Read garbage
              • Crash
              • Corrupt other data
              • Wrong results
```

---

###  Key Takeaways

1. **ALWAYS use guards** - Even if you think sizes match
2. **Pass size as parameter** - Kernel needs boundary information
3. **Guard before array access** - `if idx < size:` then access
4. **Guards are cheap** - One comparison saves crashes


In [None]:
def map_guard_test(cuda):
    def call(out, a, size) -> None:
        local_i = cuda.threadIdx.x
        for i in range(size):
            out[i]=a[i]+10
    return call


SIZE = 4
out = np.zeros((SIZE,))
a = np.arange(SIZE)
problem = CudaProblem(
    "Guard",
    map_guard_test,
    [a],
    out,
    [SIZE],
    threadsperblock=Coord(8, 1),
    spec=map_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 4 - Map 2D

Implement a kernel that adds 10 to each position of `a` and stores it in `out`.
Input `a` is 2D and square. You have more threads than positions.

##  Puzzle 4: Map 2D - Deep Dive Solution Explanation

---

###  Problem Statement
Add 10 to each element of a 2D matrix using 2D thread indexing.

---

###  Core Concept: 2D Thread Organization

```
2D Thread Block (3×3):

           threadIdx.x (row)
              0     1     2
           ┌─────┬─────┬─────┐
         0 │(0,0)│(1,0)│(2,0)│
threadIdx  ├─────┼─────┼─────┤
   .y    1 │(0,1)│(1,1)│(2,1)│
(column)   ├─────┼─────┼─────┤
         2 │(0,2)│(1,2)│(2,2)│
           └─────┴─────┴─────┘

Each thread has TWO coordinates: (threadIdx.x, threadIdx.y)
```

---

###  Thread-to-Matrix Mapping

```
2×2 Matrix:              3×3 Thread Block:
┌────┬────┐             ┌─────┬─────┬─────┐
│a00 │a01 │             │T(0,0)│T(1,0)│T(2,0)│ ←Row 2 GUARDED
├────┼────┤             ├─────┼─────┼─────┤
│a10 │a11 │             │T(0,1)│T(1,1)│T(2,1)│ ←Row 2 GUARDED
└────┴────┘             ├─────┼─────┼─────┤
                        │T(0,2)│T(1,2)│T(2,2)│ ←ALL GUARDED
                        └─────┴─────┴─────┘
                              ↑
                           Col 2 GUARDED

Active: T(0,0), T(1,0), T(0,1), T(1,1) = 4 threads
Guarded: 5 threads (where i≥2 OR j≥2)
```

---

###  Why Guard BOTH Dimensions?

```python
if local_i < size and local_j < size:
```

Thread (2,0): `local_i=2, local_j=0`
- Check `local_i < 2`: 2 < 2 → FALSE → Correctly guarded

Thread (0,2): `local_i=0, local_j=2`  
- Check `local_j < 2`: 2 < 2 → FALSE → Correctly guarded

If we only checked ONE dimension, we'd miss the other!

---

###  Key Takeaways

1. **threadIdx.x for rows, threadIdx.y for columns** (or vice versa, be consistent!)
2. **Guard BOTH dimensions**: `if i < size and j < size`
3. **9 threads for 4 elements = 5 idle threads** - that's OK!


In [None]:
def map_2D_test(cuda):
    def call(out, a, size) -> None:
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y
        for i in range(size):
            for j in range(size):
                out[i,j]=a[i,j]+10
    return call


SIZE = 2
out = np.zeros((SIZE, SIZE))
a = np.arange(SIZE * SIZE).reshape((SIZE, SIZE))
problem = CudaProblem(
    "Map 2D", map_2D_test, [a], out, [SIZE], threadsperblock=Coord(3, 3), spec=map_spec
)
problem.show()

In [None]:
problem.check()

## Puzzle 5 - Broadcast

Implement a kernel that adds `a` and `b` and stores it in `out`.
Inputs `a` and `b` are vectors. You have more threads than positions.

##  Puzzle 5: Broadcast - Deep Dive Solution Explanation

---

###  Problem Statement
Add a column vector and row vector using broadcasting to produce a 2D matrix.

---

###  Core Concept: Broadcasting

Broadcasting "stretches" arrays to match shapes:

```
Column a (2×1):    Row b (1×2):       Result (2×2):
    ┌───┐          ┌───┬───┐         ┌─────────┬─────────┐
    │ 0 │    +     │ 0 │ 1 │    =    │ 0+0 = 0 │ 0+1 = 1 │
    ├───┤          └───┴───┘         ├─────────┼─────────┤
    │ 1 │                            │ 1+0 = 1 │ 1+1 = 2 │
    └───┘                            └─────────┴─────────┘
    
Column "repeats"   Row "repeats"
horizontally       vertically
```

---

###  The Key Insight: Index Patterns

```python
a[local_i, 0]   # Row varies (local_i), Column FIXED at 0
b[0, local_j]   # Row FIXED at 0, Column varies (local_j)

Thread (0,0): a[0,0] + b[0,0] = 0 + 0 = 0
Thread (1,0): a[1,0] + b[0,0] = 1 + 0 = 1  ← Different row of a
Thread (0,1): a[0,0] + b[0,1] = 0 + 1 = 1  ← Different col of b
Thread (1,1): a[1,0] + b[0,1] = 1 + 1 = 2  ← Both different
```

---

###  Data Reuse Pattern

```
a[0,0] is read by: Thread(0,0), Thread(0,1)  ← Entire column of threads
a[1,0] is read by: Thread(1,0), Thread(1,1)  ← Entire column of threads
b[0,0] is read by: Thread(0,0), Thread(1,0)  ← Entire row of threads
b[0,1] is read by: Thread(0,1), Thread(1,1)  ← Entire row of threads

Each element is read multiple times - opportunity for shared memory!
```

---

###  Key Takeaways

1. **Fix one index, vary the other** - Creates the broadcast pattern
2. **Column vector**: `a[row, 0]` - column always 0
3. **Row vector**: `b[0, col]` - row always 0
4. **Same data read by multiple threads** - Optimization opportunity


In [None]:
def broadcast_test(cuda):
    def call(out, a, b, size) -> None:
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y
        for i in range(size):
            for j in range(size):
                out[i,j]=a[i,0]+b[0,j]
    return call


SIZE = 2
out = np.zeros((SIZE, SIZE))
a = np.arange(SIZE).reshape(SIZE, 1)
b = np.arange(SIZE).reshape(1, SIZE)
problem = CudaProblem(
    "Broadcast",
    broadcast_test,
    [a, b],
    out,
    [SIZE],
    threadsperblock=Coord(3, 3),
    spec=zip_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 6 - Blocks

Implement a kernel that adds 10 to each position of `a` and stores it in `out`.
You have fewer threads per block than the size of `a`.

*Tip: A block is a group of threads. The number of threads per block is limited, but we can
have many different blocks. Variable `cuda.blockIdx` tells us what block we are in.*

##  Puzzle 6: Blocks - Deep Dive Solution Explanation

---

###  Problem Statement
Add 10 to each element using MULTIPLE BLOCKS for arrays larger than one block.

---

###  Core Concept: Block Hierarchy

```
Grid (all blocks):
┌─────────────┬─────────────┬─────────────┐
│  BLOCK 0    │  BLOCK 1    │  BLOCK 2    │
│ blockIdx=0  │ blockIdx=1  │ blockIdx=2  │
├─────────────┼─────────────┼─────────────┤
│ T0 T1 T2 T3 │ T0 T1 T2 T3 │ T0 T1 T2 T3 │
└─────────────┴─────────────┴─────────────┘
  handles       handles       handles
  a[0:4]        a[4:8]        a[8:9]
```

---

###  The Global Index Formula

```
i = blockIdx.x * blockDim.x + threadIdx.x

Example: blockDim=4, 3 blocks

Block 0 (blockIdx.x = 0):
  T0: i = 0*4 + 0 = 0
  T1: i = 0*4 + 1 = 1
  T2: i = 0*4 + 2 = 2
  T3: i = 0*4 + 3 = 3

Block 1 (blockIdx.x = 1):
  T0: i = 1*4 + 0 = 4
  T1: i = 1*4 + 1 = 5
  T2: i = 1*4 + 2 = 6
  T3: i = 1*4 + 3 = 7

Block 2 (blockIdx.x = 2):
  T0: i = 2*4 + 0 = 8  ← Valid (8 < 9)
  T1: i = 2*4 + 1 = 9  ← GUARDED (9 ≥ 9)
  T2: i = 2*4 + 2 = 10 ← GUARDED
  T3: i = 2*4 + 3 = 11 ← GUARDED
```

---

###  Visual Mapping

```
Array:    [0] [1] [2] [3] [4] [5] [6] [7] [8]
           │   │   │   │   │   │   │   │   │
Thread:   T0  T1  T2  T3  T0  T1  T2  T3  T0
Block:    └─── Block 0 ───┘ └─── Block 1 ───┘ └─ Block 2 ─┘
```

---

###  Key Takeaways

1. **Global index**: `blockIdx.x * blockDim.x + threadIdx.x`
2. **Blocks are independent** - Cannot communicate directly
3. **Always guard** - Last block often has extra threads
4. **Formula for # blocks**: `(size + blockDim - 1) // blockDim`


In [None]:
def map_block_test(cuda):
    def call(out, a, size) -> None:
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        if i < size:
            out[i] = a[i] + 10
    return call


SIZE = 9
out = np.zeros((SIZE,))
a = np.arange(SIZE)
problem = CudaProblem(
    "Blocks",
    map_block_test,
    [a],
    out,
    [SIZE],
    threadsperblock=Coord(4, 1),
    blockspergrid=Coord(3, 1),
    spec=map_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 7 - Blocks 2D

Implement the same kernel in 2D.  You have fewer threads per block
than the size of `a` in both directions.

##  Puzzle 7: Blocks 2D - Deep Dive Solution Explanation

---

###  Problem Statement  
Add 10 to each element of a 2D matrix using 2D blocks.

---

###  2D Index Calculation

Same formula, applied to BOTH dimensions:

```
Row:    i = blockIdx.x * blockDim.x + threadIdx.x
Column: j = blockIdx.y * blockDim.y + threadIdx.y

Example: 5×5 matrix, 3×3 blocks, 2×2 grid

Block (1,1), Thread (1,0):
  i = 1*3 + 1 = 4
  j = 1*3 + 0 = 3
  → handles element a[4,3]
```

---

###  Block Coverage

```
         Columns 0-2      Columns 3-5
        (blockIdx.y=0)   (blockIdx.y=1)
       ┌──────────────┬──────────────┐
       │  Block(0,0)  │  Block(0,1)  │ Rows 0-2
       │   9 threads  │  6 active    │ (blockIdx.x=0)
       ├──────────────┼──────────────┤
       │  Block(1,0)  │  Block(1,1)  │ Rows 3-5
       │  6 active    │  4 active    │ (blockIdx.x=1)
       └──────────────┴──────────────┘
                       
36 total threads, 25 active, 11 guarded
```

---

###  Key Takeaways

1. **Same formula for each dimension**
2. **2D grid of 2D blocks** - Natural for matrices
3. **Guard both dimensions**: `if i < size and j < size`


In [None]:
def map_block2D_test(cuda):
    def call(out, a, size) -> None:
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        j = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
        if i < size and j < size:
            out[i,j]=a[i,j]+10
    return call


SIZE = 5
out = np.zeros((SIZE, SIZE))
a = np.ones((SIZE, SIZE))

problem = CudaProblem(
    "Blocks 2D",
    map_block2D_test,
    [a],
    out,
    [SIZE],
    threadsperblock=Coord(3, 3),
    blockspergrid=Coord(2, 2),
    spec=map_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 8 - Shared

Implement a kernel that adds 10 to each position of `a` and stores it in `out`.
You have fewer threads per block than the size of `a`.

**Warning**: Each block can only have a *constant* amount of shared
 memory that threads in that block can read and write to. This needs
 to be a literal python constant not a variable. After writing to
 shared memory you need to call `cuda.syncthreads` to ensure that
 threads do not cross.

(This example does not really need shared memory or syncthreads, but it is a demo.)

##  Puzzle 8: Shared Memory - Deep Dive Solution Explanation

---

###  Problem Statement
Use shared memory to load data cooperatively before computing.

---

###  Core Concept: Memory Hierarchy

```
Speed:  Registers > Shared Memory > Global Memory
           ↑            ↑              ↑
        ~1 cycle     ~5 cycles    ~400 cycles

Shared memory is ~100x faster than global memory!
```

---

###  Why syncthreads() Is Critical

```
WITHOUT syncthreads():

Thread 0: shared[0]=a[0] ─────────────── val=shared[0] ────→
Thread 1: ───────────── shared[1]=a[1] ──────────── val=shared[1]
Thread 2: shared[2]=a[2] ── val=shared[2] ←─ DANGER! Thread 2 
                                             might read before
                                             Thread 2 writes!

WITH syncthreads():

Thread 0: shared[0]=a[0] ─┐
Thread 1: shared[1]=a[1] ─┼──BARRIER──┬─ val=shared[0]
Thread 2: shared[2]=a[2] ─┤   wait    ├─ val=shared[1]
Thread 3: shared[3]=a[3] ─┘   here    └─ val=shared[2]

ALL writes complete before ANY reads!
```

---

###  Execution Trace

```
Block 0:
┌────────────────────────────────────────────────────────────────┐
│ PHASE 1: Cooperative Load (all threads work in parallel)       │
│                                                                 │
│   T0: shared[0] = a[0]    T2: shared[2] = a[2]                │
│   T1: shared[1] = a[1]    T3: shared[3] = a[3]                │
│                                                                 │
│   shared[] = [a[0], a[1], a[2], a[3]]                         │
├────────────────────────────────────────────────────────────────┤
│ syncthreads() - ALL threads wait here                          │
├────────────────────────────────────────────────────────────────┤
│ PHASE 2: Compute (all threads read their element)              │
│                                                                 │
│   T0: val=shared[0], out[0]=val+10                            │
│   T1: val=shared[1], out[1]=val+10                            │
│   T2: val=shared[2], out[2]=val+10                            │
│   T3: val=shared[3], out[3]=val+10                            │
└────────────────────────────────────────────────────────────────┘
```

---

###  Key Properties of Shared Memory

```
1. PER-BLOCK: Each block has its own separate shared memory
   Block 0: shared[] = [_, _, _, _]
   Block 1: shared[] = [_, _, _, _]  ← Different array!

2. LIMITED SIZE: ~48KB per block (varies by GPU)

3. REQUIRES SYNC: Must use syncthreads() between write and read phases

4. DECLARED STATICALLY: Size must be known at compile time
```

---

###  Key Takeaways

1. **Shared memory is fast** - Use it for data accessed multiple times
2. **syncthreads() is mandatory** - Prevents race conditions
3. **Each block has separate shared memory** - No cross-block sharing
4. **Pattern: Load → Sync → Compute → Sync (if needed) → Store**


In [None]:
TPB = 4
def shared_test(cuda):
    def call(out, a, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            shared[local_i] = a[i]
            cuda.syncthreads()

            val = shared[local_i]
            out[i]=val+10
        

    return call


SIZE = 8
out = np.zeros(SIZE)
a = np.ones(SIZE)
problem = CudaProblem(
    "Shared",
    shared_test,
    [a],
    out,
    [SIZE],
    threadsperblock=Coord(TPB, 1),
    blockspergrid=Coord(2, 1),
    spec=map_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 9 - Pooling

Implement a kernel that sums together the last 3 position of `a` and stores it in `out`.
You have 1 thread per position. You only need 1 global read and 1 global write per thread.

*Tip: Remember to be careful about syncing.*

##  Puzzle 9: Pooling - Deep Dive Solution Explanation

---

###  Problem Statement
Compute sliding window sum: `out[i] = a[i] + a[i-1] + a[i-2]` (where indices exist)

---

###  Core Concept: Neighbor Access Pattern

Each output needs MULTIPLE input values - not just its own index!

```
Input:  [0, 1, 2, 3, 4, 5, 6, 7]
        
out[0] = a[0]                    = 0
out[1] = a[0] + a[1]             = 0 + 1 = 1
out[2] = a[0] + a[1] + a[2]      = 0 + 1 + 2 = 3
out[3] = a[1] + a[2] + a[3]      = 1 + 2 + 3 = 6
out[4] = a[2] + a[3] + a[4]      = 2 + 3 + 4 = 9
...
```

---

###  Why Shared Memory Matters Here

```
Without shared memory:
Thread 2 needs: a[0], a[1], a[2]  ← 3 global memory reads
Thread 3 needs: a[1], a[2], a[3]  ← 3 global memory reads
Thread 4 needs: a[2], a[3], a[4]  ← 3 global memory reads
                  ↑    ↑
            Same elements read multiple times!

With shared memory:
1. Each thread loads ONE element to shared (8 global reads total)
2. Threads read from FAST shared memory for neighbors
```

---

###  Execution Trace for Thread 3

```
local_i = 3

Loop iteration j=0: neigh_idx = 3-0 = 3, 3≥0 , val += shared[3]
Loop iteration j=1: neigh_idx = 3-1 = 2, 2≥0 , val += shared[2]
Loop iteration j=2: neigh_idx = 3-2 = 1, 1≥0 , val += shared[1]

val = shared[3] + shared[2] + shared[1]
    = a[3] + a[2] + a[1]
    = 3 + 2 + 1 = 6

out[3] = 6 
```

---

###  Boundary Handling

```
Thread 0 (local_i = 0):
  j=0: neigh_idx = 0, 0≥0  → val += shared[0]
  j=1: neigh_idx = -1, -1≥0  → skip
  j=2: neigh_idx = -2, -2≥0  → skip
  
  Only adds shared[0], correctly handles left boundary!
```

---

###  Key Takeaways

1. **Neighbor access** - Each thread reads nearby elements
2. **Shared memory** - Load once, read multiple times
3. **Boundary checks** - Handle edge cases explicitly


In [None]:
def pool_spec(a):
    out = np.zeros(*a.shape)
    for i in range(a.shape[0]):
        out[i] = a[max(i - 2, 0) : i + 1].sum()
    return out


TPB = 8
def pool_test(cuda):
    def call(out, a, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < size:
            shared[local_i] = a[i]
        cuda.syncthreads()

        if i < size:
            val=0
            for j in range(3):
                neigh_idx = local_i - j
                if neigh_idx >= 0:
                    val += shared[neigh_idx]

            out[i] = val

    return call


SIZE = 8
out = np.zeros(SIZE)
a = np.arange(SIZE)
problem = CudaProblem(
    "Pooling",
    pool_test,
    [a],
    out,
    [SIZE],
    threadsperblock=Coord(TPB, 1),
    blockspergrid=Coord(1, 1),
    spec=pool_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 10 - Dot Product

Implement a kernel that computes the dot-product of `a` and `b` and stores it in `out`.
You have 1 thread per position. You only need 2 global reads and 1 global write per thread.

*Note: For this problem you don't need to worry about number of shared reads. We will
 handle that challenge later.*

##  Puzzle 10: Dot Product - Deep Dive Solution Explanation

---

###  Problem Statement
Compute dot product: `out = Σ a[i] × b[i]`

---

###  Core Concept: Reduction

A reduction combines many values into ONE result. This requires cooperation between threads!

```
a = [0, 1, 2, 3, 4, 5, 6, 7]
b = [0, 1, 2, 3, 4, 5, 6, 7]

Products: [0, 1, 4, 9, 16, 25, 36, 49]
Sum:      0+1+4+9+16+25+36+49 = 140
```

---

###  Two-Phase Execution

```
╔════════════════════════════════════════════════════════════════╗
║ PHASE 1: PARALLEL MULTIPLY                                      ║
║ All 8 threads work simultaneously                               ║
╠════════════════════════════════════════════════════════════════╣
║ T0: shared[0] = 0×0 = 0     T4: shared[4] = 4×4 = 16           ║
║ T1: shared[1] = 1×1 = 1     T5: shared[5] = 5×5 = 25           ║
║ T2: shared[2] = 2×2 = 4     T6: shared[6] = 6×6 = 36           ║
║ T3: shared[3] = 3×3 = 9     T7: shared[7] = 7×7 = 49           ║
╠════════════════════════════════════════════════════════════════╣
║ syncthreads() - wait for all products                          ║
╠════════════════════════════════════════════════════════════════╣
║ PHASE 2: SEQUENTIAL SUM (Thread 0 only)                        ║
║                                                                 ║
║ T0: total = 0+1+4+9+16+25+36+49 = 140                          ║
║     out[0] = 140                                                ║
║                                                                 ║
║ T1-T7: (do nothing, wait)                                      ║
╚════════════════════════════════════════════════════════════════╝
```

---

###  Limitation: Sequential Sum

This solution uses O(n) time for the sum (single thread).

**Better approach: Parallel Reduction** (next puzzle!)
- O(log n) time using tree-based summing
- All threads participate in reduction

---

###  Key Takeaways

1. **Map-Reduce pattern**: Parallel multiply, reduce to sum
2. **Shared memory**: Stores intermediate products
3. **Single thread finishes**: Simple but not optimal for large arrays


In [None]:
def dot_spec(a, b):
    return a @ b

TPB = 8
def dot_test(cuda):
    def call(out, a, b, size) -> None:
        shared = cuda.shared.array(TPB, numba.float32)
    
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x
        
        if i < size:
            shared[local_i] = a[i]*b[i]
        else:
            shared[local_i] = 0
        
        cuda.syncthreads()

        if local_i==0:
            total=0
            for j in range(size):
                total += shared[j]

            out[0] = total
        
    return call


SIZE = 8
out = np.zeros(1)
a = np.arange(SIZE)
b = np.arange(SIZE)
problem = CudaProblem(
    "Dot",
    dot_test,
    [a, b],
    out,
    [SIZE],
    threadsperblock=Coord(SIZE, 1),
    blockspergrid=Coord(1, 1),
    spec=dot_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 11 - 1D Convolution

Implement a kernel that computes a 1D convolution between `a` and `b` and stores it in `out`.
You need to handle the general case. You only need 2 global reads and 1 global write per thread.

##  Puzzle 11: 1D Convolution - Deep Dive Solution Explanation

---

###  Problem Statement
Compute 1D convolution: `out[i] = Σ a[i+j] × b[j]`

---

###  Core Concept: Convolution

```
Input a:  [0, 1, 2, 3, 4, 5]
Kernel b: [0, 1, 2]

out[0] = a[0]×b[0] + a[1]×b[1] + a[2]×b[2] = 0×0 + 1×1 + 2×2 = 5
out[1] = a[1]×b[0] + a[2]×b[1] + a[3]×b[2] = 1×0 + 2×1 + 3×2 = 8
out[2] = a[2]×b[0] + a[3]×b[1] + a[4]×b[2] = 2×0 + 3×1 + 4×2 = 11
...
```

---

###  Understanding the Halo Region

```
Block handles indices 0-7 (TPB=8), kernel size 3

Thread 7 needs: a[7], a[8], a[9]
                       ↑    ↑
                    BEYOND block's main region!

shared_a layout (size TPB + MAX_CONV = 12):
┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
│a0 │a1 │a2 │a3 │a4 │a5 │a6 │a7 │a8 │a9 │ 0 │ 0 │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘
 └────────── Main (TPB=8) ──────────┘└ Halo ─┘

Halo = extra elements at the end needed for convolution
```

---

###  Halo Loading Logic

```python
if local_i < b_size - 1:           # Only first kernel_size-1 threads
    if i + TPB < a_size:           # If element exists
        shared_a[local_i + TPB] = a[i + TPB]

For kernel size 3, threads 0 and 1 load the halo:
  Thread 0: shared_a[8] = a[8]
  Thread 1: shared_a[9] = a[9]
```

---

###  Key Takeaways

1. **Halo region**: Extra elements beyond block boundary
2. **Halo size**: kernel_size - 1
3. **Shared memory size**: TPB + halo
4. **Only some threads load halo**: First (kernel_size - 1) threads


In [None]:
def conv_spec(a, b):
    out = np.zeros(*a.shape)
    len = b.shape[0]
    for i in range(a.shape[0]):
        out[i] = sum([a[i + j] * b[j] for j in range(len) if i + j < a.shape[0]])
    return out


MAX_CONV = 4
TPB = 8
TPB_MAX_CONV = TPB + MAX_CONV
def conv_test(cuda):
    def call(out, a, b, a_size, b_size) -> None:
        shared_a = cuda.shared.array(TPB_MAX_CONV, numba.float32)
        shared_b = cuda.shared.array(MAX_CONV, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x

        if i < a_size:
            shared_a[local_i] = a[i]
        else:
            shared_a[local_i] = 0
            
        if local_i < MAX_CONV:
            if local_i < b_size:
                shared_b[local_i] = b[local_i]
            else:
                shared_b[local_i] = 0

        if local_i < b_size - 1:
            if i + TPB < a_size:
                shared_a[local_i + TPB] = a[i + TPB]

        cuda.syncthreads()

        if i < a_size:
            tmp=0.0
            for j in range(b_size):
                if i + j < a_size:
                    tmp += shared_a[local_i + j]*shared_b[j]

            out[i] = tmp
        
    return call


# Test 1

SIZE = 6
CONV = 3
out = np.zeros(SIZE)
a = np.arange(SIZE)
b = np.arange(CONV)
problem = CudaProblem(
    "1D Conv (Simple)",
    conv_test,
    [a, b],
    out,
    [SIZE, CONV],
    Coord(1, 1),
    Coord(TPB, 1),
    spec=conv_spec,
)
problem.show()

In [None]:
problem.check()

Test 2

In [None]:
out = np.zeros(15)
a = np.arange(15)
b = np.arange(4)
problem = CudaProblem(
    "1D Conv (Full)",
    conv_test,
    [a, b],
    out,
    [15, 4],
    Coord(2, 1),
    Coord(TPB, 1),
    spec=conv_spec,
)
problem.show()

In [None]:
problem.check()

## Puzzle 12 - Prefix Sum

Implement a kernel that computes a sum over `a` and stores it in `out`.
If the size of `a` is greater than the block size, only store the sum of
each block.

We will do this using the [parallel prefix sum](https://en.wikipedia.org/wiki/Prefix_sum) algorithm in shared memory.
That is, each step of the algorithm should sum together half the remaining numbers.
Follow this diagram:

![](https://user-images.githubusercontent.com/35882/178757889-1c269623-93af-4a2e-a7e9-22cd55a42e38.png)

##  Puzzle 12: Sum Reduction - Deep Dive Solution Explanation

---

###  Problem Statement
Sum all elements using PARALLEL reduction - O(log n) instead of O(n)!

---

###  Core Concept: Tree Reduction

```
Sequential sum: 8 steps for 8 elements
Parallel reduction: 3 steps for 8 elements (log₂8 = 3)

Step 1:  [0, 1, 2, 3, 4, 5, 6, 7]
           ↓  ↓  ↓  ↓
          4 threads add pairs at stride 4:
          cache[0]+=cache[4], cache[1]+=cache[5], ...
Result:  [4, 6, 8, 10, -, -, -, -]

Step 2:  [4, 6, 8, 10, -, -, -, -]
           ↓  ↓
          2 threads add pairs at stride 2:
Result:  [12, 16, -, -, -, -, -, -]

Step 3:  [12, 16, -, -, -, -, -, -]
            ↓
          1 thread adds pair at stride 1:
Result:  [28, -, -, -, -, -, -, -]

Final sum: 28 = 0+1+2+3+4+5+6+7 
```

---

###  Step-by-Step Execution

```
Initial cache: [0, 1, 2, 3, 4, 5, 6, 7]

╔═══════════════════════════════════════════════════════════════════╗
║ STRIDE = 4                                                         ║
║ Active threads: 0, 1, 2, 3 (local_i < 4)                          ║
╠═══════════════════════════════════════════════════════════════════╣
║ T0: cache[0] += cache[4]  →  0 + 4 = 4                            ║
║ T1: cache[1] += cache[5]  →  1 + 5 = 6                            ║
║ T2: cache[2] += cache[6]  →  2 + 6 = 8                            ║
║ T3: cache[3] += cache[7]  →  3 + 7 = 10                           ║
║                                                                    ║
║ cache: [4, 6, 8, 10, 4, 5, 6, 7]                                  ║
╠═══════════════════════════════════════════════════════════════════╣
║ syncthreads(), stride = 2                                         ║
╠═══════════════════════════════════════════════════════════════════╣
║ STRIDE = 2                                                         ║
║ Active threads: 0, 1 (local_i < 2)                                ║
╠═══════════════════════════════════════════════════════════════════╣
║ T0: cache[0] += cache[2]  →  4 + 8 = 12                           ║
║ T1: cache[1] += cache[3]  →  6 + 10 = 16                          ║
║                                                                    ║
║ cache: [12, 16, 8, 10, 4, 5, 6, 7]                                ║
╠═══════════════════════════════════════════════════════════════════╣
║ syncthreads(), stride = 1                                         ║
╠═══════════════════════════════════════════════════════════════════╣
║ STRIDE = 1                                                         ║
║ Active threads: 0 only (local_i < 1)                              ║
╠═══════════════════════════════════════════════════════════════════╣
║ T0: cache[0] += cache[1]  →  12 + 16 = 28                         ║
║                                                                    ║
║ cache: [28, 16, 8, 10, 4, 5, 6, 7]                                ║
╠═══════════════════════════════════════════════════════════════════╣
║ stride = 0, exit loop                                              ║
║ out[blockIdx.x] = cache[0] = 28                                   ║
╚═══════════════════════════════════════════════════════════════════╝
```

---

###  Performance: O(log n) vs O(n)

```
For n = 1024 elements:
  Sequential: 1024 additions (by one thread)
  Parallel:   10 steps (log₂1024 = 10), using 512→256→...→1 threads

For n = 1,000,000:
  Sequential: 1,000,000 steps
  Parallel:   20 steps!  (~50,000x faster conceptually)
```

---

###  Key Takeaways

1. **Halving stride**: `stride //= 2` each iteration
2. **Halving active threads**: `if local_i < stride`
3. **O(log n) steps**: Maximum parallelism utilization
4. **syncthreads in loop**: Essential for correctness!


In [None]:
TPB = 8
def sum_spec(a):
    out = np.zeros((a.shape[0] + TPB - 1) // TPB)
    for j, i in enumerate(range(0, a.shape[-1], TPB)):
        out[j] = a[i : i + TPB].sum()
    return out


def sum_test(cuda):
    def call(out, a, size: int) -> None:
        cache = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x
        
        if i < size:
            cache[local_i] = a[i]
        else:
            cache[local_i] = 0.0

        cuda.syncthreads()

        stride = cuda.blockDim.x // 2
        while stride > 0 :
            if local_i < stride:
                cache[local_i] += cache[local_i + stride]

            stride //= 2
            cuda.syncthreads()


        if local_i == 0:
            out[cuda.blockIdx.x] = cache[0]
        

    return call

# Test 1

SIZE = 8
out = np.zeros(1)
inp = np.arange(SIZE)
problem = CudaProblem(
    "Sum (Simple)",
    sum_test,
    [inp],
    out,
    [SIZE],
    Coord(1, 1),
    Coord(TPB, 1),
    spec=sum_spec,
)
problem.show()

In [None]:
problem.check()

Test 2

In [None]:
SIZE = 15
out = np.zeros(2)
inp = np.arange(SIZE)
problem = CudaProblem(
    "Sum (Full)",
    sum_test,
    [inp],
    out,
    [SIZE],
    Coord(2, 1),
    Coord(TPB, 1),
    spec=sum_spec,
)
problem.show()

In [None]:
problem.check()

In [None]:
import numpy as np

def cpu_block_sum(a, TPB):
    # Calculate how many blocks we will have
    num_blocks = (len(a) + TPB - 1) // TPB
    out = np.zeros(num_blocks)
    
    for i in range(num_blocks):
        # Extract the segment (block)
        start = i * TPB
        end = start + TPB
        segment = a[start:end]
        
        # Sum the segment and store it
        out[i] = sum(segment)
        
    return out

# Example usage:
inp = np.arange(15)
tpb = 8
print("1D Block Sum:", cpu_block_sum(inp, tpb))


## Puzzle 13 - Axis Sum

Implement a kernel that computes a sum over each column of `a` and stores it in `out`.

##  Puzzle 13: Axis Sum - Deep Dive Solution Explanation

---

###  Problem Statement
Sum along the last axis of a 2D matrix - each row reduced independently.

---

###  Core Concept: Batched Reduction

```
Input (4×6):                    Output (4×1):
Row 0: [0,  1,  2,  3,  4,  5]  →  sum = 15
Row 1: [6,  7,  8,  9, 10, 11]  →  sum = 51
Row 2: [12,13, 14, 15, 16, 17]  →  sum = 87
Row 3: [18,19, 20, 21, 22, 23]  →  sum = 123

Each row is an INDEPENDENT reduction!
```

---

###  Block Assignment

```
blockspergrid = (1, BATCH) = (1, 4)

4 blocks, each handling one row:

Block (0, 0): batch=0, handles row 0, outputs to out[0, 0]
Block (0, 1): batch=1, handles row 1, outputs to out[1, 0]
Block (0, 2): batch=2, handles row 2, outputs to out[2, 0]
Block (0, 3): batch=3, handles row 3, outputs to out[3, 0]

All 4 blocks run in PARALLEL - 4 reductions at once!
```

---

###  Key Line: Using blockIdx.y for Batch

```python
batch = cuda.blockIdx.y

cache[local_i] = a[batch, i]      # Different row per block
out[batch, cuda.blockIdx.x] = ... # Different output row per block
```

---

###  Key Takeaways

1. **blockIdx.y as batch index** - Different rows = different blocks
2. **Same reduction code** - Just parameterized by batch
3. **Rows are independent** - No inter-block communication needed


In [None]:
TPB = 8
def sum_spec(a):
    out = np.zeros((a.shape[0], (a.shape[1] + TPB - 1) // TPB))
    for j, i in enumerate(range(0, a.shape[-1], TPB)):
        out[..., j] = a[..., i : i + TPB].sum(-1)
    return out


def axis_sum_test(cuda):
    def call(out, a, size: int) -> None:
        cache = cuda.shared.array(TPB, numba.float32)
        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        local_i = cuda.threadIdx.x
        batch = cuda.blockIdx.y

        if i < size:
            cache[local_i] = a[batch, i]
        else:
            cache[local_i] = 0.0
        
        cuda.syncthreads()

        stride = cuda.blockDim.x // 2
        while stride > 0 :
            if local_i < stride:
                cache[local_i] += cache[local_i + stride]

            stride //= 2
            cuda.syncthreads()

            if local_i == 0:
                out[batch, cuda.blockIdx.x] = cache[0]

    return call


BATCH = 4
SIZE = 6
out = np.zeros((BATCH, 1))
inp = np.arange(BATCH * SIZE).reshape((BATCH, SIZE))
problem = CudaProblem(
    "Axis Sum",
    axis_sum_test,
    [inp],
    out,
    [SIZE],
    Coord(1, BATCH),
    Coord(TPB, 1),
    spec=sum_spec,
)
problem.show()

In [None]:
problem.check()

In [None]:
def cpu_axis_sum(a, TPB):
    batch_size = a.shape[0]
    num_columns = a.shape[1]
    num_blocks = (num_columns + TPB - 1) // TPB
    
    # Initialize output matrix
    out = np.zeros((batch_size, num_blocks))
    
    for b in range(batch_size):
        for i in range(num_blocks):
            start = i * TPB
            end = start + TPB
            
            # Sum the segment of the specific row 'b'
            segment = a[b, start:end]
            out[b, i] = sum(segment)
            
    return out

# Example usage:
BATCH, SIZE = 2, 6
inp_2d = np.arange(BATCH * SIZE).reshape((BATCH, SIZE))
print("2D Axis Sum:\n", cpu_axis_sum(inp_2d, 8))

## Puzzle 14 - Matrix Multiply!

Implement a kernel that multiplies square matrices `a` and `b` and
stores the result in `out`.

*Tip: The most efficient algorithm here will copy a block into
 shared memory before computing each of the individual row-column
 dot products. This is easy to do if the matrix fits in shared
 memory.  Do that case first. Then update your code to compute
 a partial dot-product and iteratively move the part you
 copied into shared memory.* You should be able to do the hard case
 in 6 global reads.

##  Puzzle 14: Matrix Multiplication - Deep Dive Solution Explanation

---

###  Problem Statement
Compute C = A × B using tiled matrix multiplication with shared memory.

---

###  Core Concept: Tiled Matrix Multiplication

**The Problem with Naive Matmul:**
```
For each C[i,j]: Load entire row of A, entire column of B
For N×N matrices: N³ global memory reads!

Tiled approach: Load small tiles, reuse from shared memory
```

---

###  The Tiling Strategy

```
For C[i,j] = Σₖ A[i,k] × B[k,j]

Instead of loading entire row/column:
1. Load TPB×TPB tile of A (around row i)
2. Load TPB×TPB tile of B (around column j)
3. Compute partial products
4. Move to next tile, repeat
5. Accumulate all partial products

Tiles slide along the K dimension:

    A                    B                    C
┌───────────────┐   ┌───────────────┐   ┌───────────┐
│ T0 │ T1 │ T2 │   │ T0 │          │   │           │
├────┼────┼────┤   ├────┤          │   │   C[i,j]  │
│    │    │    │   │ T1 │          │   │     =     │
├────┼────┼────┤   ├────┤          │   │  Σ tiles  │
│    │    │    │   │ T2 │          │   │           │
└───────────────┘   └───────────────┘   └───────────┘
  ↑ Row i tiles       ↑ Column j tiles
```

---

###  Execution for One Output Element

```
Thread at (i=0, j=0) computing C[0,0]:

Tile 0:
  Load A[0, 0:3] into a_shared[0, :]
  Load B[0:3, 0] into b_shared[:, 0]
  tmp += a_shared[0,0]*b_shared[0,0]
       + a_shared[0,1]*b_shared[1,0]
       + a_shared[0,2]*b_shared[2,0]

Tile 1:
  Load A[0, 3:6] into a_shared[0, :]
  Load B[3:6, 0] into b_shared[:, 0]
  tmp += a_shared[0,0]*b_shared[0,0]
       + a_shared[0,1]*b_shared[1,0]
       + a_shared[0,2]*b_shared[2,0]

... continue for all tiles ...

C[0,0] = tmp (complete dot product)
```

---

###  Why Tiling Is Efficient

```
Naive: For N×N matmul
  Each element needs 2N global reads
  Total: 2N³ global memory accesses

Tiled (tile size T):
  Each tile loaded once, used T times
  Total: 2N³/T global memory accesses
  
For T=32: 32x reduction in memory traffic!
```

---

###  Key Takeaways

1. **Tiles slide along K dimension** - Partial products accumulate
2. **Load cooperatively** - Each thread loads one element of each tile
3. **Compute locally** - Multiply from shared memory (fast!)
4. **Two syncthreads()** - After load AND after compute
5. **This is how cuBLAS works** - Fundamental GPU optimization


In [None]:
def matmul_spec(a, b):
    return a @ b


TPB = 3
def mm_oneblock_test(cuda):
    def call(out, a, b, size: int) -> None:
        a_shared = cuda.shared.array((TPB, TPB), numba.float32)
        b_shared = cuda.shared.array((TPB, TPB), numba.float32)

        i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
        j = cuda.blockIdx.y * cuda.blockDim.y + cuda.threadIdx.y
        local_i = cuda.threadIdx.x
        local_j = cuda.threadIdx.y

        tmp=0.0
        for tile_idx in range((size + TPB - 1)//TPB):
        
            if i < size and (tile_idx * TPB + local_j) < size:
                a_shared[local_i, local_j] = a[i, tile_idx*TPB + local_j]
            else:
                a_shared[local_i, local_j] = 0

            
            if j < size and (tile_idx * TPB + local_i) < size:
                b_shared[local_i, local_j] = b[tile_idx*TPB + local_i, j]
            else:
                b_shared[local_i, local_j] = 0

            cuda.syncthreads()

            for k in range(TPB):
                tmp += a_shared[local_i, k]*b_shared[k, local_j]

            cuda.syncthreads()

        if i < size and j < size:
            out[i, j] = tmp

    return call

# Test 1

SIZE = 2
out = np.zeros((SIZE, SIZE))
inp1 = np.arange(SIZE * SIZE).reshape((SIZE, SIZE))
inp2 = np.arange(SIZE * SIZE).reshape((SIZE, SIZE)).T

problem = CudaProblem(
    "Matmul (Simple)",
    mm_oneblock_test,
    [inp1, inp2],
    out,
    [SIZE],
    Coord(1, 1),
    Coord(TPB, TPB),
    spec=matmul_spec,
)
problem.show(sparse=True)

In [None]:
problem.check()

Test 2

In [None]:
SIZE = 8
out = np.zeros((SIZE, SIZE))
inp1 = np.arange(SIZE * SIZE).reshape((SIZE, SIZE))
inp2 = np.arange(SIZE * SIZE).reshape((SIZE, SIZE)).T

problem = CudaProblem(
    "Matmul (Full)",
    mm_oneblock_test,
    [inp1, inp2],
    out,
    [SIZE],
    Coord(3, 3),
    Coord(TPB, TPB),
    spec=matmul_spec,
)
problem.show(sparse=True)

In [None]:
problem.check()

In [None]:
def cpu_matmul(a, b, TPB=8):
    c=np.zeros((a.shape[0], b.shape[1]))
    for i in range(a.shape[0]):
        for j in range(b.shape[1]):
            for k in range(a.shape[1]):
                c[i, j] += a[i, k]*b[k, j]
    return c


a=np.array([[1,1], [2,3]])
b = np.array([[1,2,3,], [3,4,5]])


print(cpu_matmul(a,b))