# TeAAL specifications on variants of Loops' SpMM implementations (Group-Mapped)

Description: Assign an equal amount of work tiles to a group of threads (warp or blocks). Threads within each group process individual work items in parallel.

In this version, each `tile` represents a row of $A$, and each `atom` represents a nonzero entry of $A$. In other words, assigning a thread block to each row of $A$.

Template: DNE in Loops

Scheduler (referenced as `config` below): https://github.com/gunrock/loops/blob/main/include/loops/schedule/group_mapped.hxx

## Imports

Import the necessary modules.

In [None]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "../..")

from src import utils

## Initialization

Initialize the input tensors.

For simplicity, the size of a thread warp is the same as the size of a thread block (`WARP_SIZE = BLOCK_SIZE`). Suppose that each GPU SM processes 1 thread warp/block per cycle.

In [None]:
K = 4
M = 4
N = 4

# Hardware Specification
NUM_SM = 2 # Number of GPU SMs 
WARP_SIZE = 2 # Number of threads per warp
NUM_THREADS = NUM_SM * WARP_SIZE # Total number of threads

print(f"Hardware Specification\n  NUM_SM: {NUM_SM}, WARP_SIZE: {WARP_SIZE}, NUM_THREADS: {NUM_THREADS}")

seed = 1

A_MK = Tensor.fromRandom(rank_ids=["M", "K"], shape=[M, K], seed=seed, density=[0.9, 0.6], name="A") # SpMM, A is sparse
B_KN = Tensor.fromRandom(rank_ids=["K", "N"], shape=[K, N], seed=seed + 1, density=[1, 1], name="B") # SpMM, B is dense

## TeAAL Specifications

Rows of matrix $A$ are partitioned across the SMs' warp/block. A thread warp/block can be assigned to a row with all zeros. 

Note that the current TeAAL specificaiton does not allow to specify the rank of `opt: slip`. This means there exists a synchronization across the SMs.

In [None]:
yaml = """
einsum:
  declaration:
    A: [M, K]
    B: [K, N]
    Z: [M, N]
  expressions:
    - Z[m, n] = A[m, k] * B[k, n]
mapping:
  rank-order:
    A: [M, K]
    B: [K, N]
    Z: [M, N]
  partitioning:
    Z:
      M: [uniform_shape(NUM_SM)]
      K: [uniform_occupancy(A.WARP_SIZE)]
  loop-order:
    Z: [M1, M0, N, K1, K0]
    # M1: Number of partitioned rows of A
    # M0: Size of each partitioned row of A = NUM_SM
    # K1: Number of partitioned nonzero elements for a given row
    # K0: Size of each partitioned nonzero elements = WARP_SIZE (Can be less than WARP_SIZE if there are less than WARP_SIZE nonzero elements left for the current partition) 
  spacetime:
    Z:
      space: [M0, K0]
      time: [M1, K1, N] 
"""

utils.compile(yaml)

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [None]:
utils.check_matmul(A_MK, B_KN, Z_MN)