# TeAAL specifications on variants of Loops' SpGEMM implementations (Thread-Mapped)

Description: Assign a fixed, constant number of work tiles to each thread. Resultant work items from each work tile are processed sequentially within the thread. 

In this version, each `tile` represents a row of $A$, and each `atom` represents a nonzero entry of $A$. In other words, assigning a thread to each row of $A$. For SpGEMM, following Gustavson's row-row dataflow (multiply a row of A and a row of B to get a partial row of Z, then merge the partial rows) as instructed in [A Programming Model for GPU Load Balancing](https://arxiv.org/abs/2301.04792)

Template: DNE in Loops

Scheduler (referenced as `config` below): https://github.com/gunrock/loops/blob/main/include/loops/schedule/thread_mapped.hxx

GPU Kernel Template (Use it as a reference, don't execute it):

In [None]:
__global__ void __thread_mapped(...) {
  for (auto A_row : config.tiles()) { # Loop over assigned rows for each tid.
    for (auto A_nz_idx : config.atoms(A_row)) { # Loop over nonzero entry of current row of A, let its col index = j
      for (auto B_nz_idx : get_B_nz(get_A_col(A_nz_idx))) { # Loop over nonzero entry of j-th row of B and multiply
        B_col = get_B_col(B_nz_idx);
        Z(A_row, B_col) += A_values[A_nz_idx] * B_values[B_nz_idx];
      }
    }
  }
}

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "../../")

from src import utils

## Initialization

Initialize the input tensors.

For simplicity, the size of a thread warp is the same as the size of a thread block (`WARP_SIZE = BLOCK_SIZE`). Suppose that each GPU SM processes 1 thread warp/block per cycle.

In [None]:
K = 4
M = 4
N = 4

# GPU Kernel Configuration
BLOCK_SIZE = 2 # Number of threads per block
GRID_SIZE = (M + BLOCK_SIZE - 1) // BLOCK_SIZE # Number of thread blocks

print(f"GPU Kernel Configuration\n \
        BLOCK_SIZE (Number of threads per block): {BLOCK_SIZE} \n \
        GRID_SIZE (Number of thread blocks): {GRID_SIZE}")

seed = 1

# SpGEMM, both A and B are sparse
A_MK = Tensor.fromRandom(rank_ids=["M", "K"], shape=[M, K], seed=seed, density=[0.9, 0.6], name="A")
#A_MK = Tensor.fromUncompressed(rank_ids=["M", "K"], root=[[0,0,8,0],[4,1,0,10],[0,4,2,0],[9,7,0,0]], shape=[M, K], name="A")
B_KN = Tensor.fromRandom(rank_ids=["K", "N"], shape=[K, N], seed=seed, density=[0.9, 0.6], name="B")

## TeAAL Specifications

Rows of matrix $A$ are partitioned across the SMs' threads. A thread can be assigned to a row with all zeros. 

Note that the current TeAAL specificaiton does not allow to specify the rank of `opt: slip`. This means there exists a synchronization across the SMs.

In [None]:
yaml = """
einsum:
  declaration:
    A: [M, K]
    B: [K, N]
    T: [K, M, N]
    Z: [M, N]
  expressions:
    - T[k, m, n] = A[m, k] * B[k, n]
    - Z[m, n] = T[k, m, n]
mapping:
  rank-order:
    A: [M, K]
    B: [K, N]
    T: [K, M, N]
    Z: [M, N]
  partitioning:
    T:
      M: [uniform_shape(BLOCK_SIZE)]
    Z:
      M: [uniform_shape(BLOCK_SIZE)]
  loop-order:
    T: [M1, M0, K, N]
    Z: [M1, M0, K, N]
    # M1: Number of partitioned rows of A = GRID_SIZE = Number of thread blocks
    # M0: Size of each partitioned row of A = BLOCK_SIZE
  spacetime:
    T:
      space: [M1, M0]
      time: [K, N] 
    Z:
      space: [M1, M0]
      time: [K, N] 
"""

utils.compile(yaml)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_matmul(A_MK, B_KN, Z_MN)

## Performance on GPU

Load Balance: Poor load balance due to the difference in NNZ per row of $A$. This results in threads that are assigned to rows with few NNZ being idle.

Assuming that the $A$, $B$, and $Z$ are stored in CSR format, the memory access pattern would be:
- A: Uncoalesced access, threads in a warp are accessing different rows of $A$.
- B: Uncoalesced access, threads in a warp are accessing rows of $B$ based on column indices of $A$'s nonzero entries.
- Z: Uncoalesced access, threads in a warp are writing to different rows of $Z$.