# 1. Neural Network, CUDA, and PTX – Forward Propagation

In this project we implement a simple neural network with one hidden layer. The network architecture is as follows:
- **Input Layer:** Receives a flattened MNIST image of dimension $784$ (i.e. $28 \times 28$ pixels).
- **Hidden Layer:** Consists of $128$ neurons with ReLU activation.
- **Output Layer:** Contains $10$ neurons (one per digit class) with a linear activation.

**CUDA and PTX Overview:**  
CUDA is NVIDIA’s parallel computing platform and API that allows developers to run massively parallel computations on GPUs. CUDA code is compiled into PTX (Parallel Thread Execution), which is a low-level, assembly-like language for NVIDIA GPUs. PTX is then just-in-time (JIT) compiled by the GPU driver into device-specific binary instructions. This low-level control can be leveraged to optimize performance-critical parts of an algorithm.

**Forward Propagation in PTX:**  
In the PTX (or higher-level CUDA C) implementation, forward propagation is performed as follows:
1. **Layer 1 Computation:**  
   Each neuron in the hidden layer computes  
   $$ a^{(1)} = \max\Big(0, \sum_{j=1}^{784} W^{(1)}_{ij} \, x_j + b^{(1)}_i \Big) $$
   where each thread processes one sample and loops over the input elements, performing multiplication, summation, and finally applying the ReLU activation ($\max(0,\cdot)$).
2. **Output Layer Computation:**  
   The output layer computes  
   $$ a^{(2)} = \sum_{j=1}^{128} W^{(2)}_{ij} \, a^{(1)}_j + b^{(2)}_i $$
   again using loops to perform the dot products.

In PTX, these computations are implemented by carefully managing pointer arithmetic, register usage, and ensuring that each thread handles one sample from the mini-batch.

---

# 2. Backpropagation Computation in PTX

Backpropagation for our network is designed to compute gradients for both layers using a squared-error loss with one-hot targets. The algorithm is as follows:

1. **Output Layer Backpropagation:**  
   The error (delta) at the output is computed by  
   $$ \delta^{(2)} = a^{(2)} - y $$
   where $y$ is the one-hot target vector. The gradients for the output biases are then given by $\delta^{(2)}$, and the gradient for the output weights is computed as  
   $$ \nabla W^{(2)} = \delta^{(2)} \otimes a^{(1)} $$
2. **Hidden Layer Backpropagation:**  
   The error for the hidden layer is computed by backpropagating the output error through the weights and applying the derivative of the ReLU activation:
   $$ \delta^{(1)} = \Big( (W^{(2)})^T \, \delta^{(2)} \Big) \odot \mathbf{1}_{\{a^{(1)} > 0\}} $$
   where $\mathbf{1}_{\{a^{(1)} > 0\}}$ is the indicator function representing the derivative of the ReLU (1 if $a^{(1)}>0$, else 0). The gradients for the input-to-hidden weights and biases are then computed as  
   $$ \nabla W^{(1)} = \delta^{(1)} \otimes x. $$
   
**PTX Implementation Details:**  
The PTX (or corresponding CUDA C) implementation translates these steps into a series of loops with explicit memory loads and stores. Temporary arrays are allocated (often in registers or local memory) to hold the intermediate delta values. However, due to limitations such as the inability to index register arrays easily and strict operand type matching requirements, the PTX implementation of backpropagation encountered several issues and did not work as expected without further debugging and modifications.

---

# 3. Training Setup Implementation in CUDA

The training is implemented in two main CUDA kernels:

1. **trainStep Kernel:**  
   - **Input:** A mini-batch of images and one-hot encoded labels.
   - **Forward Pass:** For each sample, the kernel computes the hidden layer activations using $a^{(1)} = \max(0, W^{(1)} x + b^{(1)})$, and then computes the output $a^{(2)} = W^{(2)} a^{(1)} + b^{(2)}$.
   - **Backward Pass:** The kernel computes the output error $\delta^{(2)} = a^{(2)} - y$ and backpropagates this error to compute $\delta^{(1)}$ for the hidden layer. Gradients are calculated as outer products of the error terms with the corresponding activations.
   - **Output:** The gradients for each sample are written into per-sample gradient buffers.

2. **updateWeights Kernel:**  
   - **Input:** The gradient buffers and the current parameter values.
   - **Computation:** This kernel aggregates the per-sample gradients by summing them over the mini-batch, averages them (dividing by the batch size), and updates the parameters using the rule  
     $$ W \leftarrow W - \gamma \cdot \frac{1}{N}\sum_{i=1}^{N} \nabla W_i $$
     where $\gamma$ is the learning rate and $N$ is the batch size.

**Overall Training Loop:**  
- The host code sets up GPU memory for inputs, activations, parameters, and gradient buffers.
- For each mini-batch, the data is transferred to the GPU.
- The `trainStep` kernel is launched to perform forward propagation and compute gradients.
- The `updateWeights` kernel is then invoked for each parameter to update the network.
- Periodically, the parameters are copied back to the host and the network’s performance is evaluated using a NumPy forward pass.
- This process continues until the test accuracy exceeds 60% or a maximum number of batches is processed.

---

# 4. Comparison with PyTorch

To evaluate our custom CUDA implementation, we compared it with a PyTorch implementation that uses an identical two-layer network architecture and batch gradient updates with mean squared error (MSE) loss on one-hot targets.

**PyCUDA Results:**  
The PyCUDA implementation (a two-layer network with one hidden layer of 128 neurons) was trained on the MNIST dataset with a batch size of $64$. Training proceeded until the test accuracy exceeded 60%. Here are some of the recorded test accuracies:


In [13]:

forward_pass = r"""
.version 7.0
.target sm_35
.address_size 64

// -------------------------------------------------------------------
// Kernel: neuralNetBatch
//
// This kernel computes a forward pass for a 3-layer feed-forward neural
// network over a batch of inputs. Each thread processes one sample.
// For each sample, the network performs:
//
//   Layer 1 (Input -> Hidden1):
//     z^(1) = W^(1)*x + b^(1)
//     a^(1) = ReLU(z^(1))
//
//   Layer 2 (Hidden1 -> Hidden2):
//     z^(2) = W^(2)*a^(1) + b^(2)
//     a^(2) = ReLU(z^(2))
//
//   Layer 3 (Hidden2 -> Output):
//     z^(3) = W^(3)*a^(2) + b^(3)
//     (The output here is raw logits; softmax may be applied later.)
//
// The weight matrices and biases are shared by all threads while the input,
// intermediate activations, and outputs are stored in batched arrays.
// -------------------------------------------------------------------
.visible .entry neuralNetBatch(
    .param .u64 input_ptr,       // [batch_size x input_dim] array
    .param .u64 weights1_ptr,    // Layer 1 weights: [hidden1_dim x input_dim]
    .param .u64 bias1_ptr,       // Layer 1 biases: [hidden1_dim]
    .param .u64 hidden1_ptr,     // Layer 1 activations: [batch_size x hidden1_dim]
    .param .u64 weights2_ptr,    // Layer 2 weights: [hidden2_dim x hidden1_dim]
    .param .u64 bias2_ptr,       // Layer 2 biases: [hidden2_dim]
    .param .u64 hidden2_ptr,     // Layer 2 activations: [batch_size x hidden2_dim]
    .param .u64 weights3_ptr,    // Layer 3 weights: [output_dim x hidden2_dim]
    .param .u64 bias3_ptr,       // Layer 3 biases: [output_dim]
    .param .u64 output_ptr,      // Output logits: [batch_size x output_dim]
    .param .u32 batch_size,      // Total number of samples in the batch
    .param .u32 input_dim,       // Dimensionality of input vector
    .param .u32 hidden1_dim,     // Number of neurons in Layer 1
    .param .u32 hidden2_dim,     // Number of neurons in Layer 2
    .param .u32 output_dim       // Number of output neurons
)
{
    // ---------------------------------------------------------------
    // Declare registers.
    // Note: We rename the loaded batch_size to "bs" to avoid duplicate names.
    // ---------------------------------------------------------------
    .reg .u64   in_base, w1_base, b1_base, h1_base, w2_base, b2_base, h2_base, w3_base, b3_base, out_base;
    .reg .u32   bs, inputDim, hidden1Dim, hidden2Dim, outputDim;
    .reg .u32   tid, blockId, blockDim, sampleIdx;
    .reg .u32   i, j, idx, idx_byte;
    .reg .u64   byte_offset;
    .reg .u64   addr_w, addr_bias;
    .reg .f32   acc, temp, in_val, wt, bias;
    .reg .u64   sample_in_ptr, sample_h1_ptr, sample_h2_ptr, sample_out_ptr;
    .reg .pred  p_exit;

    // ---------------------------------------------------------------
    // Load kernel parameters.
    // ---------------------------------------------------------------
    ld.param.u64   in_base,   [input_ptr];
    ld.param.u64   w1_base,   [weights1_ptr];
    ld.param.u64   b1_base,   [bias1_ptr];
    ld.param.u64   h1_base,   [hidden1_ptr];
    ld.param.u64   w2_base,   [weights2_ptr];
    ld.param.u64   b2_base,   [bias2_ptr];
    ld.param.u64   h2_base,   [hidden2_ptr];
    ld.param.u64   w3_base,   [weights3_ptr];
    ld.param.u64   b3_base,   [bias3_ptr];
    ld.param.u64   out_base,  [output_ptr];
    ld.param.u32   bs, [batch_size];       // Load batch size into "bs"
    ld.param.u32   inputDim,  [input_dim];
    ld.param.u32   hidden1Dim,[hidden1_dim];
    ld.param.u32   hidden2Dim,[hidden2_dim];
    ld.param.u32   outputDim, [output_dim];

    // ---------------------------------------------------------------
    // Compute global thread index (sampleIdx) using built-in registers.
    // sampleIdx = blockIdx.x * blockDim.x + threadIdx.x
    // ---------------------------------------------------------------
    mov.u32   tid, %tid.x;
    mov.u32   blockId, %ctaid.x;
    mov.u32   blockDim, %ntid.x;
    mul.lo.u32 sampleIdx, blockId, blockDim;
    add.u32    sampleIdx, sampleIdx, tid;

    // If sampleIdx >= bs then exit this thread.
    setp.ge.u32 p_exit, sampleIdx, bs;
    @p_exit bra exit;

    // ---------------------------------------------------------------
    // Compute sample-specific base pointers (each sample stored contiguously):
    //   sample_in_ptr = in_base + sampleIdx * inputDim * 4
    //   sample_h1_ptr = h1_base + sampleIdx * hidden1Dim * 4
    //   sample_h2_ptr = h2_base + sampleIdx * hidden2Dim * 4
    //   sample_out_ptr = out_base + sampleIdx * outputDim * 4
    // ---------------------------------------------------------------
    mul.lo.u32 idx, sampleIdx, inputDim;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 sample_in_ptr, in_base, byte_offset;

    mul.lo.u32 idx, sampleIdx, hidden1Dim;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 sample_h1_ptr, h1_base, byte_offset;

    mul.lo.u32 idx, sampleIdx, hidden2Dim;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 sample_h2_ptr, h2_base, byte_offset;

    mul.lo.u32 idx, sampleIdx, outputDim;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 sample_out_ptr, out_base, byte_offset;

    // ===============================================================
    // Layer 1: Compute hidden1 = ReLU(W1 * input + b1)
    // ===============================================================
    mov.u32 j, 0;
layer1_loop:
    setp.ge.u32 p_exit, j, hidden1Dim;
    @p_exit bra layer1_end;

    // Initialize accumulator for neuron j to 0.
    mov.f32 acc, 0f00000000;
    mov.u32 i, 0;
layer1_inner_loop:
    setp.ge.u32 p_exit, i, inputDim;
    @p_exit bra layer1_after_inner;

    // Compute weight index: idx = j * inputDim + i.
    mul.lo.u32 idx, j, inputDim;
    add.u32 idx, idx, i;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, w1_base, byte_offset;
    ld.global.f32 wt, [addr_w];

    // Load input: sample_in_ptr + i*4.
    mul.lo.u32 idx_byte, i, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_in_ptr, byte_offset;
    ld.global.f32 in_val, [addr_w];

    // Multiply and accumulate.
    mul.f32 temp, wt, in_val;
    add.f32 acc, acc, temp;
    add.u32 i, i, 1;
    bra layer1_inner_loop;
layer1_after_inner:
    // Load bias for neuron j.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_bias, b1_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32 acc, acc, bias;
    // Apply ReLU: acc = max(0, acc)
    max.f32 acc, acc, 0f00000000;
    // Store activation in hidden1.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_h1_ptr, byte_offset;
    st.global.f32 [addr_w], acc;
    add.u32 j, j, 1;
    bra layer1_loop;
layer1_end:

    // ===============================================================
    // Layer 2: Compute hidden2 = ReLU(W2 * hidden1 + b2)
    // ===============================================================
    mov.u32 j, 0;
layer2_loop:
    setp.ge.u32 p_exit, j, hidden2Dim;
    @p_exit bra layer2_end;
    mov.f32 acc, 0f00000000;
    mov.u32 i, 0;
layer2_inner_loop:
    setp.ge.u32 p_exit, i, hidden1Dim;
    @p_exit bra layer2_after_inner;
    mul.lo.u32 idx, j, hidden1Dim;
    add.u32 idx, idx, i;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, w2_base, byte_offset;
    ld.global.f32 wt, [addr_w];
    // Load activation from hidden1.
    mul.lo.u32 idx_byte, i, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_h1_ptr, byte_offset;
    ld.global.f32 in_val, [addr_w];
    mul.f32 temp, wt, in_val;
    add.f32 acc, acc, temp;
    add.u32 i, i, 1;
    bra layer2_inner_loop;
layer2_after_inner:
    // Add bias for layer 2 neuron j.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_bias, b2_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32 acc, acc, bias;
    max.f32 acc, acc, 0f00000000;
    // Store activation in hidden2.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_h2_ptr, byte_offset;
    st.global.f32 [addr_w], acc;
    add.u32 j, j, 1;
    bra layer2_loop;
layer2_end:

    // ===============================================================
    // Layer 3: Compute output = W3 * hidden2 + b3
    // ===============================================================
    mov.u32 j, 0;
layer3_loop:
    setp.ge.u32 p_exit, j, outputDim;
    @p_exit bra layer3_end;
    mov.f32 acc, 0f00000000;
    mov.u32 i, 0;
layer3_inner_loop:
    setp.ge.u32 p_exit, i, hidden2Dim;
    @p_exit bra layer3_after_inner;
    mul.lo.u32 idx, j, hidden2Dim;
    add.u32 idx, idx, i;
    mul.lo.u32 idx_byte, idx, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, w3_base, byte_offset;
    ld.global.f32 wt, [addr_w];
    // Load activation from hidden2.
    mul.lo.u32 idx_byte, i, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_h2_ptr, byte_offset;
    ld.global.f32 in_val, [addr_w];
    mul.f32 temp, wt, in_val;
    add.f32 acc, acc, temp;
    add.u32 i, i, 1;
    bra layer3_inner_loop;
layer3_after_inner:
    // Add bias for output neuron j.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_bias, b3_base, byte_offset;
    ld.global.f32 bias, [addr_bias];
    add.f32 acc, acc, bias;
    // Store the output logit.
    mul.lo.u32 idx_byte, j, 4;
    cvt.u64.u32 byte_offset, idx_byte;
    add.u64 addr_w, sample_out_ptr, byte_offset;
    st.global.f32 [addr_w], acc;
    add.u32 j, j, 1;
    bra layer3_loop;
layer3_end:
exit:
    ret;
}

"""

Here we have the python code to run the above PTX.

In [19]:
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np

# -------------------------------
# Define network dimensions:
#   For MNIST: input_dim = 784.
#   Here we choose arbitrary hidden layer sizes.
# -------------------------------
batch_size   = 10   # Number of samples in the batch.
input_dim    = 784
hidden1_dim  = 128
hidden2_dim  = 64
output_dim   = 10

# -------------------------------
# Create random data for testing:
#   Each sample is a row in the input array.
#   We initialize weight matrices and biases with random data.
# -------------------------------
input_host   = np.random.randn(batch_size, input_dim).astype(np.float32)
W1_host      = np.random.randn(hidden1_dim, input_dim).astype(np.float32)
b1_host      = np.random.randn(hidden1_dim).astype(np.float32)
W2_host      = np.random.randn(hidden2_dim, hidden1_dim).astype(np.float32)
b2_host      = np.random.randn(hidden2_dim).astype(np.float32)
W3_host      = np.random.randn(output_dim, hidden2_dim).astype(np.float32)
b3_host      = np.random.randn(output_dim).astype(np.float32)

# Output buffers for intermediate activations and final logits.
hidden1_host = np.empty((batch_size, hidden1_dim), dtype=np.float32)
hidden2_host = np.empty((batch_size, hidden2_dim), dtype=np.float32)
output_host  = np.empty((batch_size, output_dim), dtype=np.float32)

# -------------------------------
# Allocate device memory and copy host data:
# -------------------------------
input_gpu   = cuda.mem_alloc(input_host.nbytes)
W1_gpu      = cuda.mem_alloc(W1_host.nbytes)
b1_gpu      = cuda.mem_alloc(b1_host.nbytes)
hidden1_gpu = cuda.mem_alloc(hidden1_host.nbytes)
W2_gpu      = cuda.mem_alloc(W2_host.nbytes)
b2_gpu      = cuda.mem_alloc(b2_host.nbytes)
hidden2_gpu = cuda.mem_alloc(hidden2_host.nbytes)
W3_gpu      = cuda.mem_alloc(W3_host.nbytes)
b3_gpu      = cuda.mem_alloc(b3_host.nbytes)
output_gpu  = cuda.mem_alloc(output_host.nbytes)

cuda.memcpy_htod(input_gpu, input_host)
cuda.memcpy_htod(W1_gpu, W1_host)
cuda.memcpy_htod(b1_gpu, b1_host)
cuda.memcpy_htod(W2_gpu, W2_host)
cuda.memcpy_htod(b2_gpu, b2_host)
cuda.memcpy_htod(W3_gpu, W3_host)
cuda.memcpy_htod(b3_gpu, b3_host)



# -------------------------------
# Load the PTX module and get the kernel function.
# -------------------------------
module = cuda.module_from_buffer(forward_pass.encode("utf-8"))
kernel = module.get_function("neuralNetBatch")

# -------------------------------
# Configure grid and block dimensions.
#
# Since we want one thread per sample, we set the total number of threads
# equal to batch_size.
# -------------------------------
threads_per_block = batch_size if batch_size <= 1024 else 1024
blocks = (batch_size + threads_per_block - 1) // threads_per_block

# -------------------------------
# Launch the kernel.
# -------------------------------
kernel(
    input_gpu,
    W1_gpu,
    b1_gpu,
    hidden1_gpu,
    W2_gpu,
    b2_gpu,
    hidden2_gpu,
    W3_gpu,
    b3_gpu,
    output_gpu,
    np.uint32(batch_size),
    np.uint32(input_dim),
    np.uint32(hidden1_dim),
    np.uint32(hidden2_dim),
    np.uint32(output_dim),
    block=(threads_per_block, 1, 1),
    grid=(blocks, 1, 1)
)

# -------------------------------
# Retrieve and print the results.
# -------------------------------
cuda.memcpy_dtoh(hidden1_host, hidden1_gpu)
cuda.memcpy_dtoh(hidden2_host, hidden2_gpu)
cuda.memcpy_dtoh(output_host, output_gpu)

print("Hidden layer 1 activations shape:", hidden1_host.shape)
print("Hidden layer 2 activations shape:", hidden2_host.shape)
print("Output layer (logits) shape:", output_host.shape)
print("Output logits for each sample:")
print(output_host)



Hidden layer 1 activations shape: (10, 128)
Hidden layer 2 activations shape: (10, 64)
Output layer (logits) shape: (10, 10)
Output logits for each sample:
[[-2398.4146     896.48694  -2454.5962    -310.93024    364.0878
    316.3618     952.67584  -1174.3683    -400.76767   -694.3439  ]
 [ -218.6154    1192.3281    -988.25446    918.27496  -1215.4276
    191.75932    537.2199    1248.4154     -92.95179    184.19208 ]
 [-2518.4102    1683.3507   -1737.7667    1169.7318    -601.6049
    210.67848   1875.9261    -862.35266   -549.42114    860.8875  ]
 [    6.430615  3396.2239    -629.7285    2261.9907   -1670.4064
   -264.03015   -484.3882     124.95582   -415.25583   -449.62613 ]
 [  371.7862   -1066.0829   -1795.1741    2092.619     -522.11346
   -756.5303     382.81982   1786.3528    -273.72974    546.8591  ]
 [ -346.25864   -179.3902    -174.4288    2052.9424   -1083.8575
    796.1063     366.92545   1549.8699    -730.0691    1120.9371  ]
 [  510.4109     337.61234    442.06943   -76

## Neural Network Architecture

Let the input be:
$$
\mathbf{x} \in \mathbb{R}^{d} \quad \text{with } d=784 \quad (\text{MNIST image flattened})
$$

The network has three layers defined as follows:

### Layer 1 (Input to Hidden1)

- **Weights:**  
  $$
  W^{(1)} \in \mathbb{R}^{n_1 \times d}
  $$
  Each element is \( W^{(1)}_{ij} \) with \( i = 1, \dots, n_1 \) and \( j = 1, \dots, d \).

- **Biases:**  
  $$
  \mathbf{b}^{(1)} \in \mathbb{R}^{n_1}
  $$
  with components \( b^{(1)}_{i} \).

- **Pre-activation:**  
  $$
  \mathbf{z}^{(1)} = W^{(1)} \mathbf{x} + \mathbf{b}^{(1)}
  $$

- **Activation (ReLU):**  
  $$
  \mathbf{a}^{(1)} = \sigma(\mathbf{z}^{(1)}) \quad \text{with } \sigma(z) = \max(0, z)
  $$

### Layer 2 (Hidden1 to Hidden2)

- **Weights:**  
  $$
  W^{(2)} \in \mathbb{R}^{n_2 \times n_1}
  $$
  Each element is \( W^{(2)}_{ij} \) for \( i = 1, \dots, n_2 \) and \( j = 1, \dots, n_1 \).

- **Biases:**  
  $$
  \mathbf{b}^{(2)} \in \mathbb{R}^{n_2}
  $$
  with components \( b^{(2)}_{i} \).

- **Pre-activation:**  
  $$
  \mathbf{z}^{(2)} = W^{(2)} \mathbf{a}^{(1)} + \mathbf{b}^{(2)}
  $$

- **Activation (ReLU):**  
  $$
  \mathbf{a}^{(2)} = \sigma(\mathbf{z}^{(2)}) \quad \text{with } \sigma(z) = \max(0, z)
  $$

### Layer 3 (Hidden2 to Output)

- **Weights:**  
  $$
  W^{(3)} \in \mathbb{R}^{10 \times n_2}
  $$
  Each element is \( W^{(3)}_{ij} \) for \( i = 1, \dots, 10 \) and \( j = 1, \dots, n_2 \).

- **Biases:**  
  $$
  \mathbf{b}^{(3)} \in \mathbb{R}^{10}
  $$
  with components \( b^{(3)}_{i} \).

- **Pre-activation (Logits):**  
  $$
  \mathbf{z}^{(3)} = W^{(3)} \mathbf{a}^{(2)} + \mathbf{b}^{(3)}
  $$

- **Output:**  
  Often, a softmax is applied to \(\mathbf{z}^{(3)}\) to obtain probabilities:
  $$
  \hat{\mathbf{y}} = \operatorname{softmax}(\mathbf{z}^{(3)}) \quad \text{where } \hat{y}_i = \frac{e^{z^{(3)}_i}}{\sum_{k=1}^{10} e^{z^{(3)}_k}}
  $$

---

## Loss Function

Assume we use the cross-entropy loss for classification. For a single training example with true label vector $\mathbf{y}$ (one-hot encoded), the loss is:
$$
L = -\sum_{i=1}^{10} y_i \log(\hat{y}_i)
$$

---

## Backpropagation

Let the derivative of the loss with respect to any variable \( u \) be denoted by \(\frac{\partial L}{\partial u}\). We define the error terms (or deltas) for each layer.

### Output Layer (Layer 3)

For the output layer using softmax and cross-entropy, the error is:
$$
\delta^{(3)} = \hat{\mathbf{y}} - \mathbf{y}
$$

The gradients for layer 3 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(3)}} = \delta^{(3)} \, (\mathbf{a}^{(2)})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(3)}} = \delta^{(3)}
  $$

### Hidden Layer 2 (Layer 2)

The error for layer 2 is computed by backpropagating the error from layer 3 and applying the derivative of the ReLU activation. Recall:
$$
\sigma'(z) = \begin{cases} 
1 & \text{if } z > 0, \\
0 & \text{otherwise}
\end{cases}
$$

Thus,
$$
\delta^{(2)} = \left( W^{(3)} \right)^T \delta^{(3)} \circ \sigma'(\mathbf{z}^{(2)})
$$
where “\(\circ\)” denotes element-wise multiplication.

The gradients for layer 2 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} \, (\mathbf{a}^{(1)})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(2)}} = \delta^{(2)}
  $$

### Hidden Layer 1 (Layer 1)

Similarly, backpropagate the error to the first hidden layer:
$$
\delta^{(1)} = \left( W^{(2)} \right)^T \delta^{(2)} \circ \sigma'(\mathbf{z}^{(1)})
$$

The gradients for layer 1 are:

- **Weights:**
  $$
  \frac{\partial L}{\partial W^{(1)}} = \delta^{(1)} \, (\mathbf{x})^T
  $$

- **Biases:**
  $$
  \frac{\partial L}{\partial \mathbf{b}^{(1)}} = \delta^{(1)}
  $$

---

## Summary

### Forward Pass

1. **Layer 1:**
   $$
   \mathbf{z}^{(1)} = W^{(1)} \mathbf{x} + \mathbf{b}^{(1)}, \quad \mathbf{a}^{(1)} = \max(0, \mathbf{z}^{(1)})
   $$

2. **Layer 2:**
   $$
   \mathbf{z}^{(2)} = W^{(2)} \mathbf{a}^{(1)} + \mathbf{b}^{(2)}, \quad \mathbf{a}^{(2)} = \max(0, \mathbf{z}^{(2)})
   $$

3. **Layer 3:**
   $$
   \mathbf{z}^{(3)} = W^{(3)} \mathbf{a}^{(2)} + \mathbf{b}^{(3)}, \quad \hat{\mathbf{y}} = \operatorname{softmax}(\mathbf{z}^{(3)})
   $$

### Backward Pass

1. **Output Layer (Layer 3):**
   $$
   \delta^{(3)} = \hat{\mathbf{y}} - \mathbf{y}
   $$
   $$
   \frac{\partial L}{\partial W^{(3)}} = \delta^{(3)} (\mathbf{a}^{(2)})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(3)}} = \delta^{(3)}
   $$

2. **Hidden Layer 2 (Layer 2):**
   $$
   \delta^{(2)} = \left( W^{(3)} \right)^T \delta^{(3)} \circ 1_{\{\mathbf{z}^{(2)} > 0\}}
   $$
   $$
   \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} (\mathbf{a}^{(1)})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(2)}} = \delta^{(2)}
   $$

3. **Hidden Layer 1 (Layer 1):**
   $$
   \delta^{(1)} = \left( W^{(2)} \right)^T \delta^{(2)} \circ 1_{\{\mathbf{z}^{(1)} > 0\}}
   $$
   $$
   \frac{\partial L}{\partial W^{(1)}} = \delta^{(1)} (\mathbf{x})^T, \quad \frac{\partial L}{\partial \mathbf{b}^{(1)}} = \delta^{(1)}
   $$



In [34]:
fb_code = r"""
.version 7.0
.target sm_35
.address_size 64

.visible .entry trainStep_fb(
    .param .u64 input_ptr,       // [batch_size x input_dim]
    .param .u64 labels_ptr,      // [batch_size x output_dim] (one-hot)
    .param .u64 hidden1_ptr,     // scratch for layer1 activations [batch_size x hidden1_dim]
    .param .u64 hidden2_ptr,     // scratch for layer2 activations [batch_size x hidden2_dim]
    .param .u64 output_ptr,      // scratch for layer3 outputs [batch_size x output_dim]
    .param .u64 weights1_ptr,    // W1: [hidden1_dim x input_dim]
    .param .u64 bias1_ptr,       // b1: [hidden1_dim]
    .param .u64 weights2_ptr,    // W2: [hidden2_dim x hidden1_dim]
    .param .u64 bias2_ptr,       // b2: [hidden2_dim]
    .param .u64 weights3_ptr,    // W3: [output_dim x hidden2_dim]
    .param .u64 bias3_ptr,       // b3: [output_dim]
    .param .u32 batch_size,
    .param .u32 input_dim,
    .param .u32 hidden1_dim,
    .param .u32 hidden2_dim,
    .param .u32 output_dim,
    .param .f32 learning_rate    // γ
)
{
    // Allocate local memory for delta2 (δ^2) per thread.
    .local .align 4 .f32 delta2_array[64];

    // ------------------ Register Declarations ----------------------
    .reg .u64    in_base, lab_base, h1_base, h2_base, out_base;
    .reg .u64    w1_base, b1_base, w2_base, b2_base, w3_base, b3_base;
    .reg .u32    bs, inputDim, hidden1Dim, hidden2Dim, outputDim;
    .reg .f32    gamma;
    .reg .u32    tid, blockId, blockDim, sampleIdx;
    .reg .u32    i, j, k, l, idx, idx_byte;
    .reg .u64    byte_offset;
    .reg .u64    addr;
    .reg .f32    temp, in_val, wt, bias_val, acc;
    .reg .f32    a_val;
    .reg .f32    delta3, delta2, delta1;
    .reg .f32    label_val;
    .reg .f32    a_h1, a_h2;
    .reg .pred   p_exit;
    .reg .u64    local_delta2_ptr;

    // -------------------- Load Kernel Parameters ---------------------
    ld.param.u64   in_base,   [input_ptr];
    ld.param.u64   lab_base,  [labels_ptr];
    ld.param.u64   h1_base,   [hidden1_ptr];
    ld.param.u64   h2_base,   [hidden2_ptr];
    ld.param.u64   out_base,  [output_ptr];
    ld.param.u64   w1_base,   [weights1_ptr];
    ld.param.u64   b1_base,   [bias1_ptr];
    ld.param.u64   w2_base,   [weights2_ptr];
    ld.param.u64   b2_base,   [bias2_ptr];
    ld.param.u64   w3_base,   [weights3_ptr];
    ld.param.u64   b3_base,   [bias3_ptr];
    ld.param.u32   bs,        [batch_size];
    ld.param.u32   inputDim,  [input_dim];
    ld.param.u32   hidden1Dim, [hidden1_dim];
    ld.param.u32   hidden2Dim, [hidden2_dim];
    ld.param.u32   outputDim, [output_dim];
    ld.param.f32   gamma,     [learning_rate];

    // Convert pointer parameters to global address space.
    cvta.global.u64 in_base, in_base;
    cvta.global.u64 lab_base, lab_base;
    cvta.global.u64 h1_base, h1_base;
    cvta.global.u64 h2_base, h2_base;
    cvta.global.u64 out_base, out_base;
    cvta.global.u64 w1_base, w1_base;
    cvta.global.u64 b1_base, b1_base;
    cvta.global.u64 w2_base, w2_base;
    cvta.global.u64 b2_base, b2_base;
    cvta.global.u64 w3_base, w3_base;
    cvta.global.u64 b3_base, b3_base;

    // ------------------- Compute Sample Index ------------------------
    mov.u32 tid, %tid.x;
    mov.u32 blockId, %ctaid.x;
    mov.u32 blockDim, %ntid.x;
    mul.lo.u32 sampleIdx, blockId, blockDim;
    add.u32 sampleIdx, sampleIdx, tid;
    setp.ge.u32 p_exit, sampleIdx, bs;
    @p_exit bra exit;

    // ---------------------- Forward Pass -----------------------------
    // Layer 1: a^(1) = ReLU(W1*x + b1)
    mov.u32 j, 0;
layer1_loop:
    setp.ge.u32 p_exit, j, hidden1Dim;
    @p_exit bra layer1_end;
        mov.f32 acc, 0f00000000;
        mov.u32 i, 0;
    layer1_inner:
        setp.ge.u32 p_exit, i, inputDim;
        @p_exit bra layer1_inner_end;
            // Compute index = j*inputDim + i
            mul.lo.u32 idx, j, inputDim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w1_base, byte_offset;
            ld.global.f32 wt, [addr];
            // Load input: index = sampleIdx*inputDim + i
            mul.lo.u32 idx, sampleIdx, inputDim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, in_base, byte_offset;
            ld.global.f32 in_val, [addr];
            mul.f32 temp, wt, in_val;
            add.f32 acc, acc, temp;
            add.u32 i, i, 1;
            bra layer1_inner;
    layer1_inner_end:
        // Add bias: b1[j]
        mul.lo.u32 idx_byte, j, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b1_base, byte_offset;
        ld.global.f32 bias_val, [addr];
        add.f32 acc, acc, bias_val;
        // Apply ReLU:
        max.f32 acc, acc, 0f00000000;
        // Store activation in hidden1: index = sampleIdx*hidden1Dim + j
        mul.lo.u32 idx, sampleIdx, hidden1Dim;
        add.u32 idx, idx, j;
        mul.lo.u32 idx_byte, idx, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, h1_base, byte_offset;
        st.global.f32 [addr], acc;
        add.u32 j, j, 1;
        bra layer1_loop;
layer1_end:

    // Layer 2: a^(2) = ReLU(W2 * a^(1) + b2)
    mov.u32 j, 0;
layer2_loop:
    setp.ge.u32 p_exit, j, hidden2Dim;
    @p_exit bra layer2_end;
        mov.f32 acc, 0f00000000;
        mov.u32 i, 0;
    layer2_inner:
        setp.ge.u32 p_exit, i, hidden1Dim;
        @p_exit bra layer2_inner_end;
            // Compute index = j*hidden1Dim + i
            mul.lo.u32 idx, j, hidden1Dim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w2_base, byte_offset;
            ld.global.f32 wt, [addr];
            // Load activation from hidden1: index = sampleIdx*hidden1Dim + i
            mul.lo.u32 idx, sampleIdx, hidden1Dim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, h1_base, byte_offset;
            ld.global.f32 in_val, [addr];
            mul.f32 temp, wt, in_val;
            add.f32 acc, acc, temp;
            add.u32 i, i, 1;
            bra layer2_inner;
    layer2_inner_end:
        // Add bias: b2[j]
        mul.lo.u32 idx_byte, j, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b2_base, byte_offset;
        ld.global.f32 bias_val, [addr];
        add.f32 acc, acc, bias_val;
        // Apply ReLU:
        max.f32 acc, acc, 0f00000000;
        // Store activation in hidden2: index = sampleIdx*hidden2Dim + j
        mul.lo.u32 idx, sampleIdx, hidden2Dim;
        add.u32 idx, idx, j;
        mul.lo.u32 idx_byte, idx, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, h2_base, byte_offset;
        st.global.f32 [addr], acc;
        add.u32 j, j, 1;
        bra layer2_loop;
layer2_end:

    // Layer 3: a^(3) = W3 * a^(2) + b3   (logits)
    mov.u32 j, 0;
layer3_loop:
    setp.ge.u32 p_exit, j, outputDim;
    @p_exit bra layer3_end;
        mov.f32 acc, 0f00000000;
        mov.u32 i, 0;
    layer3_inner:
        setp.ge.u32 p_exit, i, hidden2Dim;
        @p_exit bra layer3_inner_end;
            // Compute index = j*hidden2Dim + i
            mul.lo.u32 idx, j, hidden2Dim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w3_base, byte_offset;
            ld.global.f32 wt, [addr];
            // Load activation from hidden2: index = sampleIdx*hidden2Dim + i
            mul.lo.u32 idx, sampleIdx, hidden2Dim;
            add.u32 idx, idx, i;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, h2_base, byte_offset;
            ld.global.f32 in_val, [addr];
            mul.f32 temp, wt, in_val;
            add.f32 acc, acc, temp;
            add.u32 i, i, 1;
            bra layer3_inner;
    layer3_inner_end:
        // Add bias: b3[j]
        mul.lo.u32 idx_byte, j, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b3_base, byte_offset;
        ld.global.f32 bias_val, [addr];
        add.f32 acc, acc, bias_val;
        // Store output (logit) in global memory: index = sampleIdx*outputDim + j
        mul.lo.u32 idx, sampleIdx, outputDim;
        add.u32 idx, idx, j;
        mul.lo.u32 idx_byte, idx, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, out_base, byte_offset;
        st.global.f32 [addr], acc;
        add.u32 j, j, 1;
        bra layer3_loop;
layer3_end:

    // ---------------- Backpropagation --------------------------
    // ---------- Layer 3 BP ----------
    // δ^3 = a^(3) - label; update bias b3 and weight matrix W3.
    mov.u32 i, 0;
layer3_bp_loop:
    setp.ge.u32 p_exit, i, outputDim;
    @p_exit bra layer3_bp_end;
        // Load output a^(3)[i]
        mul.lo.u32 idx_byte, i, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, out_base, byte_offset;
        ld.global.f32 a_val, [addr];
        // Load label[i]
        mul.lo.u32 idx_byte, i, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, lab_base, byte_offset;
        ld.global.f32 label_val, [addr];
        sub.f32 delta3, a_val, label_val;
        // Update bias b3[i]
        mul.f32 temp, gamma, delta3;
        neg.f32 temp, temp;
        mul.lo.u32 idx_byte, i, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b3_base, byte_offset;
        atom.global.add.f32 [addr], temp;
        // Update weight matrix W3 for each hidden2 neuron j:
        mov.u32 j, 0;
    layer3_bp_inner_loop:
        setp.ge.u32 p_exit, j, hidden2Dim;
        @p_exit bra layer3_bp_inner_end;
            // Load activation from hidden2: index = sampleIdx*hidden2Dim + j
            mul.lo.u32 idx, sampleIdx, hidden2Dim;
            add.u32 idx, idx, j;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, h2_base, byte_offset;
            ld.global.f32 a_h2, [addr];
            mul.f32 temp, delta3, a_h2;
            mul.f32 temp, gamma, temp;
            neg.f32 temp, temp;
            // Update weight W3[i,j]: index = i*hidden2Dim + j
            mul.lo.u32 idx, i, hidden2Dim;
            add.u32 idx, idx, j;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w3_base, byte_offset;
            atom.global.add.f32 [addr], temp;
            add.u32 j, j, 1;
            bra layer3_bp_inner_loop;
    layer3_bp_inner_end:
        add.u32 i, i, 1;
        bra layer3_bp_loop;
layer3_bp_end:

    // ---------- Layer 2 BP ----------
    // δ^2 = (sum_i W3[i,j]*δ^3[i]) * ReLU'(a^(2)[j]); update b2 and W2.
    mov.u32 j, 0;
layer2_bp_loop:
    setp.ge.u32 p_exit, j, hidden2Dim;
    @p_exit bra layer2_bp_end;
        mov.f32 delta2, 0f00000000;
        mov.u32 i, 0;
    layer2_bp_inner_loop:
        setp.ge.u32 p_exit, i, outputDim;
        @p_exit bra layer2_bp_inner_end;
            // For W3: index = i*hidden2Dim + j
            mul.lo.u32 idx, i, hidden2Dim;
            add.u32 idx, idx, j;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w3_base, byte_offset;
            ld.global.f32 wt, [addr];
            // Recompute δ^3 for neuron i
            mul.lo.u32 idx_byte, i, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, out_base, byte_offset;
            ld.global.f32 a_val, [addr];
            mul.lo.u32 idx_byte, i, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, lab_base, byte_offset;
            ld.global.f32 label_val, [addr];
            sub.f32 delta3, a_val, label_val;
            mul.f32 temp, wt, delta3;
            add.f32 delta2, delta2, temp;
            add.u32 i, i, 1;
            bra layer2_bp_inner_loop;
    layer2_bp_inner_end:
        // Apply ReLU derivative on a^(2)[j]:
        mul.lo.u32 idx_byte, j, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, h2_base, byte_offset;
        ld.global.f32 a_val, [addr];
        setp.le.f32 p_exit, a_val, 0f00000000;
        @p_exit mov.f32 delta2, 0f00000000;
        // Store δ^2[j] in local memory
        mul.lo.u32 idx_byte, j, 4;
        cvta.local.u64 local_delta2_ptr, delta2_array;
        add.u64 addr, local_delta2_ptr, idx_byte;
        st.local.f32 [addr], delta2;
        // Update bias b2[j]
        mul.f32 temp, gamma, delta2;
        neg.f32 temp, temp;
        mul.lo.u32 idx_byte, j, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b2_base, byte_offset;
        atom.global.add.f32 [addr], temp;
        // Update W2: for each hidden1 neuron k
        mov.u32 k, 0;
    layer2_bp_inner2_loop:
        setp.ge.u32 p_exit, k, hidden1Dim;
        @p_exit bra layer2_bp_inner2_end;
            // Load activation from hidden1: index = sampleIdx*hidden1Dim + k
            mul.lo.u32 idx, sampleIdx, hidden1Dim;
            add.u32 idx, idx, k;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, h1_base, byte_offset;
            ld.global.f32 a_h1, [addr];
            mul.f32 temp, delta2, a_h1;
            mul.f32 temp, gamma, temp;
            neg.f32 temp, temp;
            // Update weight W2[j,k]: index = j*hidden1Dim + k
            mul.lo.u32 idx, j, hidden1Dim;
            add.u32 idx, idx, k;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w2_base, byte_offset;
            atom.global.add.f32 [addr], temp;
            add.u32 k, k, 1;
            bra layer2_bp_inner2_loop;
    layer2_bp_inner2_end:
        add.u32 j, j, 1;
        bra layer2_bp_loop;
layer2_bp_end:

    // ---------- Layer 1 BP ----------
    // δ^1 = (sum_j W2[j,k]*δ^2[j]) * ReLU'(a^(1)[k]); update b1 and W1.
    mov.u32 k, 0;
layer1_bp_loop:
    setp.ge.u32 p_exit, k, hidden1Dim;
    @p_exit bra layer1_bp_end;
        mov.f32 delta1, 0f00000000;
        mov.u32 j, 0;
    layer1_bp_inner_loop:
        setp.ge.u32 p_exit, j, hidden2Dim;
        @p_exit bra layer1_bp_inner_end;
            // Compute index for W2[j,k]: j*hidden1Dim + k
            mul.lo.u32 idx, j, hidden1Dim;
            add.u32 idx, idx, k;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w2_base, byte_offset;
            ld.global.f32 wt, [addr];
            // Load δ^2[j] from local memory
            mul.lo.u32 idx_byte, j, 4;
            cvta.local.u64 local_delta2_ptr, delta2_array;
            add.u64 addr, local_delta2_ptr, idx_byte;
            ld.local.f32 delta2, [addr];
            mul.f32 temp, wt, delta2;
            add.f32 delta1, delta1, temp;
            add.u32 j, j, 1;
            bra layer1_bp_inner_loop;
    layer1_bp_inner_end:
        // Apply ReLU derivative on a^(1)[k]:
        mul.lo.u32 idx_byte, k, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, h1_base, byte_offset;
        ld.global.f32 a_val, [addr];
        setp.le.f32 p_exit, a_val, 0f00000000;
        @p_exit mov.f32 delta1, 0f00000000;
        // Update bias b1[k]
        mul.f32 temp, gamma, delta1;
        neg.f32 temp, temp;
        mul.lo.u32 idx_byte, k, 4;
        cvt.u64.u32 byte_offset, idx_byte;
        add.u64 addr, b1_base, byte_offset;
        atom.global.add.f32 [addr], temp;
        // Update W1: for each input neuron l
        mov.u32 l, 0;
    layer1_bp_inner2_loop:
        setp.ge.u32 p_exit, l, inputDim;
        @p_exit bra layer1_bp_inner2_end;
            // Load input: index = sampleIdx*inputDim + l
            mul.lo.u32 idx, sampleIdx, inputDim;
            add.u32 idx, idx, l;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, in_base, byte_offset;
            ld.global.f32 in_val, [addr];
            mul.f32 temp, delta1, in_val;
            mul.f32 temp, gamma, temp;
            neg.f32 temp, temp;
            // Update weight W1[k,l]: index = k*inputDim + l
            mul.lo.u32 idx, k, inputDim;
            add.u32 idx, idx, l;
            mul.lo.u32 idx_byte, idx, 4;
            cvt.u64.u32 byte_offset, idx_byte;
            add.u64 addr, w1_base, byte_offset;
            atom.global.add.f32 [addr], temp;
            add.u32 l, l, 1;
            bra layer1_bp_inner2_loop;
    layer1_bp_inner2_end:
        add.u32 k, k, 1;
        bra layer1_bp_loop;
layer1_bp_end:

exit:
    ret;
}
"""
update_code = r"""
.version 7.0
.target sm_35
.address_size 64

.visible .entry trainStep_update(
    .param .u64 dW1_buf,   // [batch_size x (hidden1_dim * input_dim)]
    .param .u64 db1_buf,   // [batch_size x hidden1_dim]
    .param .u64 dW2_buf,   // [batch_size x (hidden2_dim * hidden1_dim)]
    .param .u64 db2_buf,   // [batch_size x hidden2_dim]
    .param .u64 dW3_buf,   // [batch_size x (output_dim * hidden2_dim)]
    .param .u64 db3_buf,   // [batch_size x output_dim]
    .param .u64 weights1_ptr, // W1: [hidden1_dim x input_dim]
    .param .u64 bias1_ptr,    // b1: [hidden1_dim]
    .param .u64 weights2_ptr, // W2: [hidden2_dim x hidden1_dim]
    .param .u64 bias2_ptr,    // b2: [hidden2_dim]
    .param .u64 weights3_ptr, // W3: [output_dim x hidden2_dim]
    .param .u64 bias3_ptr,    // b3: [output_dim]
    .param .u32 batch_size,
    .param .u32 input_dim,
    .param .u32 hidden1_dim,
    .param .u32 hidden2_dim,
    .param .u32 output_dim,
    .param .f32 learning_rate
)
{
    // ---------------------------------------------------------------
    // Declare registers.
    // ---------------------------------------------------------------
    .reg .u64   dW1_base, db1_base, dW2_base, db2_base, dW3_base, db3_base;
    .reg .u64   w1_base, b1_base, w2_base, b2_base, w3_base, b3_base;
    .reg .u32   bs, inputDim, hidden1Dim, hidden2Dim, outputDim;
    .reg .f32   lr;
    .reg .u32   idx, sample;
    .reg .f32   grad_sum, grad_avg;
    .reg .f32   old_weight;
    .reg .u64   byte_offset, addr;
    .reg .u32   i;
    .reg .pred  p_exit;

    // ---------------------------------------------------------------
    // Load parameters.
    // ---------------------------------------------------------------
    ld.param.u64 dW1_base, [dW1_buf];
    ld.param.u64 db1_base, [db1_buf];
    ld.param.u64 dW2_base, [dW2_buf];
    ld.param.u64 db2_base, [db2_buf];
    ld.param.u64 dW3_base, [dW3_buf];
    ld.param.u64 db3_base, [db3_buf];
    ld.param.u64 w1_base, [weights1_ptr];
    ld.param.u64 b1_base, [bias1_ptr];
    ld.param.u64 w2_base, [weights2_ptr];
    ld.param.u64 b2_base, [bias2_ptr];
    ld.param.u64 w3_base, [weights3_ptr];
    ld.param.u64 b3_base, [bias3_ptr];
    ld.param.u32 bs, [batch_size];
    ld.param.u32 inputDim, [input_dim];
    ld.param.u32 hidden1Dim, [hidden1_dim];
    ld.param.u32 hidden2Dim, [hidden2_dim];
    ld.param.u32 outputDim, [output_dim];
    ld.param.f32 lr, [learning_rate];

    // ---------------------------------------------------------------
    // Example: Update W1.
    // Loop over each element in W1 (total elements = hidden1_dim * input_dim).
    mov.u32 idx, 0;
W1_update_loop:
    setp.ge.u32 p_exit, idx, hidden1Dim * inputDim;
    @p_exit bra W1_update_end;
       // Initialize grad_sum = 0.
       mov.f32 grad_sum, 0f00000000;
       mov.u32 sample, 0;
    W1_reduce:
       setp.ge.u32 p_exit, sample, bs;
       @p_exit bra W1_reduce_end;
          // Compute address for element 'idx' in the gradient buffer for sample.
          // Each sample’s gradients are stored contiguously.
          mul.lo.u32 i, sample, hidden1Dim;
          mul.lo.u32 i, i, inputDim;
          add.u32 i, i, idx;
          mul.lo.u32 i, i, 4;
          cvt.u64.u32 byte_offset, i;
          add.u64 addr, dW1_base, byte_offset;
          ld.global.f32 temp, [addr];
          add.f32 grad_sum, grad_sum, temp;
          add.u32 sample, sample, 1;
          bra W1_reduce;
    W1_reduce_end:
       // Average gradient = grad_sum / bs.
       div.f32 grad_avg, grad_sum, __uint2float_rn(bs);
       // Update weight:
       mul.lo.u32 i, idx, 4;
       cvt.u64.u32 byte_offset, i;
       add.u64 addr, w1_base, byte_offset;
       ld.global.f32 old_weight, [addr];
       mul.f32 temp, lr, grad_avg;
       sub.f32 old_weight, old_weight, temp;
       st.global.f32 [addr], old_weight;
       add.u32 idx, idx, 1;
       bra W1_update_loop;
W1_update_end:
    // Similar update loops must be written for db1, W2, b2, W3, and b3.
    ret;
}
"""

Sadly, this doesn't work, so I instead implement it in CUDA to save time!

In [39]:
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import numpy as np

# -------------------------------
# Define network dimensions and training parameters.
# -------------------------------
batch_size   = 64    # samples per batch
num_batches  = 10    # number of batches to process
input_dim    = 784   # e.g. flattened MNIST images
hidden1_dim  = 128
hidden2_dim  = 64
output_dim   = 10    # number of classes
learning_rate = 0.01

# -------------------------------
# CUDA Kernels in CUDA C
# -------------------------------
kernel_code = r"""
extern "C" {

//
// Kernel: trainStep
// Each thread processes one sample: it computes the forward pass,
// computes the backpropagated gradients, and writes per-sample gradients
// into gradient buffers.
//
__global__ void trainStep(
    const float *input, const float *labels,
    float *hidden1, float *hidden2, float *output,
    const float *weights1, const float *bias1,
    const float *weights2, const float *bias2,
    const float *weights3, const float *bias3,
    float *grad_dW1, float *grad_db1,
    float *grad_dW2, float *grad_db2,
    float *grad_dW3, float *grad_db3,
    int batch_size, int input_dim, int hidden1_dim, int hidden2_dim, int output_dim)
{
    int sample = blockIdx.x * blockDim.x + threadIdx.x;
    if (sample >= batch_size) return;

    // Pointers to this sample's data.
    const float *x = input + sample * input_dim;
    const float *label = labels + sample * output_dim;
    float *h1 = hidden1 + sample * hidden1_dim;
    float *h2 = hidden2 + sample * hidden2_dim;
    float *out = output + sample * output_dim;

    // Pointers to gradient buffers for this sample.
    float *gradW1 = grad_dW1 + sample * hidden1_dim * input_dim;
    float *gradb1 = grad_db1 + sample * hidden1_dim;
    float *gradW2 = grad_dW2 + sample * hidden2_dim * hidden1_dim;
    float *gradb2 = grad_db2 + sample * hidden2_dim;
    float *gradW3 = grad_dW3 + sample * output_dim * hidden2_dim;
    float *gradb3 = grad_db3 + sample * output_dim;

    // --------------------
    // Forward Pass
    // --------------------
    // Layer 1: h1 = ReLU(W1*x + bias1)
    for (int i = 0; i < hidden1_dim; i++) {
        float sum = bias1[i];
        for (int j = 0; j < input_dim; j++) {
            sum += weights1[i * input_dim + j] * x[j];
        }
        h1[i] = (sum > 0) ? sum : 0;
    }

    // Layer 2: h2 = ReLU(W2*h1 + bias2)
    for (int i = 0; i < hidden2_dim; i++) {
        float sum = bias2[i];
        for (int j = 0; j < hidden1_dim; j++) {
            sum += weights2[i * hidden1_dim + j] * h1[j];
        }
        h2[i] = (sum > 0) ? sum : 0;
    }

    // Output layer: out = W3*h2 + bias3  (logits)
    for (int i = 0; i < output_dim; i++) {
        float sum = bias3[i];
        for (int j = 0; j < hidden2_dim; j++) {
            sum += weights3[i * hidden2_dim + j] * h2[j];
        }
        out[i] = sum;
    }

    // --------------------
    // Backpropagation
    // --------------------
    // Compute delta3 = out - label.
    float delta3[16]; // output_dim is 10 (16 is a safe upper bound)
    for (int i = 0; i < output_dim; i++) {
        delta3[i] = out[i] - label[i];
        gradb3[i] = delta3[i];
    }
    // Gradients for weights3: outer product delta3 * h2.
    for (int i = 0; i < output_dim; i++) {
        for (int j = 0; j < hidden2_dim; j++) {
            gradW3[i * hidden2_dim + j] = delta3[i] * h2[j];
        }
    }

    // Backprop for layer 2: delta2 = (W3^T * delta3) .* ReLU'(h2)
    float delta2[64]; // hidden2_dim is 64.
    for (int i = 0; i < hidden2_dim; i++) {
        float sum = 0;
        for (int j = 0; j < output_dim; j++) {
            sum += weights3[j * hidden2_dim + i] * delta3[j];
        }
        delta2[i] = (h2[i] > 0) ? sum : 0;
        gradb2[i] = delta2[i];
    }
    // Gradients for weights2: outer product delta2 * h1.
    for (int i = 0; i < hidden2_dim; i++) {
        for (int j = 0; j < hidden1_dim; j++) {
            gradW2[i * hidden1_dim + j] = delta2[i] * h1[j];
        }
    }

    // Backprop for layer 1: delta1 = (W2^T * delta2) .* ReLU'(h1)
    float delta1[128]; // hidden1_dim is 128.
    for (int i = 0; i < hidden1_dim; i++) {
        float sum = 0;
        for (int j = 0; j < hidden2_dim; j++) {
            sum += weights2[j * hidden1_dim + i] * delta2[j];
        }
        delta1[i] = (h1[i] > 0) ? sum : 0;
        gradb1[i] = delta1[i];
    }
    // Gradients for weights1: outer product delta1 * x.
    for (int i = 0; i < hidden1_dim; i++) {
        for (int j = 0; j < input_dim; j++) {
            gradW1[i * input_dim + j] = delta1[i] * x[j];
        }
    }
}

//
// Kernel: updateWeights
// For a given parameter array (or bias vector), this kernel sums the gradient
// values over the batch, averages them, and then updates the parameter.
//
__global__ void updateWeights(
    float *param, const float *grad_buffer,
    int count, int batch_size, float learning_rate)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= count) return;
    float sum = 0;
    for (int i = 0; i < batch_size; i++) {
        sum += grad_buffer[i * count + idx];
    }
    float grad_avg = sum / batch_size;

    // Optionally clip the gradient (uncomment the following lines if desired):
    // if (grad_avg > 1.0f) grad_avg = 1.0f;
    // if (grad_avg < -1.0f) grad_avg = -1.0f;

    param[idx] -= learning_rate * grad_avg;
}

} // extern "C"
"""

# -------------------------------
# Compile the CUDA kernels.
# -------------------------------
mod = SourceModule(kernel_code)
trainStep = mod.get_function("trainStep")
updateWeights = mod.get_function("updateWeights")

# -------------------------------
# Utility function to generate a random batch.
# -------------------------------
def generate_batch_data(batch_size, input_dim, output_dim):
    x = np.random.randn(batch_size, input_dim).astype(np.float32)
    labels = np.zeros((batch_size, output_dim), dtype=np.float32)
    for i in range(batch_size):
        lbl = np.random.randint(0, output_dim)
        labels[i, lbl] = 1.0
    return x, labels

# -------------------------------
# Initialize weight matrices and biases with small values.
# -------------------------------
scale = 0.01
weights1 = (np.random.randn(hidden1_dim, input_dim) * scale).astype(np.float32)
bias1    = (np.random.randn(hidden1_dim) * scale).astype(np.float32)
weights2 = (np.random.randn(hidden2_dim, hidden1_dim) * scale).astype(np.float32)
bias2    = (np.random.randn(hidden2_dim) * scale).astype(np.float32)
weights3 = (np.random.randn(output_dim, hidden2_dim) * scale).astype(np.float32)
bias3    = (np.random.randn(output_dim) * scale).astype(np.float32)

# -------------------------------
# Allocate GPU memory for weights and biases.
# -------------------------------
weights1_gpu = cuda.mem_alloc(weights1.nbytes)
bias1_gpu    = cuda.mem_alloc(bias1.nbytes)
weights2_gpu = cuda.mem_alloc(weights2.nbytes)
bias2_gpu    = cuda.mem_alloc(bias2.nbytes)
weights3_gpu = cuda.mem_alloc(weights3.nbytes)
bias3_gpu    = cuda.mem_alloc(bias3.nbytes)
cuda.memcpy_htod(weights1_gpu, weights1)
cuda.memcpy_htod(bias1_gpu, bias1)
cuda.memcpy_htod(weights2_gpu, weights2)
cuda.memcpy_htod(bias2_gpu, bias2)
cuda.memcpy_htod(weights3_gpu, weights3)
cuda.memcpy_htod(bias3_gpu, bias3)

# -------------------------------
# Allocate buffers for activations and outputs.
# -------------------------------
hidden1_gpu = cuda.mem_alloc(batch_size * hidden1_dim * np.float32().nbytes)
hidden2_gpu = cuda.mem_alloc(batch_size * hidden2_dim * np.float32().nbytes)
output_gpu  = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)

# -------------------------------
# Allocate buffers for input and labels (reused per batch).
# -------------------------------
input_gpu   = cuda.mem_alloc(batch_size * input_dim * np.float32().nbytes)
labels_gpu  = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)

# -------------------------------
# Allocate gradient buffers.
# For simplicity we use one set (each of size: batch_size * (parameter size)).
# -------------------------------
grad_dW1_gpu = cuda.mem_alloc(batch_size * hidden1_dim * input_dim * np.float32().nbytes)
grad_db1_gpu = cuda.mem_alloc(batch_size * hidden1_dim * np.float32().nbytes)
grad_dW2_gpu = cuda.mem_alloc(batch_size * hidden2_dim * hidden1_dim * np.float32().nbytes)
grad_db2_gpu = cuda.mem_alloc(batch_size * hidden2_dim * np.float32().nbytes)
grad_dW3_gpu = cuda.mem_alloc(batch_size * output_dim * hidden2_dim * np.float32().nbytes)
grad_db3_gpu = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)

# -------------------------------
# Define grid and block dimensions for trainStep.
# Each thread handles one sample.
# -------------------------------
threads_per_block = 256
grid_dim = ((batch_size + threads_per_block - 1) // threads_per_block, 1, 1)
block_dim = (threads_per_block, 1, 1)

# -------------------------------
# Training loop.
# For each batch: generate data, run forward/backprop kernel, then update weights.
# -------------------------------
for batch in range(num_batches):
    # Generate batch data.
    x_batch, labels_batch = generate_batch_data(batch_size, input_dim, output_dim)
    cuda.memcpy_htod(input_gpu, x_batch)
    cuda.memcpy_htod(labels_gpu, labels_batch)
    
    # Launch the forward/backprop kernel.
    trainStep(
        input_gpu, labels_gpu,
        hidden1_gpu, hidden2_gpu, output_gpu,
        weights1_gpu, bias1_gpu,
        weights2_gpu, bias2_gpu,
        weights3_gpu, bias3_gpu,
        grad_dW1_gpu, grad_db1_gpu,
        grad_dW2_gpu, grad_db2_gpu,
        grad_dW3_gpu, grad_db3_gpu,
        np.int32(batch_size), np.int32(input_dim), np.int32(hidden1_dim),
        np.int32(hidden2_dim), np.int32(output_dim),
        block=block_dim, grid=grid_dim
    )
    
    # For each parameter array, run the update kernel.
    # updateWeights will sum over the batch dimension and update the parameter.
    threads = 256

    # Update weights1.
    count_w1 = np.int32(hidden1_dim * input_dim)
    grid = ((int(count_w1) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights1_gpu, grad_dW1_gpu,
        count_w1, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    # Update bias1.
    count_b1 = np.int32(hidden1_dim)
    grid = ((int(count_b1) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias1_gpu, grad_db1_gpu,
        count_b1, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    # Update weights2.
    count_w2 = np.int32(hidden2_dim * hidden1_dim)
    grid = ((int(count_w2) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights2_gpu, grad_dW2_gpu,
        count_w2, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    # Update bias2.
    count_b2 = np.int32(hidden2_dim)
    grid = ((int(count_b2) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias2_gpu, grad_db2_gpu,
        count_b2, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    # Update weights3.
    count_w3 = np.int32(output_dim * hidden2_dim)
    grid = ((int(count_w3) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights3_gpu, grad_dW3_gpu,
        count_w3, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    # Update bias3.
    count_b3 = np.int32(output_dim)
    grid = ((int(count_b3) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias3_gpu, grad_db3_gpu,
        count_b3, np.int32(batch_size), np.float32(learning_rate),
        block=(threads, 1, 1), grid=grid
    )
    
    print("Completed batch", batch+1)

# -------------------------------
# Copy updated weights back to host for verification.
# -------------------------------
cuda.memcpy_dtoh(weights1, weights1_gpu)
cuda.memcpy_dtoh(bias1, bias1_gpu)
cuda.memcpy_dtoh(weights2, weights2_gpu)
cuda.memcpy_dtoh(bias2, bias2_gpu)
cuda.memcpy_dtoh(weights3, weights3_gpu)
cuda.memcpy_dtoh(bias3, bias3_gpu)

print("Training complete.")
print("Updated weights and biases:")
print("W1:", weights1)
print("b1:", bias1)
print("W2:", weights2)
print("b2:", bias2)
print("W3:", weights3)
print("b3:", bias3)


Completed batch 1
Completed batch 2
Completed batch 3
Completed batch 4
Completed batch 5
Completed batch 6
Completed batch 7
Completed batch 8
Completed batch 9
Completed batch 10
Training complete.
Updated weights and biases:
W1: [[ 0.00437199 -0.00232274  0.01523101 ...  0.00399005  0.00060943
  -0.00379807]
 [ 0.00448643 -0.0070263   0.01537438 ...  0.00932618  0.00631156
   0.00498502]
 [-0.0053929   0.00048365 -0.0018097  ...  0.00558358  0.00576927
   0.00641882]
 ...
 [ 0.00283663 -0.00596519  0.01930408 ...  0.01084203  0.00051508
   0.00443041]
 [-0.00656595  0.00544863 -0.01342957 ...  0.00244525  0.00281238
   0.00508942]
 [-0.00308127  0.003688    0.00730584 ... -0.00074348 -0.00059138
   0.01014786]]
b1: [ 3.33784777e-03  1.61030708e-04 -2.57293601e-03  1.59598663e-02
 -6.32680440e-03 -6.16802415e-03  9.86100826e-03  7.18621491e-03
  4.82773408e-03  1.93283800e-03 -5.34211705e-03 -8.81280005e-03
 -2.07232051e-02  5.43789798e-03 -2.43395261e-04  1.27090816e-03
  8.52138549

This worked! So we start training now! We'll train until it reaches 60% accuracy on the data. Yes I know this isn't the best practice because we need to incorporate validation data, but the point of this exercise is to write it in cuda not to have a full machine learning work flow!

In [None]:
import time
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

############################################
# Settings and Data Loading (Common to both)
############################################
batch_size   = 64
input_dim    = 784   # MNIST: 28x28 flattened
hidden_dim   = 128   # one hidden layer
output_dim   = 10
learning_rate = 0.001   # try a lower learning rate for stability
max_batches = 5000      # maximum batches to run
scale = 0.01            # weight initialization scale

# Download and prepare MNIST dataset.
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Helper: Convert integer labels to one-hot vectors.
def one_hot(labels, num_classes=10):
    one_hot_labels = np.zeros((labels.shape[0], num_classes), dtype=np.float32)
    one_hot_labels[np.arange(labels.shape[0]), labels] = 1.0
    return one_hot_labels

# Simple NumPy forward pass (for evaluation) for a two-layer network.
def forward_numpy(x, weights, bias, weights_out, bias_out):
    # hidden layer: ReLU(W*x + b)
    h = np.maximum(0, np.dot(x, weights.T) + bias)
    # output layer: linear activation
    out = np.dot(h, weights_out.T) + bias_out
    return out

# Evaluate accuracy using NumPy forward pass.
def evaluate_model(weights, bias, weights_out, bias_out):
    correct = 0
    total = 0
    for images, labels in test_loader:
        x = images.numpy()  # shape: [batch_size, 784]
        out = forward_numpy(x, weights, bias, weights_out, bias_out)
        preds = np.argmax(out, axis=1)
        total += labels.size(0)
        correct += (preds == labels.numpy()).sum()
    return correct / total

############################################
# 1. PyCUDA Implementation (Two-Layer Network)
############################################
cuda_kernel_code = r"""
extern "C" {

// Kernel: trainStep
// Each thread processes one sample. It performs the forward pass,
// computes backpropagated gradients using squared-error loss, and writes
// per-sample gradients into separate gradient buffers.
__global__ void trainStep(
    const float *input, const float *labels,
    float *hidden, float *output,
    const float *weights, const float *bias,
    const float *weights_out, const float *bias_out,
    float *grad_dW, float *grad_db,
    float *grad_dW_out, float *grad_db_out,
    int batch_size, int input_dim, int hidden_dim, int output_dim)
{
    int sample = blockIdx.x * blockDim.x + threadIdx.x;
    if(sample >= batch_size) return;

    // Pointers for this sample.
    const float *x = input + sample * input_dim;
    const float *label = labels + sample * output_dim;
    float *h = hidden + sample * hidden_dim;
    float *out = output + sample * output_dim;

    // Pointers to per-sample gradient buffers.
    float *gradW = grad_dW + sample * hidden_dim * input_dim;
    float *gradb = grad_db + sample * hidden_dim;
    float *gradW_out = grad_dW_out + sample * output_dim * hidden_dim;
    float *gradb_out = grad_db_out + sample * output_dim;

    // Forward pass: Compute hidden layer activations.
    for (int i = 0; i < hidden_dim; i++) {
        float sum = bias[i];
        for (int j = 0; j < input_dim; j++) {
            sum += weights[i * input_dim + j] * x[j];
        }
        h[i] = (sum > 0) ? sum : 0;  // ReLU activation.
    }
    // Output layer: Compute output.
    for (int i = 0; i < output_dim; i++) {
        float sum = bias_out[i];
        for (int j = 0; j < hidden_dim; j++) {
            sum += weights_out[i * hidden_dim + j] * h[j];
        }
        out[i] = sum;
    }

    // Backpropagation.
    // Compute delta for output: delta_out = out - label.
    float delta_out[16]; // assume output_dim <= 16.
    for (int i = 0; i < output_dim; i++) {
        delta_out[i] = out[i] - label[i];
        gradb_out[i] = delta_out[i];
    }
    // Gradients for output weights: gradW_out = delta_out * h.
    for (int i = 0; i < output_dim; i++) {
        for (int j = 0; j < hidden_dim; j++) {
            gradW_out[i * hidden_dim + j] = delta_out[i] * h[j];
        }
    }
    // Backpropagate to hidden layer:
    // delta_hidden = (weights_out^T * delta_out) * (h > 0 ? 1 : 0).
    float delta_hidden[128]; // assume hidden_dim <= 128.
    for (int i = 0; i < hidden_dim; i++) {
        float sum = 0;
        for (int j = 0; j < output_dim; j++) {
            sum += weights_out[j * hidden_dim + i] * delta_out[j];
        }
        delta_hidden[i] = (h[i] > 0) ? sum : 0;
        gradb[i] = delta_hidden[i];
    }
    // Gradients for input weights: gradW = delta_hidden * x.
    for (int i = 0; i < hidden_dim; i++) {
        for (int j = 0; j < input_dim; j++) {
            gradW[i * input_dim + j] = delta_hidden[i] * x[j];
        }
    }
}

// Kernel: updateWeights
// For a given parameter array and its per-sample gradient buffer,
// sum the gradients over the batch, average them, and update the parameter.
__global__ void updateWeights(
    float *param, const float *grad_buffer,
    int count, int batch_size, float learning_rate)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= count) return;
    float sum = 0;
    for (int i = 0; i < batch_size; i++) {
        sum += grad_buffer[i * count + idx];
    }
    float grad_avg = sum / batch_size;
    param[idx] -= learning_rate * grad_avg;
}

} // extern "C"
"""

# Compile the CUDA kernels.
mod = SourceModule(cuda_kernel_code)
trainStep = mod.get_function("trainStep")
updateWeights = mod.get_function("updateWeights")

# Initialize network parameters for PyCUDA.
weights_cuda = (np.random.randn(hidden_dim, input_dim) * scale).astype(np.float32)
bias_cuda    = (np.random.randn(hidden_dim) * scale).astype(np.float32)
weights_out_cuda = (np.random.randn(output_dim, hidden_dim) * scale).astype(np.float32)
bias_out_cuda    = (np.random.randn(output_dim) * scale).astype(np.float32)

# Allocate GPU memory for parameters.
weights_gpu = cuda.mem_alloc(weights_cuda.nbytes)
bias_gpu    = cuda.mem_alloc(bias_cuda.nbytes)
weights_out_gpu = cuda.mem_alloc(weights_out_cuda.nbytes)
bias_out_gpu    = cuda.mem_alloc(bias_out_cuda.nbytes)
cuda.memcpy_htod(weights_gpu, weights_cuda)
cuda.memcpy_htod(bias_gpu, bias_cuda)
cuda.memcpy_htod(weights_out_gpu, weights_out_cuda)
cuda.memcpy_htod(bias_out_gpu, bias_out_cuda)

# Allocate GPU memory for activations, outputs, inputs, labels, and gradients.
hidden_gpu   = cuda.mem_alloc(batch_size * hidden_dim * np.float32().nbytes)
output_gpu   = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)
input_gpu    = cuda.mem_alloc(batch_size * input_dim * np.float32().nbytes)
labels_gpu   = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)
grad_dW_gpu  = cuda.mem_alloc(batch_size * hidden_dim * input_dim * np.float32().nbytes)
grad_db_gpu  = cuda.mem_alloc(batch_size * hidden_dim * np.float32().nbytes)
grad_dW_out_gpu = cuda.mem_alloc(batch_size * output_dim * hidden_dim * np.float32().nbytes)
grad_db_out_gpu = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)

# Define grid and block dimensions.
threads_per_block = 256
grid_dim = ((batch_size + threads_per_block - 1) // threads_per_block, 1, 1)
block_dim = (threads_per_block, 1, 1)

# PyCUDA Training Loop: Train until test accuracy exceeds 60% or max_batches reached.
print("Starting PyCUDA Training (Two-Layer Network)...")
cuda_start = time.time()
batch_count_cuda = 0
accuracy_cuda = 0
train_iter = iter(train_loader)
while accuracy_cuda < 0.6 and batch_count_cuda < max_batches:
    try:
        images, labels = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        images, labels = next(train_iter)
    x_batch = images.numpy()                     # shape: [batch_size, 784]
    y_batch = one_hot(labels.numpy(), output_dim)  # one-hot: [batch_size, 10]
    cuda.memcpy_htod(input_gpu, x_batch)
    cuda.memcpy_htod(labels_gpu, y_batch)
    # Launch trainStep kernel.
    trainStep(
        input_gpu, labels_gpu,
        hidden_gpu, output_gpu,
        weights_gpu, bias_gpu,
        weights_out_gpu, bias_out_gpu,
        grad_dW_gpu, grad_db_gpu,
        grad_dW_out_gpu, grad_db_out_gpu,
        np.int32(batch_size), np.int32(input_dim), np.int32(hidden_dim), np.int32(output_dim),
        block=block_dim, grid=grid_dim
    )
    # Launch updateWeights kernel for each parameter.
    threads = 256

    # Update input-to-hidden weights.
    count_w = np.int32(hidden_dim * input_dim)
    grid_update = ((int(count_w) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights_gpu, grad_dW_gpu,
        count_w, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update hidden biases.
    count_b = np.int32(hidden_dim)
    grid_update = ((int(count_b) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias_gpu, grad_db_gpu,
        count_b, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update hidden-to-output weights.
    count_w_out = np.int32(output_dim * hidden_dim)
    grid_update = ((int(count_w_out) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights_out_gpu, grad_dW_out_gpu,
        count_w_out, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update output biases.
    count_b_out = np.int32(output_dim)
    grid_update = ((int(count_b_out) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias_out_gpu, grad_db_out_gpu,
        count_b_out, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    batch_count_cuda += 1

    # Every 100 batches, evaluate test accuracy.
    if batch_count_cuda % 100 == 0:
        # Copy parameters back to host.
        cuda.memcpy_dtoh(weights_cuda, weights_gpu)
        cuda.memcpy_dtoh(bias_cuda, bias_gpu)
        cuda.memcpy_dtoh(weights_out_cuda, weights_out_gpu)
        cuda.memcpy_dtoh(bias_out_cuda, bias_out_gpu)
        accuracy_cuda = evaluate_model(weights_cuda, bias_cuda, weights_out_cuda, bias_out_cuda)
        print(f"PyCUDA: Batch {batch_count_cuda}, Test Accuracy: {accuracy_cuda*100:.2f}%")
cuda_end = time.time()
cuda_time = cuda_end - cuda_start
print(f"\nPyCUDA Training Finished in {cuda_time:.4f} seconds over {batch_count_cuda} batches.")
print(f"PyCUDA Test Accuracy: {accuracy_cuda*100:.2f}%")

############################################
# End of PyCUDA Section
############################################


Starting PyCUDA Training (Two-Layer Network)...
PyCUDA: Batch 100, Test Accuracy: 8.09%
PyCUDA: Batch 200, Test Accuracy: 11.60%
PyCUDA: Batch 300, Test Accuracy: 14.89%
PyCUDA: Batch 400, Test Accuracy: 17.08%
PyCUDA: Batch 500, Test Accuracy: 19.42%
PyCUDA: Batch 600, Test Accuracy: 22.38%
PyCUDA: Batch 700, Test Accuracy: 27.20%
PyCUDA: Batch 800, Test Accuracy: 34.27%
PyCUDA: Batch 900, Test Accuracy: 41.09%
PyCUDA: Batch 1000, Test Accuracy: 46.24%
PyCUDA: Batch 1100, Test Accuracy: 48.37%
PyCUDA: Batch 1200, Test Accuracy: 51.23%
PyCUDA: Batch 1300, Test Accuracy: 53.20%
PyCUDA: Batch 1400, Test Accuracy: 55.29%
PyCUDA: Batch 1500, Test Accuracy: 57.03%
PyCUDA: Batch 1600, Test Accuracy: 58.38%
PyCUDA: Batch 1700, Test Accuracy: 60.15%

PyCUDA Training Finished in 48.1470 seconds over 1700 batches.
PyCUDA Test Accuracy: 60.15%


Now we compare the code with a the same implementation in Pytorch!

In [46]:
import time
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim

###############################
# Common Settings and Data Loading
###############################
batch_size    = 64
input_dim     = 784   # 28x28 flattened
hidden_dim    = 128
output_dim    = 10
learning_rate = 0.001
max_batches   = 10000   # safety maximum
scale         = 0.01    # weight initialization scale

# Download MNIST dataset using torchvision.
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader   = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Helper: convert integer labels to one-hot vectors.
def one_hot(labels, num_classes=10):
    one_hot_labels = np.zeros((labels.shape[0], num_classes), dtype=np.float32)
    one_hot_labels[np.arange(labels.shape[0]), labels] = 1.0
    return one_hot_labels

# Simple NumPy forward pass for a two-layer network.
def forward_numpy(x, weights, bias, weights_out, bias_out):
    h = np.maximum(0, np.dot(x, weights.T) + bias)
    out = np.dot(h, weights_out.T) + bias_out
    return out

# Evaluate test accuracy using NumPy.
def evaluate_model(weights, bias, weights_out, bias_out):
    correct = 0
    total = 0
    for images, labels in test_loader:
        x = images.numpy()  # shape: [batch_size, 784]
        out = forward_numpy(x, weights, bias, weights_out, bias_out)
        preds = np.argmax(out, axis=1)
        total += labels.size(0)
        correct += (preds == labels.numpy()).sum()
    return correct / total

###############################
# 1. PyCUDA Implementation
###############################
# Two-layer network CUDA kernels (using squared error loss on one-hot targets).
cuda_kernel_code = r"""
extern "C" {

__global__ void trainStep(
    const float *input, const float *labels,
    float *hidden, float *output,
    const float *weights, const float *bias,
    const float *weights_out, const float *bias_out,
    float *grad_dW, float *grad_db,
    float *grad_dW_out, float *grad_db_out,
    int batch_size, int input_dim, int hidden_dim, int output_dim)
{
    int sample = blockIdx.x * blockDim.x + threadIdx.x;
    if(sample >= batch_size) return;

    // Pointers for this sample.
    const float *x = input + sample * input_dim;
    const float *label = labels + sample * output_dim;
    float *h = hidden + sample * hidden_dim;
    float *out = output + sample * output_dim;

    // Per-sample gradient pointers.
    float *gradW = grad_dW + sample * hidden_dim * input_dim;
    float *gradb = grad_db + sample * hidden_dim;
    float *gradW_out = grad_dW_out + sample * output_dim * hidden_dim;
    float *gradb_out = grad_db_out + sample * output_dim;

    // Forward pass: hidden layer.
    for (int i = 0; i < hidden_dim; i++) {
        float sum = bias[i];
        for (int j = 0; j < input_dim; j++) {
            sum += weights[i * input_dim + j] * x[j];
        }
        h[i] = (sum > 0) ? sum : 0;  // ReLU.
    }
    // Output layer.
    for (int i = 0; i < output_dim; i++) {
        float sum = bias_out[i];
        for (int j = 0; j < hidden_dim; j++) {
            sum += weights_out[i * hidden_dim + j] * h[j];
        }
        out[i] = sum;
    }

    // Backpropagation.
    // Compute delta for output: delta_out = out - label.
    float delta_out[16]; // assume output_dim <= 16.
    for (int i = 0; i < output_dim; i++) {
        delta_out[i] = out[i] - label[i];
        gradb_out[i] = delta_out[i];
    }
    // Gradients for output weights.
    for (int i = 0; i < output_dim; i++) {
        for (int j = 0; j < hidden_dim; j++) {
            gradW_out[i * hidden_dim + j] = delta_out[i] * h[j];
        }
    }
    // Backpropagate to hidden: delta_hidden = (W_out^T * delta_out) * (h>0)
    float delta_hidden[128]; // assume hidden_dim <= 128.
    for (int i = 0; i < hidden_dim; i++) {
        float sum = 0;
        for (int j = 0; j < output_dim; j++) {
            sum += weights_out[j * hidden_dim + i] * delta_out[j];
        }
        delta_hidden[i] = (h[i] > 0) ? sum : 0;
        gradb[i] = delta_hidden[i];
    }
    // Gradients for input-to-hidden weights.
    for (int i = 0; i < hidden_dim; i++) {
        for (int j = 0; j < input_dim; j++) {
            gradW[i * input_dim + j] = delta_hidden[i] * x[j];
        }
    }
}

__global__ void updateWeights(
    float *param, const float *grad_buffer,
    int count, int batch_size, float learning_rate)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= count) return;
    float sum = 0;
    for (int i = 0; i < batch_size; i++) {
        sum += grad_buffer[i * count + idx];
    }
    float grad_avg = sum / batch_size;
    param[idx] -= learning_rate * grad_avg;
}

} // extern "C"
"""

# Compile the CUDA kernels.
mod = SourceModule(cuda_kernel_code)
trainStep = mod.get_function("trainStep")
updateWeights = mod.get_function("updateWeights")

# Initialize network parameters for PyCUDA.
weights_cuda = (np.random.randn(hidden_dim, input_dim) * scale).astype(np.float32)
bias_cuda    = (np.random.randn(hidden_dim) * scale).astype(np.float32)
weights_out_cuda = (np.random.randn(output_dim, hidden_dim) * scale).astype(np.float32)
bias_out_cuda    = (np.random.randn(output_dim) * scale).astype(np.float32)

# Allocate GPU memory for parameters.
weights_gpu = cuda.mem_alloc(weights_cuda.nbytes)
bias_gpu    = cuda.mem_alloc(bias_cuda.nbytes)
weights_out_gpu = cuda.mem_alloc(weights_out_cuda.nbytes)
bias_out_gpu    = cuda.mem_alloc(bias_out_cuda.nbytes)
cuda.memcpy_htod(weights_gpu, weights_cuda)
cuda.memcpy_htod(bias_gpu, bias_cuda)
cuda.memcpy_htod(weights_out_gpu, weights_out_cuda)
cuda.memcpy_htod(bias_out_gpu, bias_out_cuda)

# Allocate GPU memory for activations, outputs, inputs, labels, and gradients.
hidden_gpu      = cuda.mem_alloc(batch_size * hidden_dim * np.float32().nbytes)
output_gpu      = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)
input_gpu       = cuda.mem_alloc(batch_size * input_dim * np.float32().nbytes)
labels_gpu      = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)
grad_dW_gpu     = cuda.mem_alloc(batch_size * hidden_dim * input_dim * np.float32().nbytes)
grad_db_gpu     = cuda.mem_alloc(batch_size * hidden_dim * np.float32().nbytes)
grad_dW_out_gpu = cuda.mem_alloc(batch_size * output_dim * hidden_dim * np.float32().nbytes)
grad_db_out_gpu = cuda.mem_alloc(batch_size * output_dim * np.float32().nbytes)

# Define grid and block dimensions.
threads_per_block = 256
grid_dim = ((batch_size + threads_per_block - 1) // threads_per_block, 1, 1)
block_dim = (threads_per_block, 1, 1)

print("\nStarting PyCUDA Training (Two-Layer Network)...")
cuda_start = time.time()
batch_count_cuda = 0
accuracy_cuda = 0
train_iter = iter(train_loader)
while accuracy_cuda < 0.6 and batch_count_cuda < max_batches:
    try:
        images, labels = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        images, labels = next(train_iter)
    x_batch = images.numpy()                     # shape: [batch_size, 784]
    y_batch = one_hot(labels.numpy(), output_dim)  # one-hot encoding.
    cuda.memcpy_htod(input_gpu, x_batch)
    cuda.memcpy_htod(labels_gpu, y_batch)
    # Launch forward/backprop kernel.
    trainStep(
        input_gpu, labels_gpu,
        hidden_gpu, output_gpu,
        weights_gpu, bias_gpu,
        weights_out_gpu, bias_out_gpu,
        grad_dW_gpu, grad_db_gpu,
        grad_dW_out_gpu, grad_db_out_gpu,
        np.int32(batch_size), np.int32(input_dim), np.int32(hidden_dim), np.int32(output_dim),
        block=block_dim, grid=grid_dim
    )
    # Launch updateWeights for each parameter.
    threads = 256

    # Update weights (input-to-hidden).
    count_w = np.int32(hidden_dim * input_dim)
    grid_update = ((int(count_w) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights_gpu, grad_dW_gpu,
        count_w, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update hidden biases.
    count_b = np.int32(hidden_dim)
    grid_update = ((int(count_b) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias_gpu, grad_db_gpu,
        count_b, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update weights (hidden-to-output).
    count_w_out = np.int32(output_dim * hidden_dim)
    grid_update = ((int(count_w_out) + threads - 1) // threads, 1, 1)
    updateWeights(
        weights_out_gpu, grad_dW_out_gpu,
        count_w_out, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    # Update output biases.
    count_b_out = np.int32(output_dim)
    grid_update = ((int(count_b_out) + threads - 1) // threads, 1, 1)
    updateWeights(
        bias_out_gpu, grad_db_out_gpu,
        count_b_out, np.int32(batch_size), np.float32(learning_rate),
        block=(threads,1,1), grid=grid_update
    )
    batch_count_cuda += 1

    # Every 100 batches, evaluate test accuracy.
    if batch_count_cuda % 100 == 0:
        cuda.memcpy_dtoh(weights_cuda, weights_gpu)
        cuda.memcpy_dtoh(bias_cuda, bias_gpu)
        cuda.memcpy_dtoh(weights_out_cuda, weights_out_gpu)
        cuda.memcpy_dtoh(bias_out_cuda, bias_out_gpu)
        accuracy_cuda = evaluate_model(weights_cuda, bias_cuda, weights_out_cuda, bias_out_cuda)
        print(f"PyCUDA: Batch {batch_count_cuda}, Test Accuracy: {accuracy_cuda*100:.2f}%")
cuda_end = time.time()
cuda_time = cuda_end - cuda_start
print(f"\nPyCUDA Training Finished in {cuda_time:.4f} seconds over {batch_count_cuda} batches.")
print(f"PyCUDA Test Accuracy: {accuracy_cuda*100:.2f}%")

###############################
# 2. PyTorch Implementation
###############################
print("\nStarting PyTorch Training...")

# Define a two-layer network in PyTorch with identical architecture.
class TwoLayerNet(nn.Module):
    def __init__(self):
        super(TwoLayerNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=True)
    def forward(self, x):
        x = x.view(-1, input_dim)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Use MSELoss on one-hot targets to mimic the custom kernel.
def to_one_hot(labels, num_classes=10):
    one_hot = torch.zeros(labels.size(0), num_classes, device=labels.device)
    one_hot.scatter_(1, labels.view(-1,1), 1.0)
    return one_hot

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TwoLayerNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

torch_start = time.time()
batch_count_torch = 0
accuracy_torch = 0
train_iter_torch = iter(train_loader)
while accuracy_torch < 0.6 and batch_count_torch < max_batches:
    try:
        images, labels = next(train_iter_torch)
    except StopIteration:
        train_iter_torch = iter(train_loader)
        images, labels = next(train_iter_torch)
    images = images.to(device)
    labels = labels.to(device)
    # Convert labels to one-hot.
    one_hot_labels = to_one_hot(labels, output_dim)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, one_hot_labels)
    loss.backward()
    optimizer.step()
    batch_count_torch += 1
    if batch_count_torch % 100 == 0:
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                preds = torch.argmax(outputs, dim=1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()
        accuracy_torch = correct / total
        print(f"PyTorch: Batch {batch_count_torch}, Test Accuracy: {accuracy_torch*100:.2f}%")
        model.train()
torch_end = time.time()
torch_time = torch_end - torch_start
print(f"\nPyTorch Training Finished in {torch_time:.4f} seconds over {batch_count_torch} batches.")
print(f"PyTorch Test Accuracy: {accuracy_torch*100:.2f}%")



Starting PyCUDA Training (Two-Layer Network)...
PyCUDA: Batch 100, Test Accuracy: 13.01%
PyCUDA: Batch 200, Test Accuracy: 14.63%
PyCUDA: Batch 300, Test Accuracy: 17.43%
PyCUDA: Batch 400, Test Accuracy: 22.13%
PyCUDA: Batch 500, Test Accuracy: 27.75%
PyCUDA: Batch 600, Test Accuracy: 35.57%
PyCUDA: Batch 700, Test Accuracy: 42.39%
PyCUDA: Batch 800, Test Accuracy: 48.74%
PyCUDA: Batch 900, Test Accuracy: 51.74%
PyCUDA: Batch 1000, Test Accuracy: 55.54%
PyCUDA: Batch 1100, Test Accuracy: 58.98%
PyCUDA: Batch 1200, Test Accuracy: 61.64%

PyCUDA Training Finished in 33.4456 seconds over 1200 batches.
PyCUDA Test Accuracy: 61.64%

Starting PyTorch Training...
PyTorch: Batch 100, Test Accuracy: 12.18%
PyTorch: Batch 200, Test Accuracy: 13.38%
PyTorch: Batch 300, Test Accuracy: 15.13%
PyTorch: Batch 400, Test Accuracy: 17.71%
PyTorch: Batch 500, Test Accuracy: 19.94%
PyTorch: Batch 600, Test Accuracy: 22.34%
PyTorch: Batch 700, Test Accuracy: 24.76%
PyTorch: Batch 800, Test Accuracy: 26.9

# 4. Comparison with PyTorch

The PyTorch implementation was set up to match the architecture and training procedure of our custom CUDA solution. In this setup:

- **Architecture:**  
  A two-layer network is used with one hidden layer of $128$ neurons (with ReLU activation) and an output layer of $10$ neurons. This exactly mirrors the network used in the PyCUDA implementation.

- **Loss Function and Update Rule:**  
  Both implementations use the Mean Squared Error (MSE) loss on one-hot encoded targets. The PyTorch model computes gradients on a per-mini-batch basis and updates parameters with a full batch gradient update, just as in the CUDA kernel that averages per-sample gradients over the mini-batch.

- **Training Details:**  
  Both models are trained on MNIST data with a batch size of $64$. The training loop in PyTorch updates the network after each mini-batch and evaluates test accuracy periodically (every 100 batches) until the test accuracy exceeds 60% or a maximum number of batches is reached.

- **Results:**  
  The results observed were as follows:
  - **PyTorch:**  
    - At Batch 100, Test Accuracy: 12.18%  
    - At Batch 200, Test Accuracy: 13.38%  
    - At Batch 300, Test Accuracy: 15.13%  
    - At Batch 400, Test Accuracy: 17.71%  
    - ...  
    - At Batch 4200, Test Accuracy: 60.33%  
    - Total Training Time: 67.6265 seconds over 4200 batches  
  - **PyCUDA:**  
    - At Batch 100, Test Accuracy: 13.01%  
    - At Batch 200, Test Accuracy: 14.63%  
    - At Batch 300, Test Accuracy: 17.43%  
    - At Batch 400, Test Accuracy: 22.13%  
    - ...  
    - At Batch 1200, Test Accuracy: 61.64%  
    - Total Training Time: 33.4456 seconds over 1200 batches  

**Analysis:**  
Both methods are trained on the same amount of data and use an identical network architecture and loss function. The PyTorch model, while being easier to code and more robust, took more batches and nearly twice the time to reach about 60% accuracy compared to the custom PyCUDA solution. This highlights that, with careful low-level optimization, custom CUDA code can offer performance benefits; however, such implementations require a more complex development process compared to high-level frameworks like PyTorch.
