In [None]:
# Forward pass of the kernel

import torch
import triton
import triton.language as tl

@triton.jit
def _attn_fwd_inner(
    O_block,
    l_i,
    m_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    block_index_q,
    softmax_scale,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1: # LEft part of diagonal
        # From 0 to the left of the diagonal
        lo, hi = 0, block_index_q * BLOCK_SIZE_Q
    elif STAGE == 2: # Exatly Along the diagonal  
        # Used only for the block in which there is transition between non-masked and masked keys
        lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q
        lo = tl.multiple_of(lo, BLOCK_SIZE_Q) 
    else:
        # Only used for non-causal attention
        lo, hi = 0, SEQ_LEN

    K_block_ptr = tl.advance(K_block_ptr, (0, lo)) 
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

    # loop over k, v and update accumulator 
    for start_kv in range(lo, hi, BLOCK_SIZE_KV):
        # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations
        start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV)

        # -- compute qk ----
        K_block = tl.load(K_block_ptr)
        QK_block = tl.dot(Q_block, K_block)

        if STAGE == 2: 
            # Mask is applied when idx_q > indx_k,v
            mask = offs_q[:, None] >= (start_kv + offs_kv[None, :])
            QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1))
            QK_block -= m_ij[:, None]
        else:
            # Compute the maximum value of qk or keep the old max value
            m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale)
            QK_block = QK_block * softmax_scale - m_ij[:, None]

        # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij)
        P_block = tl.math.exp(QK_block)
        # Compute the sum by rows of the attention scores
        l_ij = tl.sum(P_block, 1)

        # This is the correction factor for the previous l_i
        alpha = tl.math.exp(m_i - m_ij) # previous estimate - current estimate
        # Apply the correction factor to the previous l_i and add the new l_ij
        l_i = l_i * alpha + l_ij

        V_block = tl.load(V_block_ptr)
        P_block = P_block.to(tl.float16)
        # This computes the following: O_new = P x V + O_old * alpha
        O_block = O_block * alpha[:, None]
        O_block = tl.dot(P_block, V_block, O_block) # O_block += P_block @ V_block

        m_i = m_ij

        # Move to the next block of K and V
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0)) # V[Seq_Len, HEAD_DIM]
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV)) # K[HEAD_DIM, Seq_Len]
    return O_block, l_i, m_i


@triton.jit
def _attn_fwd(
    Q,  # [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM] # Q[index_batch, index_head, :, :]
    K,  # [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM] # K[index_batch, index_head, :, :]
    V,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    softmax_scale,
    M,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN
    O,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
):
    tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)

    # This indicate which block in the sequence length to process
    block_index_q = tl.program_id(0)

    # This indicates which head and batch to process. Each program is associated with a single head of single batch
    index_batch_head = tl.program_id(1)
    # This indicates which batch this program is associated with (each batch has NUM_HEADS heads)
    index_batch = index_batch_head // NUM_HEADS # Select the right Batch  
    # This indicate the poisition of the head in the batch # Select the right Head
    index_head = index_batch_head % NUM_HEADS

    # This allows to get the (SEQ_LEN, HEAD_DIM) block of Q, K, V by selecting indexing it by batch and head
    qkv_offset = (
       index_batch.to(tl.int64)* stride_Q_batch # Q[index_batch * stride_Q_batch, :, :, :]
       + index_head.to(tl.int64) * stride_Q_head # Q[index_batch * stride_Q_batch + index_head * stride_Q_head, :, :]
    )

    # We are in Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q :, : ]
    Q_block_ptr = tl.make_block_ptr(# Currently pointing the perticular program to be working with
      base= Q + qkv_offset, # Q[index_batch, index_head, :, :]
      shape=(SEQ_LEN, HEAD_DIM), 
      strides=(stride_Q_seq, stride_Q_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1, 0),
    )
  
    # We are in V[index_batch, index_head, :, :]
    V_block_ptr = tl.make_block_ptr( # V[index_batch, index_head, :, :]
      base= V + qkv_offset, 
      shape=(stride_V_seq, stride_V_dim),
      offsets=(0,0),
      block_shape = (BLOCK_SIZE_KV, HEAD_DIM),
      order=(1, 0),
    )

    # We are in  K[index_batch, index_head, :, :]
    """
    Actually it won't be selecting `everything that is inside` but only the number of elements indicated
    by the `block_shape` parameter of each pointer block. You can consider each pointers block to be
    a tensor of pointers with the shape indicated by the param `block_shape`
    """
    K_block_ptr = tl.make_block_ptr(
      base = K + qkv_offset,
      shape=(HEAD_DIM, SEQ_LEN),
      strides=(
        stride_K_dim,
        stride_K_seq,
      ), # We invert the strides w.r.t Q, so we can transpose the matrix
      offsets=(0,0),
      block_shape=(HEAD_DIM, BLOCK_SIZE_KV),
      order=(0,1),
    )

    # In this the selection of the pointer should exactly indicate the right pointer for writing
    # Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q :, :]
    O_block_ptr = tl.make_block_ptr(
      base= O + qkv_offset,
      shape=(SEQ_LEN, HEAD_DIM),
      strides=(stride_O_seq, stride_O_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1,0),
    )

    # offs_q: the offsets for the tokens in the Q to process
    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) 
    # Suppose program=0, block_size=4, Q[0, 1, 2, 3], Suppose program=3, block_size=4, Q[13, 14, 15, 16]
    """Each block of query is made up of block_size_q no of Queries. Each Q is a token and its
    dimention is not all the token but only the part of the head_dim """
    # offs_kv: the offsets for the tokens in the K and V sequence to process
    """We don't skip any values like Q here bcs we are going to multiply the whole K and V with the Q"""
    offs_kv = tl.arange(0, BLOCK_SIZE_KV)
    # For KV Suppose block_size = 4  -> [0, 1, 2, 3]
    # m_i : the running maximum of each row. We have one for each query
    m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf")
    # l_i: the running sum. We have one for each query (as we sum the attention scores by rows)
    l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 # here +1 is to make the log stable
    # acc: the accumilator for the output, which is a group of rows of the O matrix
    O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32)

    # load the blocks of Q: it will stay in SRAM throughout
    Q_block = tl.load(Q_block_ptr)
 
    # Stage: 3 if casual else 1
    if STAGE == 1 or STAGE == 3:
      # This step runs for non-casual attention or for the blocks to the left of the diagonal in the casual attention
      O_block, l_i, m_i = _attn_fwd_inner(
        O_block,
        l_i,
        m_i,
        Q_block,
        K_block_ptr,
        V_block_ptr,
        block_index_q,
        softmax_scale,
        BLOCK_SIZE_Q,
        BLOCK_SIZE_KV,
        4 - STAGE,
        offs_q,
        offs_kv,
        SEQ_LEN,
      )

    if STAGE == 3:
      # This step runs for 
      O_block, l_i, m_i = _attn_fwd_inner(
         O_block,
            l_i,
            m_i,
            Q_block,
            K_block_ptr,
            V_block_ptr,
            block_index_q,
            softmax_scale,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_KV,
            2,
            offs_q,
            offs_kv,
            SEQ_LEN,
      )
      # epilogue
      m_i += tl.math.log(
         l_i
      ) # This is needed to compujte the logsumexp for the backward pass
      O_block = O_block / l_i[:, None]
      m_ptrs = M + index_batch_head * SEQ_LEN + offs_q
      tl.store(m_ptrs, m_i)
      tl.store(O_block_ptr, O_block.to(O.type.element_ty))


@triton.jit
def _attn_bwd_preprocess(
   O,
   dO,
   Di , # [BATCH_SIZE, NUM_HEADS, SEQ_LEN]
   SEQ_LEN,
   BLOCK_SIZE_Q: tl.constexpr, # Ex: 4
   HEAD_DIM: tl.constexpr,
):
   # Axis 0: Will indicate the BLOCK_INDEX What is the Block of vectors of "O" this program will work with
   block_index_q = tl.program_id(0)

   # For this we need to skip some Q vector that other program process
   # index: 0 * 4 + [0, 1, 2, 3] = [0, 1, 2, 3] | index: 1 * 4 + [0, 1, 2, 3] = [4, 5, 6, 7]
   offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) # Which block of O we are going to work with
   # Axis 1: Which batch and which Head of each batch it is going to work with
   index_batch_head = tl.program_id(1) 
   # Load al the dimentions of each vector
   offs_dim = tl.arange(0, HEAD_DIM)
   
   # Offsets by Hand
   # Load a single block of BLOCK_SIZE_Q rows of O
   O_block = tl.load( # O [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM]
      O
      + index_batch_head * HEAD_DIM * SEQ_LEN # O[index_batch_head * HEAD_DIM * SEQ_LEN, :, :]
      + offs_q[:, None] * HEAD_DIM
      + offs_dim[None, :]
   ) # [BLOCK_SIZE_Q, HEAD_DIM]

   # Load a single block of BLOCK_SIZE_Q rows of dO
   dO_block = tl.load( # dO [BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM]
      dO
      + index_batch_head * HEAD_DIM * SEQ_LEN
      +  offs_q[:, None] * HEAD_DIM
      + offs_dim[None, :]
   ).to(tl.float32)

   # Compute the Di block with element wise product of dO and O
   Di_block = tl.sum(dO_block * O_block, axis=1) # Shape: [BLOCK_SIZE_Q,]
   # Store the Di block
   Di_block_ptrs = Di + index_batch_head * SEQ_LEN + offs_q
   tl.store(Di_block_ptrs, Di_block)

@triton.jit
def _attn_bed_dq(
   Q,
   K,
   V,
   softmax_scale,
   dO,
   dK,
   dV,
   M,
   D,
   stride_batch,
   stride_head,
   stride_seq,
   stride_dim,
   NUM_HEADS,
   SEQ_LEN,
   BLOCK_Q: tl.constexpr,
   BLOCK_KV: tl.constexpr,
   HEAD_DIM: tl.constexpr,
   STAGE: tl.constexpr,
):
   index_batch_head = tl.program_id(2)
   index_batch = index_batch_head // NUM_HEADS
   index_head = index_batch_head % NUM_HEADS
   offset_batch_head = (stride_batch * index_batch + ) 


class TritonAttention(torch.autograd.Function):
  @staticmethod
  def forward(ctx, Q, K, V, casual, softmax_scale):
    HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
    HEAD_DIM_V = V.shape[-1]
    
    BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape

    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_V

    O = torch.empty_like(Q)
    stage = 3 if casual else 1

    grid = lambda args: (
        # ceil(SEQ_LEN / BLOCK_SIZE_Q) = How many blocks of Q we have
        triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), # Which group of Queries are we going to work with ?
        BATCH_SIZE * NUM_HEADS, # Which head of which batch element are we going to work with ? --- ^
        1, # Z is the CUDA launch grid
    )

    # Number of parallel programs or kernels : (BATCH_SIZE * NUM_HEADS * NUM_BLOCKS_Q)
    
    # M is logsumexp for the backward pass, one for each query
    M = torch.empty(
        (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
    )

    _attn_fwd[grid](
        Q=Q, # Query 
        K=K, # Key
        V=V, # Value
        softmax_scale=softmax_scale, # 1/sqrt(HEAD_DIM)
        M=M, # Memory Block (L in psudo code of the paper)
        O=O, # Output
        stride_Q_batch=Q.stride(0), # 
        stride_Q_head=Q.stride(1),
        stride_Q_seq=Q.stride(2),
        stride_Q_dim=Q.stride(3),
        stride_K_batch=K.stride(0),
        stride_K_head=K.stride(1),
        stride_K_seq=K.stride(2),
        stride_K_dim=K.stride(3),
        stride_V_batch=V.stride(0),
        stride_V_head=V.stride(1),
        stride_V_seq=V.stride(2),
        stride_V_dim=V.stride(3),
        stride_O_batch=O.stride(0),
        stride_O_head=O.stride(1),
        stride_O_seq=O.stride(2),
        stride_O_dim=O.stride(3),
        BATCH_SIZE=Q.shape[0],
        NUM_HEADS=Q.shape[1],
        SEQ_LEN=Q.shape[2],
        HEAD_DIM=HEAD_DIM_K,
        STAGE=stage,
    )

    ctx.save_for_backward(Q, K, V, O, M)
    ctx.grid = grid
    ctx.softmax_scale = softmax_scale
    ctx.HEAD_DIM = HEAD_DIM_K
    ctx.casual = casual
    return 0

  @staticmethod
  def backward(ctx, dO):
     Q, K, V, O, M = ctx.saved_tensors

     assert dO.is_contiguous()
     assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
     dQ = torch.empty_like(Q)
     dK = torch.empty_like(K)
     dV = torch.empty_like(V)

     BATCH_SIZE, NUM_HEADS, SEQ_LEN = Q.shape[:3]
     NUM_WARPS, NUM_STAGES = 4, 3
     BLOCK_SIZE_MICRO, BLOCK_SIZE_MACRO = 32, 128
     # Precompute all the Di elements
     # Launch Grid for each batch and each head
     preprocess_grid = (SEQ_LEN // BLOCK_SIZE_MICRO, BATCH_SIZE * NUM_HEADS)
     Di = torch.empty_like(M) # Shape: (BATCH_SIZE, NUM_HEADS, SEQ_LEN)

     # Compute all the elements Di
     _attn_bwd_preprocess[preprocess_grid](
        O=O,
        dO=dO,
        Di=Di,
        SEQ_LEN=SEQ_LEN,
        BLOCK_SIZE_Q=BLOCK_SIZE_MICRO,
        HEAD_DIM=ctx.HEAD_DIM,
     )

     grid = (SEQ_LEN // BLOCK_SIZE_MACRO, 1, BATCH_SIZE * NUM_HEADS)
     
     stage = 3 if ctx.casual else 1

     # Pivit the KV and iterate over all the Q blocks
     _attn_bkwd_dk_dv[grid](
        Q=Q,
        K=K,
        V=V,
        softmax_scale=ctx.softmax_scale,
        dO=dO,
        dQ=dQ,
        dk=dK,
        dV=dV,
        M=M,
        Di=Di,
        stride_batch=Q.stride(0),
        stride_head=Q.stride(1),
        stride_seq=Q.stride(2),
        stride_dim=Q.stride(3),
        NUM_HEADS=NUM_HEADS,
        SEQ_LEN=SEQ_LEN,
        BLOCK_Q=BLOCK_SIZE_MACRO,
        BLOCK_KV=BLOCK_SIZE_MICRO,
        HEAD_DIM=ctx.HEAD_DIM,
        STAGE=stage,
        num_warps=NUM_WARPS,
        num_stages=NUM_STAGES,
     )

     


def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, casual, dtype=torch.float16):
  Q = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  K = ( 
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  V = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  
  softmax_scale = 1 / (HEAD_DIM**0.5) # Q K^T / sqrt(HEAD_DIM)
  dO = torch.randn_like(0) # Needed for backward pass 

  # reference implementation (naive implimentation wrt CUDA and PyTorch)
  MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device="cuda"))
  P = torch.mathmul(Q, K.transpose(2, 3)) * softmax_scale
  if casual:
    P[:, :, MASK == 0] = float("-inf")
  P = torch.softmax(P.float("inf")).half()
  ref_O = torch.matmul(P, V)
  ref_O .backward()
  ref_dV, V.grad = V.grad.clone(), None
  ref_dK, K.grad = K.grad.clone(), None
  ref_dQ, Q.grad = Q.grad.clone(), None

  # triton implimentation 
  tri_out = TritonAttention.apply(Q,K,V, casual, softmax_scale).half()
  tri_out.backward(dO)
  tri_dV, V.grad = V.grad.clone(), None
  tri_dK, K.grad = K.grad.clone(), None
  tri_dQ, Q.grad = Q.grad.clone(), None

  # Compare
  rtrol = 0.0
  atol = 1e-2 
  # Absolute Diffrence
  assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dK, tri_dK, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dV, tri_dV, atol=atol, rtol=rtol)
  assert torch.allclose(ref_dQ, tri_dQ, atol=atol, rtol=rtol)

In [None]:
import torch
import triton
import triton.language as tl


@triton.jit
def _attn_fwd_inner(
    Q,
    K,
    V, 
    softmax_scale,
    M,
    O,
    stride_Q_batch, stride_Q_head, stride_Q_seq, stride_Q_dim,
    stride_K_batch, stride_K_head, stride_K_seq, stride_K_dim,
    stride_V_batch, stride_V_head, stride_V_seq, stride_V_dim,
    stride_O_batch, stride_O_head, stride_O_seq, stride_O_dim,
    BATCH_SIZE, 
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr, 
):
    tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)

    # This indicate which block in the sequence length to process
    block_index_q = tl.program_id(0)

    # This indicates which head and batch to process. Each program is associated with a single head of single batch
    index_batch_head = tl.program_id(1)
    # THis indicate which batch this program is associated with (each batch has NUM_HEADS heads)
    index_batch = index_batch_head // NUM_HEADS
    # This indicate the position of the head in the batch
    index_head = index_batch_head % NUM_HEADS

    # This allows to get the (N_CTX, HEAD_DIM) block in Q, K, V by selecting indexing it by batch and head
    qkv_offset = (
        index_batch.to(tl.int64) * stride_Q_batch # Q[index_batch * stride_Q_batch, :, :, :]
        + index_head.to(tl.int64) * stride_Q_head # Q[index_batch * stride_Q_batch + index_head * stride_Q_head, :, :]
    )

    # We are in Q[index_batch, index_head, block_index_q * BLOCK_SIZE_Q :, :]
    Q_block_ptr = tl.make_block_ptr(
        base = Q + qkv_offset, # Q[index_batch, index_head, :, :]
        shape = (SEQ_LEN, HEAD_DIM),
        strides = (stride_Q_seq, stride_Q_dim),
        offsets = (block_index_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
        order=(1,0),
    )

    V_block_ptr = tl.make_block_ptr(
        base= V + qkv_offset,
        shape=(stride_V_seq, stride_V_dim),
        offsets=(0,0),
        block_shape=(BLOCK_SIZE_KV, HEAD_DIM),
        order=(1,0),
    )

    K_block_ptr = tl.make_block_ptr(
        base = K + qkv_offset,
        shape = (HEAD_DIM, SEQ_LEN),
        str
    )


    


class TritonAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, casual, softmax_scale):
        HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
        HEAD_DIM_V = V.shape[-1]

        BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape

        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_V

        O = torch.empty_like(Q)
        stage = 3 if casual else 1

        grid = lambda args: (
            # ceil(SEQ_LEN / BLOCK_SIZE_Q) = How many blocks of Q we have
            triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]), # Which group of Queries are we going to work with ?
            BATCH_SIZE * NUM_HEADS, # Which head of which batch element are we going to work with ?
            1, # Z is the CUDA launch grid
        )

        # Number of parallel programs or kernels : (BATCH_SIZE * NUM_HEADS * NUM_BLOCKS_Q)

        # M is logsumexp for the backward pass, one for each query
        M = torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
        )

        _attn_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=softmax_scale,
            M=M,
            O=O,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_dim=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_dim=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_dim=V.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_dim=O.stride(3),
            BATCH_SIZE=Q.shape[0],
            NUM_HEADS=Q.shape[1],
            SEQ_LEN=Q.shape[2],
            HEAD_DIM=HEAD_DIM_K,
            STAGE=stage,
        )

        ctx.save_for_backward(Q, K, V, O, M)
        ctx.grid = grid
        ctx.softmax_scale = softmax_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.casual = casual
        return 0
    
    



def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, casual, dtype=torch.float16):
    Q = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )
    K = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )
    V = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )

    softmax_scale = 1 / (HEAD_DIM ** 0.5)  # Q K^T / sqrt(HEAD_DIM)
    dO = torch.randn_like(Q) # Output

    # reference implimentation
    MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device="cuda"))
    P = torch.matmul(Q, K.transpose(2, 3)) * softmax_scale
    if casual:
        P[:, :, MASK == 0] = float("-inf")
    P = torch.softmax(P.float(), dim=-1).half()
    ref_O = torch.matmul(P, V)
    ref_O.backward(dO)
    ref_dV, V.grad = V.grad.clone(), None
    ref_dK, K.grad = K.grad.clone(), None
    ref_dQ, Q.grad = Q.grad.clone(), None

    # triton implimentation
    tri_out = TritonAttention.apply(Q, K, V, casual, softmax_scale).half()
    tri_out.backward(dO)
    tri_dV, V.grad = V.grad.clone(), None
    tri_dK, K.grad = K.grad.clone(), None
    tri_dQ, Q.grad = Q.grad.clone(), None

    # Compare
    rtol = 0.0
    atol = 1e-2
    assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dK, tri_dK, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dV, tri_dV, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dQ, tri_dQ, atol=atol, rtol=rtol)

if __name__ == "__main__":
    test_op(BATCH_SIZE=8, NUM_HEADS=8, SEQ_LEN=1024, HEAD_DIM=64, casual=True)
    test_op(BATCH_SIZE=8, NUM_HEADS=8, SEQ_LEN=1024, HEAD_DIM=64, casual=False)
    print("PASSED")

In [1]:
from openai import OpenAI
from pydantic import BaseModel
import json

# Define a simple schema
class CarDescription(BaseModel):
    brand: str
    model: str
    car_type: str

# Get the schema
json_schema = CarDescription.model_json_schema()

# Print the schema we're using
print("Schema being used:")
print(json.dumps(json_schema, indent=2))
print("\n")

# Initialize client
client = OpenAI(
    base_url="https://dev1-trusttgpt-ccep.trustt.com/v1",
    api_key="EMPTY",
)

# Make the request
response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
        {
            "role": "user",
            "content": "Generate a JSON with the brand, model, and car_type of the most iconic car from the 90's",
        }
    ],
    extra_body={"guided_json": json_schema},
)

# Print the response
print("Response:")
print(response.choices[0].message.content)

Schema being used:
{
  "properties": {
    "brand": {
      "title": "Brand",
      "type": "string"
    },
    "model": {
      "title": "Model",
      "type": "string"
    },
    "car_type": {
      "title": "Car Type",
      "type": "string"
    }
  },
  "required": [
    "brand",
    "model",
    "car_type"
  ],
  "title": "CarDescription",
  "type": "object"
}


Response:
{ "brand": "Ford", "model": "Mustang GT", "car_type": "Muscle Car" }


In [2]:
from openai import OpenAI
from pydantic import BaseModel
import json
from typing import List, Dict

# Define the schema for loan offer response
class LoanOfferResponse(BaseModel):
    has_customer_selected_tenure: str
    selected_tenure_value: str
    has_customer_selected_loan_amount: str
    loan_amount_value: str
    loan_details_verified: str
    assistant_response: str
    other_flag: str

# Test parameters
minimum_amount = 10000
approvedOfferAmount = 500000
applicableTenureList = [12, 24, 36, 48, 60]

# Create the system prompt with the parameters
system_prompt=f"""You are TrusttGPT, a JSON Response generator for Pre-Approved Offer for the Loan Origination process . Your role is to find whether customer intent is related to pre-approved loan offer and guide the customer through the process of selecting the loan amount and tenure.

All the values should be in the JSON format.
{{
  "has_customer_selected_tenure": "",
  "selected_tenure_value": "",
  "has_customer_selected_loan_amount":"" ,
  "loan_amount_value": "",
  "loan_details_verified": "",
  "assistant_response": "",
  "other_flag":"" 
}}
Rules to follow:
Important: The values provided by the customer may just be numericals without any text. Just extract the values and set the corresponding values based on the customer inputs.

1.loan_amount_value: Extract the loan amount provided by the customer
If tenure is in years (single-digit), convert to months (1 year = 12 months). Ensure the final value is in {applicableTenureList}. If not in the list, prompt the customer to select a valid tenure.
3.has_customer_selected_tenure & has_customer_selected_loan_amount: "yes" or "no" based on whether values are selected.
4.loan_details_verified: yes if the customer has verified the details otherwise no.
5.assistant_response: Guide the customer based on their inputs, ensuring they select values
6.other_flag: "yes" or "no" based on whether the customer wants to update the values, change the values .

FLOW:
- The customer can provide the loan amount and tenure values, then set the corresponding ** loan_amount_value ** and ** selected_tenure_value ** to the provided values and ** other_flag ** to "no".
- Customer can provide the loan amount and tenure values in a single message or in multiple messages.
-- the tenure provided by the customer should be in months. If the tenure is explicitly provided in years, only convert it to months and set the ** selected_tenure_value ** to the converted value.
- If the customer provides the loan amount then extract the loan amount, set corresponding ** loan_amount_value **to the extracted value and  ask the customer to provide the tenure value and set ** other_flag ** to "no".
- If the customer provides both loan amount and tenure values, then prompt the customer to verify the details and set ** other_flag ** to "no".
- If the customer verifies the details by saying affirmations like "yes","confirm","proceed" etc, set ** loan_details_verified ** to "yes"
- If the customer wants to update the values, set ** other_flag ** to "yes" and prompt the customer to provide the updated values.
- If the customer wants to know about the offer details or any other information, provide the necessary details based on the approved offer amount from {minimum_amount} to {approvedOfferAmount} and tenure from {applicableTenureList} and set ** other_flag ** to "yes" for Q and A. Strictly do not use the given range to validate the values.
    For example:
        - If the customer ask how much loan amount can i get, can't i get more loan amount, set ** other_flag ** to "yes" and provide the details based on the approved offer details.
        - If the customer asks can i get a loan of << value greater than the approved offer amount >>, set ** other_flag ** to "yes" and prompt the approved offer details from the given amount range.
        - Similarly, if the customer ask about the tenure, set ** other_flag ** to "yes" and provide the details based on the approved tenure list.
        - If the customer opted for loan details updation and provides the updated values, then stricltly do not validate the loan details set ** other_flag ** to "no" and also set the corresponding ** loan_amount_value ** and ** selected_tenure_value ** to the provided values irrespective of the given range.
- If the customer changes the values, then extract the new values and set ** other_flag ** to "no" and prompt the customer according to the flow.
- Strictly do not validate the values, only extract the values regardless of the validity. The validation will be done by the system.
- Do not make up assistance responses on your own. Strictly follow the flow and provide the responses based on the customer inputs.
IMPORTANT:
- ** STRICTLY DO NOT VALIDATE THE VALUES **
- ** STRICTLY SET other_flag to "no" WHEN YOU EXTRACT LOAN AMOUNT AND TENURE VALUES**
- ** STRICTLY SET THE CORRESPONDING VALUES BASED ON THE CUSTOMER INPUTS**
- Strictly go through the rules and follow the flow to guide the customer through the process.
- Do not validate the values, only extract the values provided by the customer. System will validate the values and will provide you the context accordingly. and set ** other_flag ** to "no".
- Do not provide any extra information or details on your own.
"""

# Get the schema
json_schema = LoanOfferResponse.model_json_schema()

# Initialize client
client = OpenAI(
    base_url="https://dev1-trusttgpt-ccep.trustt.com/v1",
    api_key="EMPTY",
)

class ConversationManager:
    def __init__(self):
        self.conversation_history: List[Dict] = [
            {"role": "system", "content": system_prompt}
        ]
        self.previous_responses = []
    
    def get_loan_response(self, user_message: str):
        try:
            # Add user message to conversation history
            self.conversation_history.append({"role": "user", "content": user_message})
            
            # Make the request with full conversation history
            response = client.chat.completions.create(
                model="meta-llama/Meta-Llama-3-8B-Instruct",
                messages=self.conversation_history,
                extra_body={"guided_json": json_schema},
            )
            
            # Get and parse the response
            response_text = response.choices[0].message.content
            parsed_response = json.loads(response_text)
            
            # Validate with Pydantic
            validated_response = LoanOfferResponse(**parsed_response)
            
            # Store the response
            self.previous_responses.append(validated_response)
            
            # Add assistant response to conversation history
            self.conversation_history.append({
                "role": "assistant",
                "content": response_text
            })
            
            return validated_response
            
        except Exception as e:
            print(f"Error: {e}")
            return None
    
    def print_conversation_state(self):
        print("\nConversation State:")
        print("-" * 50)
        for i, response in enumerate(self.previous_responses, 1):
            print(f"\nStep {i}:")
            print(json.dumps(response.model_dump(), indent=2))

# Example usage in Jupyter notebook
# Initialize conversation
conversation = ConversationManager()

# Test sequential interactions
def test_interaction(message: str):
    print(f"\nUser: {message}")
    print("-" * 50)
    response = conversation.get_loan_response(message)
    if response:
        print("Latest Response:")
        print(json.dumps(response.model_dump(), indent=2))
    return response

In [4]:
# Test sequential conversation
# Step 1: Initial loan amount request
test_interaction("I want a loan of 200000")

# Step 2: Tenure selection
test_interaction("I want it for 2 years")

# Step 3: Verification
test_interaction("Yes, confirm these details")

# View entire conversation state
conversation.print_conversation_state()

# Continue with more interactions
test_interaction("Can I get more loan amount?")
test_interaction("I want to change the loan amount to 300000")

# View updated conversation state
conversation.print_conversation_state()


User: I want a loan of 200000
--------------------------------------------------
Latest Response:
{
  "has_customer_selected_tenure": "true",
  "selected_tenure_value": "2",
  "has_customer_selected_loan_amount": "true",
  "loan_amount_value": ">=200000",
  "loan_details_verified": "false",
  "assistant_response": "We've previously reviewed your loan application for a loan amount of 200000 and tenure of 2 years. Based on that, we can offer you a loan with an interest rate of 8.5% per annum. Your estimated monthly installment would be approximately 10000. Would you like to proceed with the loan application?",
  "other_flag": "false"
}

User: I want it for 2 years
--------------------------------------------------
Latest Response:
{
  "has_customer_selected_tenure": "true",
  "selected_tenure_value": "2",
  "has_customer_selected_loan_amount": "true",
  "loan_amount_value": ">=200000",
  "loan_details_verified": "true",
  "assistant_response": "Congratulations! Your loan application has