# TeAAL specifications on variants of Loops' SpGEMM implementations (Thread-Mapped)

Description: Assign an equal amount of work tiles to a group of threads (warp or blocks). Threads within each group process individual work items in parallel.

In this version, each `tile` represents a row of $A$, and each `atom` represents a nonzero entry of $A$. In other words, assigning a thread block to each row of $A$. For SpGEMM, following Gustavson's row-row dataflow (multiply a row of A and a row of B to get a partial row of Z, then merge the partial rows) as instructed in [A Programming Model for GPU Load Balancing](https://arxiv.org/abs/2301.04792)

Template: DNE in Loops

Scheduler (referenced as `config` below): https://github.com/gunrock/loops/blob/main/include/loops/schedule/group_mapped.hxx

GPU Kernel Template (Use it as a reference, don't execute it):

In [None]:
__global__ void __launch_bounds__(threads_per_block, 2) __group_mapped(...) {
  # Initialize storage and schedule.
  using setup_t = schedule::block_mapped<threads_per_block, index_t, offset_t>; # Using block_mapped, assigning an entire tile of work to a thread block
  using storage_t = typename setup_t::storage_t;
  __shared__ storage_t temporary_storage;

  # Construct the schedule.
  setup_t config(temporary_storage, offsets, A_rows, A_nnz);
  auto p = config.partition(); # Assigns work tiles to each thread block

  for (auto virtual_atom : config.atom_accessor(p)) { # Loop over total work, each thread processing individual work items
    auto virtual_tile = config.tile_accessor(virtual_atom, p);

    if (!(config.is_valid_accessor(virtual_tile, p)))
      continue;

    auto A_row = config.tile_id(virtual_tile, p); # Perform a binary-search to find the tile index.
    auto A_nz_idx = config.atom_id(virtual_atom, A_row, virtual_tile, p);

    for (auto B_nz_idx : get_B_nz(get_A_col(A_nz_idx))) { # Loop over nonzero entry of j-th row of B and multiply
      B_col = get_B_col(B_nz_idx);
      Z(A_row, B_col) += A_values[A_nz_idx] * B_values[B_nz_idx];
    }
  }
}

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "../../")

from src import utils

## Initialization

Initialize the input tensors.

For simplicity, the size of a thread warp is the same as the size of a thread block (`WARP_SIZE = BLOCK_SIZE`). Suppose that each GPU SM processes 1 thread warp/block per cycle.

In [None]:
K = 4
M = 4
N = 4

# Hardware Specification
NUM_SM = 2 # Number of GPU SMs 
WARP_SIZE = 2 # Number of threads per warp
NUM_THREADS = NUM_SM * WARP_SIZE # Total number of threads

print(f"Hardware Specification\n  NUM_SM: {NUM_SM}, WARP_SIZE: {WARP_SIZE}, NUM_THREADS: {NUM_THREADS}")

seed = 1

# SpGEMM, both A and B are sparse
A_MK = Tensor.fromRandom(rank_ids=["M", "K"], shape=[M, K], seed=seed, density=[0.9, 0.6], name="A")
B_KN = Tensor.fromRandom(rank_ids=["K", "N"], shape=[K, N], seed=seed, density=[0.9, 0.6], name="B")

## TeAAL Specifications

Rows of matrix $A$ are partitioned across the SMs' threads. A thread can be assigned to a row with all zeros. 

Note that the current TeAAL specificaiton does not allow to specify the rank of `opt: slip`. This means there exists a synchronization across the SMs.

In [None]:
yaml = """
einsum:
  declaration:
    A: [M, K]
    B: [K, N]
    Z: [M, N]
  expressions:
    - Z[m, n] = A[m, k] * B[k, n]
mapping:
  rank-order:
    A: [M, K]
    B: [K, N]
    Z: [M, N]
  partitioning:
    Z:
      M: [uniform_shape(NUM_SM)]
      K: [uniform_occupancy(A.WARP_SIZE)]
  loop-order:
    Z: [M1, M0, K1, K0, N]
    # M1: Number of partitioned rows of A
    # M0: Size of each partitioned row of A = NUM_SM
    # K1: Number of partitioned nonzero elements for a given row
    # K0: Size of each partitioned nonzero elements = WARP_SIZE (Can be less than WARP_SIZE if there are less than WARP_SIZE nonzero elements left for the current partition) 
  spacetime:
    Z:
      space: [M0, K0]
      time: [M1, K1, N] 
"""

utils.compile(yaml)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_matmul(A_MK, B_KN, Z_MN)

## Performance on GPU

Load Balance: Better load balance than Thread-Mapped, since rows with high NNZ will be processed by a group of threads (smoothing out heavy rows). Additionally, there's no need to worry about load balance across the warps since SMs can simply start processing on another thread warp when one warp finishes earlier than the others.

Assuming that the $A$, $B$, and $Z$ are stored in CSR format, the memory access pattern would be:
- A: Coalesced access, threads in a warp are accessing the same row of $A$.
- B: Uncoalesced access, threads in a warp are accessing rows of $B$ based on column indices of $A$'s nonzero entries.
- Z: Coalesced access, threads in a warp are writing the same row of $Z$.