In [None]:

forward_pass = r"""
.version 7.0
.target sm_35
.address_size 64

.visible .entry neuralNet(
    .param .u64 input_ptr,       // Pointer to input vector (e.g. 784 floats)
    .param .u64 weights1_ptr,    // Weights for Layer 1 (hidden1_dim x input_dim)
    .param .u64 bias1_ptr,       // Bias for Layer 1 (hidden1_dim)
    .param .u64 hidden1_ptr,     // Output buffer for Layer 1 (hidden1_dim)
    .param .u64 weights2_ptr,    // Weights for Layer 2 (hidden2_dim x hidden1_dim)
    .param .u64 bias2_ptr,       // Bias for Layer 2 (hidden2_dim)
    .param .u64 hidden2_ptr,     // Output buffer for Layer 2 (hidden2_dim)
    .param .u64 weights3_ptr,    // Weights for Layer 3 (output_dim x hidden2_dim)
    .param .u64 bias3_ptr,       // Bias for Layer 3 (output_dim)
    .param .u64 output_ptr,      // Output buffer for final result (output_dim, e.g. 10)
    .param .u32 input_dim,       // Dimension of input vector (e.g. 784)
    .param .u32 hidden1_dim,     // Number of neurons in Layer 1
    .param .u32 hidden2_dim,     // Number of neurons in Layer 2
    .param .u32 output_dim       // Dimension of output vector (10)
)
{
    // Declare registers.
    .reg .u64  in_base, w1_base, b1_base, h1_base, w2_base, b2_base, h2_base, w3_base, b3_base, out_base;
    .reg .u32  inputDim, hidden1Dim, hidden2Dim, outputDim;
    .reg .s32  i, j, idx, idx_byte;
    .reg .u64  addr_in, addr_w, addr_bias, addr_h1, addr_hidden, addr_h2, addr_out, byte_offset;
    .reg .f32  acc, temp, in_val, wt, bias;
    .reg .pred p_exit;

    // Load kernel parameters.
    ld.param.u64   in_base, [input_ptr];
    ld.param.u64   w1_base, [weights1_ptr];
    ld.param.u64   b1_base, [bias1_ptr];
    ld.param.u64   h1_base, [hidden1_ptr];
    ld.param.u64   w2_base, [weights2_ptr];
    ld.param.u64   b2_base, [bias2_ptr];
    ld.param.u64   h2_base, [hidden2_ptr];
    ld.param.u64   w3_base, [weights3_ptr];
    ld.param.u64   b3_base, [bias3_ptr];
    ld.param.u64   out_base, [output_ptr];
    ld.param.u32   inputDim, [input_dim];
    ld.param.u32   hidden1Dim, [hidden1_dim];
    ld.param.u32   hidden2Dim, [hidden2_dim];
    ld.param.u32   outputDim, [output_dim];

    // ----------------------------------
    // Layer 1: Input -> Hidden1 (ReLU)
    // ----------------------------------
    mov.s32   j, 0;
layer1_loop:
    setp.ge.s32 p_exit, j, hidden1Dim;
    @p_exit bra layer1_end;

    mov.f32   acc, 0f00000000;  // Initialize accumulator for neuron j
    mov.s32   i, 0;
layer1_inner_loop:
    setp.ge.s32 p_exit, i, inputDim;
    @p_exit bra layer1_after_inner;

    // Compute weight index: idx = j * inputDim + i
    mul.lo.s32 idx, j, inputDim;
    add.s32    idx, idx, i;
    mul.lo.s32 idx_byte, idx, 4;        // Convert index to byte offset.
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_w, w1_base, byte_offset;
    ld.global.f32 wt, [addr_w];

    // Load input value: address = in_base + (i*4)
    mul.lo.s32 idx_byte, i, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_in, in_base, byte_offset;
    ld.global.f32 in_val, [addr_in];

    // Multiply and accumulate.
    mul.f32    temp, wt, in_val;
    add.f32    acc, acc, temp;

    add.s32    i, i, 1;
    bra        layer1_inner_loop;
layer1_after_inner:
    // Add bias: bias for neuron j is at b1_base + (j*4)
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_bias, b1_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32    acc, acc, bias;

    // Apply ReLU activation.
    max.f32    acc, acc, 0f00000000;

    // Store result in hidden layer 1: h1_base + (j*4)
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_h1, h1_base, byte_offset;
    st.global.f32 [addr_h1], acc;

    add.s32    j, j, 1;
    bra        layer1_loop;
layer1_end:

    // ----------------------------------
    // Layer 2: Hidden1 -> Hidden2 (ReLU)
    // ----------------------------------
    mov.s32   j, 0;
layer2_loop:
    setp.ge.s32 p_exit, j, hidden2Dim;
    @p_exit bra layer2_end;

    mov.f32   acc, 0f00000000;  // Accumulator for neuron j in layer 2
    mov.s32   i, 0;
layer2_inner_loop:
    setp.ge.s32 p_exit, i, hidden1Dim;
    @p_exit bra layer2_after_inner;

    // Weight index: idx = j * hidden1Dim + i
    mul.lo.s32 idx, j, hidden1Dim;
    add.s32    idx, idx, i;
    mul.lo.s32 idx_byte, idx, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_w, w2_base, byte_offset;
    ld.global.f32 wt, [addr_w];

    // Load hidden1 value: h1_base + (i*4)
    mul.lo.s32 idx_byte, i, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_hidden, h1_base, byte_offset;
    ld.global.f32 in_val, [addr_hidden];

    mul.f32    temp, wt, in_val;
    add.f32    acc, acc, temp;

    add.s32    i, i, 1;
    bra        layer2_inner_loop;
layer2_after_inner:
    // Add bias for layer 2 neuron j.
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_bias, b2_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32    acc, acc, bias;

    // Apply ReLU activation.
    max.f32    acc, acc, 0f00000000;

    // Store in hidden layer 2: h2_base + (j*4)
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_h2, h2_base, byte_offset;
    st.global.f32 [addr_h2], acc;

    add.s32    j, j, 1;
    bra        layer2_loop;
layer2_end:

    // ----------------------------------
    // Layer 3: Hidden2 -> Output (linear)
    // ----------------------------------
    mov.s32   j, 0;
layer3_loop:
    setp.ge.s32 p_exit, j, outputDim;
    @p_exit bra layer3_end;

    mov.f32   acc, 0f00000000;  // Accumulator for output neuron j
    mov.s32   i, 0;
layer3_inner_loop:
    setp.ge.s32 p_exit, i, hidden2Dim;
    @p_exit bra layer3_after_inner;

    // Weight index: idx = j * hidden2Dim + i
    mul.lo.s32 idx, j, hidden2Dim;
    add.s32    idx, idx, i;
    mul.lo.s32 idx_byte, idx, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_w, w3_base, byte_offset;
    ld.global.f32 wt, [addr_w];

    // Load hidden2 value: h2_base + (i*4)
    mul.lo.s32 idx_byte, i, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_hidden, h2_base, byte_offset;
    ld.global.f32 in_val, [addr_hidden];

    mul.f32    temp, wt, in_val;
    add.f32    acc, acc, temp;

    add.s32    i, i, 1;
    bra        layer3_inner_loop;
layer3_after_inner:
    // Add bias for output neuron j.
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_bias, b3_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32    acc, acc, bias;

    // Write final output: out_base + (j*4)
    mul.lo.s32 idx_byte, j, 4;
    cvt.u64.s32 byte_offset, idx_byte;
    add.u64    addr_out, out_base, byte_offset;
    st.global.f32 [addr_out], acc;

    add.s32    j, j, 1;
    bra        layer3_loop;
layer3_end:

    ret;
}
"""

## Neural Network Architecture

Let the input be:
$$
\mathbf{x} \in \mathbb{R}^{d} \quad \text{with } d=784 \quad (\text{MNIST image flattened})
$$

The network has three layers defined as follows:

### Layer 1 (Input to Hidden1)

- **Weights:**  
  $$
  W^{(1)} \in \mathbb{R}^{n_1 \times d}
  $$
  Each element is \( W^{(1)}_{ij} \) with \( i = 1, \dots, n_1 \) and \( j = 1, \dots, d \).

- **Biases:**  
  $$
  \mathbf{b}^{(1)} \in \mathbb{R}^{n_1}
  $$
  with components \( b^{(1)}_{i} \).

- **Pre-activation:**  
  $$
  \mathbf{z}^{(1)} = W^{(1)} \mathbf{x} + \mathbf{b}^{(1)}
  $$

- **Activation (ReLU):**  
  $$
  \mathbf{a}^{(1)} = \sigma(\mathbf{z}^{(1)}) \quad \text{with } \sigma(z) = \max(0, z)
  $$

### Layer 2 (Hidden1 to Hidden2)

- **Weights:**  
  $$
  W^{(2)} \in \mathbb{R}^{n_2 \times n_1}
  $$
  Each element is \( W^{(2)}_{ij} \) for \( i = 1, \dots, n_2 \) and \( j = 1, \dots, n_1 \).

- **Biases:**  
  $$
  \mathbf{b}^{(2)} \in \mathbb{R}^{n_2}
  $$
  with components \( b^{(2)}_{i} \).

- **Pre-activation:**  
  $$
  \mathbf{z}^{(2)} = W^{(2)} \mathbf{a}^{(1)} + \mathbf{b}^{(2)}
  $$

- **Activation (ReLU):**  
  $$
  \mathbf{a}^{(2)} = \sigma(\mathbf{z}^{(2)}) \quad \text{with } \sigma(z) = \max(0, z)
  $$

### Layer 3 (Hidden2 to Output)

- **Weights:**  
  $$
  W^{(3)} \in \mathbb{R}^{10 \times n_2}
  $$
  Each element is \( W^{(3)}_{ij} \) for \( i = 1, \dots, 10 \) and \( j = 1, \dots, n_2 \).

- **Biases:**  
  $$
  \mathbf{b}^{(3)} \in \mathbb{R}^{10}
  $$
  with components \( b^{(3)}_{i} \).

- **Pre-activation (Logits):**  
  $$
  \mathbf{z}^{(3)} = W^{(3)} \mathbf{a}^{(2)} + \mathbf{b}^{(3)}
  $$

- **Output:**  
  Often, a softmax is applied to \(\mathbf{z}^{(3)}\) to obtain probabilities:
  $$
  \hat{\mathbf{y}} = \operatorname{softmax}(\mathbf{z}^{(3)}) \quad \text{where } \hat{y}_i = \frac{e^{z^{(3)}_i}}{\sum_{k=1}^{10} e^{z^{(3)}_k}}
  $$

---

## Loss Function

Assume we use the cross-entropy loss for classification. For a single training example with true label vector \(\mathbf{y}\) (one-hot encoded), the loss is:
$$
L = -\sum_{i=1}^{10} y_i \log(\hat{y}_i)
$$

---

## Backpropagation

Let the derivative of the loss with respect to any variable \( u \) be denoted by \(\frac{\partial L}{\partial u}\). We define the error terms (or deltas) for each layer.

### Output Layer (Layer 3)

For the output layer using softmax and cross-entropy, the error is:
$$
\delta^{(3)} = \hat{\mathbf{y}} - \mathbf{y}
$$

The gradients for layer 3 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(3)}} = \delta^{(3)} \, (\mathbf{a}^{(2)})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(3)}} = \delta^{(3)}
  $$

### Hidden Layer 2 (Layer 2)

The error for layer 2 is computed by backpropagating the error from layer 3 and applying the derivative of the ReLU activation. Recall:
$$
\sigma'(z) = \begin{cases} 
1 & \text{if } z > 0, \\
0 & \text{otherwise}
\end{cases}
$$

Thus,
$$
\delta^{(2)} = \left( W^{(3)} \right)^T \delta^{(3)} \circ \sigma'(\mathbf{z}^{(2)})
$$
where “\(\circ\)” denotes element-wise multiplication.

The gradients for layer 2 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} \, (\mathbf{a}^{(1)})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(2)}} = \delta^{(2)}
  $$

### Hidden Layer 1 (Layer 1)

Similarly, backpropagate the error to the first hidden layer:
$$
\delta^{(1)} = \left( W^{(2)} \right)^T \delta^{(2)} \circ \sigma'(\mathbf{z}^{(1)})
$$

The gradients for layer 1 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(1)}} = \delta^{(1)} \, (\mathbf{x})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(1)}} = \delta^{(1)}
  $$

---

## Summary

### Forward Pass

1. **Layer 1:**
   $$
   \mathbf{z}^{(1)} = W^{(1)} \mathbf{x} + \mathbf{b}^{(1)}, \quad \mathbf{a}^{(1)} = \max(0, \mathbf{z}^{(1)})
   $$

2. **Layer 2:**
   $$
   \mathbf{z}^{(2)} = W^{(2)} \mathbf{a}^{(1)} + \mathbf{b}^{(2)}, \quad \mathbf{a}^{(2)} = \max(0, \mathbf{z}^{(2)})
   $$

3. **Layer 3:**
   $$
   \mathbf{z}^{(3)} = W^{(3)} \mathbf{a}^{(2)} + \mathbf{b}^{(3)}, \quad \hat{\mathbf{y}} = \operatorname{softmax}(\mathbf{z}^{(3)})
   $$

### Backward Pass

1. **Output Layer (Layer 3):**
   $$
   \delta^{(3)} = \hat{\mathbf{y}} - \mathbf{y}
   $$
   $$
   \frac{\partial L}{\partial W^{(3)}} = \delta^{(3)} (\mathbf{a}^{(2)})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(3)}} = \delta^{(3)}
   $$

2. **Hidden Layer 2 (Layer 2):**
   $$
   \delta^{(2)} = \left( W^{(3)} \right)^T \delta^{(3)} \circ 1_{\{\mathbf{z}^{(2)} > 0\}}
   $$
   $$
   \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} (\mathbf{a}^{(1)})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(2)}} = \delta^{(2)}
   $$

3. **Hidden Layer 1 (Layer 1):**
   $$
   \delta^{(1)} = \left( W^{(2)} \right)^T \delta^{(2)} \circ 1_{\{\mathbf{z}^{(1)} > 0\}}
   $$
   $$
   \frac{\partial L}{\partial W^{(1)}} = \delta^{(1)} (\mathbf{x})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(1)}} = \delta^{(1)}
   $$

