In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import numpy as np
import simplegrad as sg
import torch
from simplegrad import Tensor, ops

## Develop Ops

In [4]:
def log(tensor):
    out = Tensor(np.log(tensor.data), requires_grad=tensor.requires_grad)

    def _backward():
        tensor.grad += out.grad / tensor.data

    out._backward = _backward
    out._prev = {
        tensor,
    }
    return out


def exp(tensor):
    out = Tensor(np.exp(tensor.data), requires_grad=tensor.requires_grad)

    def _backward():
        if tensor.requires_grad:
            tensor.grad += out.data * out.grad

    out._backward = _backward
    out._prev = {
        tensor,
    }
    return out


def summation(tensor, axis=None, keepdims=False):
    """
    local_grad = d.sum(x) / d.xi = 1.
    therefore derivative of x is basically just out.grad
    broadcasted to the shape of the input tensor.
    """
    out = Tensor(
        np.sum(tensor.data, axis=axis, keepdims=keepdims),
        requires_grad=tensor.requires_grad,
    )

    def _backward():
        if tensor.requires_grad:
            input_shape, axes = tensor.data.shape, axis

            if not keepdims:
                if axis is None:  # if self.axes is None, take sum over all axes.
                    axes = tuple(i for i in range(len(input_shape)))
                elif isinstance(axis, int):
                    axes = (axis,)

                shape_range = range(len(input_shape))
                mask = np.array([0 if i in axes else 1 for i in shape_range])
                new_shape = np.array(input_shape) * mask + (1 - mask)
                grad = np.reshape(out.grad, new_shape)
                grad = np.broadcast_to(grad, input_shape)
            else:
                grad = np.broadcast_to(out.grad, input_shape)

            tensor.grad += grad

    out._backward = _backward
    out._prev = {
        tensor,
    }
    return out


def broadcast_to(tensor, shape):
    """this is interestingly the reverse of summation."""
    if tensor.shape == shape:  # Optimization: no-op if shapes match
        return tensor

    out_data = np.broadcast_to(tensor.data, shape)
    out = Tensor(out_data, requires_grad=tensor.requires_grad)

    input_shape = tensor.shape  # Capture input shape for backward pass

    def _backward():
        if tensor.requires_grad:
            ishape, oshape = tensor.data.shape, out.grad.shape
            ## in = (3, 1, 4), out = (3, 5, 4) -> aligned = (3, 1, 4)
            ## i think numpy only implicitly broadcast to prefix dims :/
            aligned = [1] * (len(oshape) - len(ishape)) + list(ishape)
            broadcast_axes = tuple([i for i, axis in enumerate(aligned) if axis == 1])
            grad = np.sum(out.grad, axis=broadcast_axes, keepdims=True)
            grad = np.reshape(grad, ishape)

            tensor.grad += grad

    out._backward = _backward
    out._prev = {
        tensor,
    }
    return out


def logsumexp(tensor, axis=None, keepdims=False):
    """
    mathematical operations, applied to 1D vector:
    forward: log(e^z1 + e^z2 + ... + e^zn) = sum(e^zi)
    backward: local_grad[i] = e^zi / sum(e^zi)
    ------
    for numerical stability:
    forward: log(sum(e^zi))  = log(sum(e^(zi - zmax)) * e^zmax)
                             = log(sum(e^(zi - zmax))) + log(e^zmax)
    backward: e^zi/sum(e^zi) = e^(zi - zmax) / sum(e^(zi - zmax))
    """
    max_z = np.max(tensor.data, axis=axis, keepdims=True)
    stable_z = tensor.data - max_z
    exp_stable_z = np.exp(stable_z)
    stable_sum = np.sum(exp_stable_z, axis=axis, keepdims=keepdims)
    max_term = max_z if keepdims else np.squeeze(max_z, axis=axis)
    data = np.log(stable_sum) + max_term
    out = Tensor(data, requires_grad=tensor.requires_grad)

    def _backward():
        if tensor.requires_grad:
            if axis is None:
                # For None axis, basically all dims.
                if not keepdims:
                    grad_shaped = out.grad * np.ones_like(tensor.data)
                    softmax_terms = exp_stable_z / np.sum(exp_stable_z)
                    tensor.grad += grad_shaped * softmax_terms
                else:
                    # keepdims=True with axis=None
                    softmax_terms = exp_stable_z / stable_sum
                    tensor.grad += out.grad * softmax_terms
            else:
                # For specific axis reduction
                grad_shaped = out.grad
                if not keepdims:
                    grad_shaped = np.expand_dims(grad_shaped, axis=axis)

                denom = (
                    stable_sum if keepdims else np.expand_dims(stable_sum, axis=axis)
                )
                softmax_terms = exp_stable_z / denom
                tensor.grad += grad_shaped * softmax_terms

    out._backward = _backward
    out._prev = {
        tensor,
    }
    return out


"""
    1. note: id(i, j) = 1{i == j}
    mathematical operations, applied to 1D vector:
    
    forward: softmax(z)[i] = e^zi / sum(e^z)
    backward: since softmax is a vector-to-vector function,
              the local_grad we need to compute is a Jacobian:
        local_grad[i,j] = softmax(z)[i] * (id(i,j) - softmax(z)[j])
    
    
    2. for numerical stability. 
    forward: softmax(z)[i] = e^zi / sum(e^z)
             = e^(zi) / e^(logsumexp(z))
             = e^(zi - logsumexp(z))
             
    backward (vectorized): 
        let ID.shape = local_grad.shape = (N, N). 
            ID[i, j] = id(i, j).
            local_grad[i, j] = d.s[i] / d.z[j]
            
        local_grad[i, j] = out[i] * ID[i, j] - out[i] * out[j]
        local_grad[i, :] = out[i] * ID[i, :] - out[i] * out[:]
        local_grad[:, :] = out[:] * ID[:, :] - out[:, None] * out[None, :]
                         = diag(out) - outer(out, out)

        shit, however, this is not really efficient.
"""


def softmax(tensor, axis: int = None):
    """
    to reduce headache, actually I should implement an exp ops,
    then let the chain rule do its job automatically.
    """
    lse = logsumexp(tensor, axis=axis, keepdims=True)
    # print(tensor.shape, lse.shape)
    lse_broadcast = broadcast_to(lse, tensor.data.shape)
    log_softmax = tensor - lse_broadcast
    out = exp(log_softmax)

    return out

In [6]:
import numpy as np
import torch
import torch.nn.functional as F
from simplegrad.tensor import Tensor


def test_logsumexp():
    # Define diverse test cases
    test_cases = [
        # Small values
        {"data": np.random.rand(3, 4) * 0.001, "axis": None, "keepdims": False},
        # Large values
        {"data": np.random.rand(3, 4) * 100, "axis": None, "keepdims": False},
        # Negative values
        {"data": np.random.rand(3, 4) * -10, "axis": None, "keepdims": False},
        # Mixed values
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": None, "keepdims": False},
        # Single dimension reduction with keepdims=True
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": 0, "keepdims": True},
        # Single dimension reduction with keepdims=False
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": 1, "keepdims": False},
        # Multiple dimensions
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": None, "keepdims": False},
        # Multiple dimensions with specific axis
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": 1, "keepdims": False},
        # Multiple dimensions with tuple axis
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": (0, 2), "keepdims": False},
        # Multiple dimensions with tuple axis and keepdims=True
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": (0, 2), "keepdims": True},
    ]

    for i, test_case in enumerate(test_cases):
        data = test_case["data"]
        axis = test_case["axis"]
        keepdims = test_case["keepdims"]

        # Convert to tensors
        pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
        x = Tensor.from_torch(pt_x)

        # PyTorch version
        if axis is None:
            # PyTorch's logsumexp requires a specific dim
            expected = torch.logsumexp(
                pt_x, dim=tuple(range(pt_x.dim())), keepdim=keepdims
            )
        elif isinstance(axis, int):
            expected = torch.logsumexp(pt_x, dim=axis, keepdim=keepdims)
        else:
            # For multiple axes, we need to handle them one by one in PyTorch
            temp = pt_x
            # Process axes in reverse order to maintain correct dimensions
            for ax in sorted(axis, reverse=True):
                temp = torch.logsumexp(temp, dim=ax, keepdim=keepdims)
            expected = temp

        # Our implementation
        result = logsumexp(x, axis=axis, keepdims=keepdims)

        # Check forward pass
        np.testing.assert_allclose(
            result.data,
            expected.detach().numpy(),
            rtol=1e-5,
            atol=1e-5,
            err_msg=f"Forward pass failed for test case {i+1}: data shape {data.shape}, axis {axis}, keepdims {keepdims}",
        )
        print(f"LogSumExp forward test {i+1} passed!")

        # Compute gradients
        grad_output = torch.ones_like(expected)
        expected.backward(grad_output)
        result.backward()

        # Check backward pass
        np.testing.assert_allclose(
            x.grad,
            pt_x.grad.detach().numpy(),
            rtol=1e-5,
            atol=1e-5,
            err_msg=f"Backward pass failed for test case {i+1}: data shape {data.shape}, axis {axis}, keepdims {keepdims}",
        )
        print(f"LogSumExp backward test {i+1} passed!")


def test_logsumexp_specific_cases():
    """Test specific edge cases for logsumexp"""

    # Case 1: All elements are the same (tests numerical stability)
    data = np.ones((3, 3)) * 1000  # Large identical values
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)

    expected = torch.logsumexp(pt_x, dim=1, keepdim=False)
    result = logsumexp(x, axis=1, keepdims=False)

    np.testing.assert_allclose(
        result.data, expected.detach().numpy(), rtol=1e-5, atol=1e-5
    )
    print("LogSumExp specific case 1 (large identical values) passed!")

    # Case 2: Extreme differences between values (tests numerical stability)
    data = np.array([[1e-10, 1e10], [1e-10, 1e-10]])
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)

    expected = torch.logsumexp(pt_x, dim=1, keepdim=False)
    result = logsumexp(x, axis=1, keepdims=False)

    np.testing.assert_allclose(
        result.data, expected.detach().numpy(), rtol=1e-5, atol=1e-5
    )
    print("LogSumExp specific case 2 (extreme value differences) passed!")

    # Case 3: Test with softmax relation (logsumexp is used in softmax implementation)
    data = np.random.rand(5, 10) * 2 - 1
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)

    # Standard softmax calculation using logsumexp
    pt_logsumexp = torch.logsumexp(pt_x, dim=1, keepdim=True)
    pt_softmax = torch.exp(pt_x - pt_logsumexp)

    our_logsumexp = logsumexp(x, axis=1, keepdims=True)
    our_softmax = np.exp(x.data - our_logsumexp.data)

    np.testing.assert_allclose(
        our_softmax, pt_softmax.detach().numpy(), rtol=1e-5, atol=1e-5
    )
    print("LogSumExp specific case 3 (softmax relation) passed!")

In [7]:
print("Testing LogSumExp operation...")
test_logsumexp()

print("\nTesting LogSumExp specific cases...")
test_logsumexp_specific_cases()

print("\nAll LogSumExp tests completed successfully!")

Testing LogSumExp operation...
LogSumExp forward test 1 passed!
LogSumExp backward test 1 passed!
LogSumExp forward test 2 passed!
LogSumExp backward test 2 passed!
LogSumExp forward test 3 passed!
LogSumExp backward test 3 passed!
LogSumExp forward test 4 passed!
LogSumExp backward test 4 passed!
LogSumExp forward test 5 passed!
LogSumExp backward test 5 passed!
LogSumExp forward test 6 passed!
LogSumExp backward test 6 passed!
LogSumExp forward test 7 passed!
LogSumExp backward test 7 passed!
LogSumExp forward test 8 passed!
LogSumExp backward test 8 passed!
LogSumExp forward test 9 passed!
LogSumExp backward test 9 passed!
LogSumExp forward test 10 passed!
LogSumExp backward test 10 passed!

Testing LogSumExp specific cases...
LogSumExp specific case 1 (large identical values) passed!
LogSumExp specific case 2 (extreme value differences) passed!
LogSumExp specific case 3 (softmax relation) passed!

All LogSumExp tests completed successfully!


In [8]:
def test_summation():
    """Test the summation operation with challenging cases"""
    print("\n=== TESTING SUMMATION ===")

    # Define challenging test cases
    test_cases = [
        # Basic cases
        {
            "data": np.random.rand(5, 5),
            "axis": None,
            "keepdims": False,
            "name": "Basic 2D, all axes",
        },
        {
            "data": np.random.rand(5, 5),
            "axis": 0,
            "keepdims": False,
            "name": "Basic 2D, axis 0",
        },
        {
            "data": np.random.rand(5, 5),
            "axis": 1,
            "keepdims": True,
            "name": "Basic 2D, axis 1 with keepdims",
        },
        # Extreme values
        {
            "data": np.random.rand(10, 10) * 1e10,
            "axis": None,
            "keepdims": False,
            "name": "Large values (1e10)",
        },
        {
            "data": np.random.rand(10, 10) * 1e-10,
            "axis": 0,
            "keepdims": True,
            "name": "Small values (1e-10)",
        },
        {
            "data": np.array([[1e15, 1e-15], [1e-15, 1e15]]),
            "axis": 1,
            "keepdims": False,
            "name": "Mixed extreme values",
        },
        # Large dimensions
        {
            "data": np.random.rand(1000, 5),
            "axis": 0,
            "keepdims": False,
            "name": "Large first dimension (1000x5)",
        },
        {
            "data": np.random.rand(5, 1000),
            "axis": 1,
            "keepdims": True,
            "name": "Large second dimension (5x1000)",
        },
        # Higher dimensions
        {
            "data": np.random.rand(10, 10, 10),
            "axis": (0, 2),
            "keepdims": False,
            "name": "3D with multiple axes",
        },
        {
            "data": np.random.rand(5, 5, 5, 5),
            "axis": (1, 2),
            "keepdims": True,
            "name": "4D with multiple axes and keepdims",
        },
        # Special patterns
        {
            "data": np.ones((20, 20)),
            "axis": None,
            "keepdims": False,
            "name": "All ones",
        },
        {"data": np.zeros((20, 20)), "axis": 0, "keepdims": True, "name": "All zeros"},
        {"data": np.eye(20), "axis": 1, "keepdims": False, "name": "Identity matrix"},
        # Edge cases
        {
            "data": np.array([1.0]),
            "axis": None,
            "keepdims": False,
            "name": "Single value",
        },
        {
            "data": np.random.rand(1, 1, 1, 1),
            "axis": (1, 2),
            "keepdims": True,
            "name": "Multiple singleton dimensions",
        },
    ]

    for i, test_case in enumerate(test_cases):
        data = test_case["data"]
        axis = test_case["axis"]
        keepdims = test_case["keepdims"]
        name = test_case["name"]

        # Create tensors
        pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
        x = Tensor(data, requires_grad=True)

        # PyTorch sum
        if isinstance(axis, tuple):
            # For multiple axes in PyTorch, we need to do them one by one
            expected = pt_x
            for ax in sorted(axis, reverse=True):  # Start from the highest axis
                expected = expected.sum(dim=ax, keepdim=keepdims)
        else:
            expected = pt_x.sum(dim=axis, keepdim=keepdims)

        # Our summation
        result = summation(x, axis=axis, keepdims=keepdims)

        # Check forward pass
        try:
            np.testing.assert_allclose(
                result.data,
                expected.detach().numpy(),
                rtol=1e-5,
                atol=1e-5,
                err_msg=f"Summation forward pass failed",
            )
            # print(f"  ✓ Forward pass successful")
        except Exception as e:
            print(f"  ✗ Forward pass failed: {e}")
            continue

        # Generate random gradient for backward pass
        grad_output_np = np.random.rand(*result.data.shape)
        if isinstance(grad_output_np, float):
            grad_output_np = np.float32(grad_output_np)
        else:
            grad_output_np = grad_output_np.astype(np.float32)
        grad_output_torch = torch.tensor(grad_output_np)

        # Compute gradients
        expected.backward(grad_output_torch)
        result.backward(grad_output_np)

        # Check backward pass
        try:
            np.testing.assert_allclose(
                x.grad,
                pt_x.grad.detach().numpy(),
                rtol=1e-4,
                atol=1e-5,
                err_msg=f"broadcast_to backward pass failed",
            )
            # print(f"  ✓ Backward pass successful")
            result_msg = "Successful."
        except Exception as e:
            result_msg = "Failed."
            print(f"  ✗ Backward pass failed: {e}")

        # Reset gradients
        pt_x.grad = None
        x.grad = np.zeros_like(x.data)

        print(
            f"Test case {i+1}: {name}."
            # f" Shape: {data.shape}, Axis: {axis}, Keepdims: {keepdims}."
            f" {result_msg}"
        )

In [11]:
def test_broadcast_to():
    """Test the broadcast_to operation with challenging cases"""
    print("\n=== TESTING BROADCAST_TO ===")

    # Define challenging test cases
    test_cases = [
        # Basic broadcasting
        {"data": np.random.rand(1), "shape": (10,), "name": "Scalar to vector"},
        {"data": np.random.rand(1, 5), "shape": (10, 5), "name": "Row to matrix"},
        {"data": np.random.rand(5, 1), "shape": (5, 10), "name": "Column to matrix"},
        # Extreme values
        {
            "data": np.random.rand(1, 3) * 1e9,
            "shape": (5, 3),
            "name": "Large values (1e9)",
        },
        {
            "data": np.random.rand(1, 3) * 1e-9,
            "shape": (5, 3),
            "name": "Small values (1e-9)",
        },
        {
            "data": np.array([[1e15], [1e-15]]),
            "shape": (2, 5),
            "name": "Mixed extreme values",
        },
        # Large dimensions
        {
            "data": np.random.rand(1, 5),
            "shape": (1000, 5),
            "name": "Broadcast to large first dim (1000)",
        },
        {
            "data": np.random.rand(5, 1),
            "shape": (5, 1000),
            "name": "Broadcast to large second dim (1000)",
        },
        # Higher dimensions
        {
            "data": np.random.rand(1, 5, 1),
            "shape": (10, 5, 8),
            "name": "3D broadcasting",
        },
        {
            "data": np.random.rand(1, 1, 1, 5),
            "shape": (7, 6, 5, 5),
            "name": "4D broadcasting",
        },
        # Multiple dimensions broadcasted
        {
            "data": np.random.rand(1, 1, 3),
            "shape": (8, 8, 3),
            "name": "Broadcasting multiple dimensions",
        },
        # Special patterns
        {"data": np.ones((1, 5)), "shape": (10, 5), "name": "Broadcasting ones"},
        {"data": np.zeros((1, 5)), "shape": (10, 5), "name": "Broadcasting zeros"},
        # No broadcasting (identity case)
        {
            "data": np.random.rand(5, 5),
            "shape": (5, 5),
            "name": "No broadcasting (same shape)",
        },
        # Edge cases
        {
            "data": np.array([1.0]),
            "shape": (1, 1, 1, 1),
            "name": "Scalar to higher dims",
        },
    ]

    for i, test_case in enumerate(test_cases):
        data = test_case["data"]
        shape = test_case["shape"]
        name = test_case["name"]

        # Create tensors
        pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
        x = Tensor(data, requires_grad=True)

        # PyTorch broadcast (expand)
        # Handle the case of broadcasting to higher dimensions
        if pt_x.dim() < len(shape):
            # Add dimensions to match the target shape
            expanded_dims = len(shape) - pt_x.dim()
            reshape_dims = [1] * expanded_dims + list(pt_x.shape)
            pt_x_reshaped = pt_x.reshape(reshape_dims)
            expected = pt_x_reshaped.expand(shape)
        else:
            expected = pt_x.expand(shape)

        # Our broadcast_to
        result = broadcast_to(x, shape)

        # Check forward pass
        try:
            np.testing.assert_allclose(
                result.data,
                expected.detach().numpy(),
                rtol=1e-5,
                atol=1e-5,
                err_msg=f"broadcast_to forward pass failed",
            )
            # print(f"  ✓ Forward pass successful")
        except Exception as e:
            print(f"  ✗ Forward pass failed: {e}")
            continue

        # Generate random gradient for backward pass
        grad_output_np = np.random.rand(*shape).astype(np.float32)
        grad_output_torch = torch.tensor(grad_output_np)

        # Compute gradients
        expected.backward(grad_output_torch)
        result.backward(grad_output_np)

        # Check backward pass
        try:
            np.testing.assert_allclose(
                x.grad,
                pt_x.grad.detach().numpy(),
                rtol=1e-4,
                atol=1e-5,
                err_msg=f"broadcast_to backward pass failed",
            )
            # print(f"  ✓ Backward pass successful")
            result_msg = "Successful."
        except Exception as e:
            result_msg = "Failed."
            print(f"  ✗ Backward pass failed: {e}")

        # Reset gradients
        pt_x.grad = None
        x.grad = np.zeros_like(x.data)

        print(
            f"Test case {i+1}: {name}."
            # f" Shape: {data.shape}, Axis: {axis}, Keepdims: {keepdims}."
            f" {result_msg}"
        )

In [12]:
def test_softmax():
    """Test the softmax operation with challenging cases"""
    print("\n=== TESTING SOFTMAX ===")

    # Define challenging test cases
    test_cases = [
        # Basic cases
        {"data": np.random.rand(10), "axis": None, "name": "Basic 1D vector"},
        {"data": np.random.rand(5, 5), "axis": 1, "name": "Basic 2D, axis 1"},
        {"data": np.random.rand(5, 5), "axis": 0, "name": "Basic 2D, axis 0"},
        # Extreme values
        {"data": np.random.rand(10) * 1e9, "axis": None, "name": "Large values (1e9)"},
        {
            "data": np.random.rand(10) * 1e-9,
            "axis": None,
            "name": "Small values (1e-9)",
        },
        {
            "data": np.array([1e15, 1e-15, 0, -1e-15, -1e15]),
            "axis": None,
            "name": "Mixed extreme values",
        },
        # Numerical stability challenges
        {
            "data": np.array([1000, 0, -1000]),
            "axis": None,
            "name": "Very different values",
        },
        {
            "data": np.array([1e5, 1e5 + 1e-5]),
            "axis": None,
            "name": "Nearly identical large values",
        },
        {"data": np.ones(10) * 1e5, "axis": None, "name": "All identical large values"},
        # Large dimensions
        {
            "data": np.random.rand(1000, 5),
            "axis": 1,
            "name": "Large first dimension (1000x5)",
        },
        {
            "data": np.random.rand(5, 1000) * 1e9,
            "axis": 0,
            "name": "Large second dimension (5x1000)",
        },
        # Higher dimensions
        {"data": np.random.rand(10, 10, 10), "axis": 2, "name": "3D tensor, last axis"},
        {
            "data": np.random.rand(10, 10, 10),
            "axis": 1,
            "name": "3D tensor, middle axis",
        },
        {
            "data": np.random.rand(5, 5, 5, 5),
            "axis": 0,
            "name": "4D tensor, first axis",
        },
        # Special patterns
        {
            "data": np.zeros((10, 10)),
            "axis": 1,
            "name": "All zeros (uniform distribution)",
        },
        {
            "data": np.ones((10, 10)),
            "axis": 1,
            "name": "All ones (uniform distribution)",
        },
        {"data": np.eye(10), "axis": 1, "name": "Identity matrix"},
        # Edge cases
        {
            "data": np.array([42.0]),
            "axis": None,
            "name": "Single value (should be 1.0)",
        },
        {
            "data": np.zeros((1, 1, 1)),
            "axis": 1,
            "name": "Multiple singleton dimensions",
        },
    ]

    for i, test_case in enumerate(test_cases):
        data = test_case["data"]
        axis = test_case["axis"]
        name = test_case["name"]

        # print(f"\nTest case {i+1}: {name}")
        # print(f"  Shape: {data.shape}, Axis: {axis}")

        # Create tensors
        pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
        x = Tensor(data, requires_grad=True)

        # PyTorch softmax
        if axis is None:
            # Flatten for axis=None
            flattened = pt_x.reshape(-1)
            expected = torch.nn.functional.softmax(flattened, dim=0)
        else:
            expected = torch.nn.functional.softmax(pt_x, dim=axis)

        # Our softmax
        result = softmax(x, axis=axis)

        # Check forward pass
        try:
            np.testing.assert_allclose(
                result.data,
                expected.detach().numpy(),
                rtol=1e-5,
                atol=1e-5,
                err_msg=f"Softmax forward pass failed",
            )
            # print(f"  ✓ Forward pass successful")
        except Exception as e:
            print(f"  ✗ Forward pass failed: {e}")
            continue

        # Generate random gradient for backward pass
        grad_output_np = np.random.rand(*result.data.shape).astype(np.float32)

        # For PyTorch, ensure gradient has the right shape
        if axis is None:
            grad_output_torch = torch.tensor(grad_output_np.reshape(-1))
        else:
            grad_output_torch = torch.tensor(grad_output_np)

        # Compute gradients
        expected.backward(grad_output_torch)
        result.backward(grad_output_np)

        # Check backward pass
        try:
            np.testing.assert_allclose(
                x.grad,
                pt_x.grad.detach().numpy(),
                rtol=1e-4,
                atol=1e-5,
                err_msg=f"broadcast_to backward pass failed",
            )
            # print(f"  ✓ Backward pass successful")
            result_msg = "Successful."
        except Exception as e:
            result_msg = "Failed."
            print(f"  ✗ Backward pass failed: {e}")

        # Reset gradients
        pt_x.grad = None
        x.grad = np.zeros_like(x.data)

        print(
            f"Test case {i+1}: {name}."
            # f" Shape: {data.shape}, Axis: {axis}, Keepdims: {keepdims}."
            f" {result_msg}"
        )

In [13]:
test_summation()
test_broadcast_to()
test_softmax()


=== TESTING SUMMATION ===
Test case 1: Basic 2D, all axes. Successful.
Test case 2: Basic 2D, axis 0. Successful.
Test case 3: Basic 2D, axis 1 with keepdims. Successful.
Test case 4: Large values (1e10). Successful.
Test case 5: Small values (1e-10). Successful.
Test case 6: Mixed extreme values. Successful.
Test case 7: Large first dimension (1000x5). Successful.
Test case 8: Large second dimension (5x1000). Successful.
Test case 9: 3D with multiple axes. Successful.
Test case 10: 4D with multiple axes and keepdims. Successful.
Test case 11: All ones. Successful.
Test case 12: All zeros. Successful.
Test case 13: Identity matrix. Successful.
Test case 14: Single value. Successful.
Test case 15: Multiple singleton dimensions. Successful.

=== TESTING BROADCAST_TO ===
Test case 1: Scalar to vector. Successful.
Test case 2: Row to matrix. Successful.
Test case 3: Column to matrix. Successful.
Test case 4: Large values (1e9). Successful.
Test case 5: Small values (1e-9). Successful.
Tes

## Develop optimizer

In [74]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split

### Reference training

In [58]:
class DecimalToBase4NN(nn.Module):
    def __init__(
        self, input_digits=10, input_classes=10, hidden_size=128, output_digits=6
    ):
        """
        A neural network to convert decimal numbers to base-4 representation.

        Args:
            input_digits: Number of decimal digits to encode
            input_classes: Number of possible values for each input digit (10 for decimal)
            hidden_size: Size of hidden layer
            output_digits: Number of base-4 digits to predict
        """
        super(DecimalToBase4NN, self).__init__()
        self.input_digits = input_digits
        self.output_digits = output_digits

        # Calculate total input size (one-hot vectors for each digit)
        input_size = input_digits * input_classes

        # Network architecture: linear -> relu -> linear -> softmax (implicitly during loss calculation)
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(
                hidden_size, output_digits * 4
            ),  # 4 possible values (0,1,2,3) for each output digit
        )

    def forward(self, x):
        # Forward pass
        batch_size = x.size(0)
        output = self.model(x)

        # Reshape to [batch_size, output_digits, 4] for softmax across the 4 possible values
        output = output.view(batch_size, self.output_digits, 4)

        # Apply softmax to get probabilities for each digit
        output = torch.softmax(output, dim=2)

        return output


def decimal_to_digits(decimal_num, num_digits=10):
    """Convert a decimal number to its individual digits (right-padded with zeros)."""
    digits = []
    temp = decimal_num

    # Extract individual digits
    while temp > 0:
        digits.append(temp % 10)
        temp //= 10

    # Pad with zeros to reach num_digits
    while len(digits) < num_digits:
        digits.append(0)

    # Reverse since we calculated least significant digit first
    return digits[::-1]


def decimal_to_base4(decimal_num, max_digits=6):
    """Convert a decimal number to its base-4 representation."""
    if decimal_num == 0:
        return [0] * max_digits

    base4_digits = []
    temp = decimal_num

    while temp > 0:
        base4_digits.append(temp % 4)
        temp //= 4

    # Pad with zeros to reach max_digits
    while len(base4_digits) < max_digits:
        base4_digits.append(0)

    # Reverse since we calculated least significant digit first
    return base4_digits[::-1][-max_digits:]


def generate_training_data(
    num_samples=10000, max_decimal=4 ** 6 - 1, input_digits=10, output_digits=6
):
    """Generate training data: pairs of (one-hot encoded decimal digits, base-4 representation)."""
    # Create a list of all possible numbers from 0 to max_decimal
    all_possible_nums = np.arange(max_decimal + 1)

    # Sample from the complete list
    if num_samples <= len(all_possible_nums):
        decimal_nums = np.random.choice(
            all_possible_nums, size=num_samples, replace=False
        )
    else:
        # If num_samples is larger than possible numbers, use all numbers
        print(
            f"Warning: Requested {num_samples} samples but only {len(all_possible_nums)} numbers exist in the range."
        )
        print(f"Using all {len(all_possible_nums)} numbers.")
        decimal_nums = all_possible_nums
        np.random.shuffle(decimal_nums)  # Shuffle to ensure random order

    # Convert each decimal to its individual digits
    X_digits = [decimal_to_digits(num, input_digits) for num in decimal_nums]

    # One-hot encode the input decimal digits
    X = torch.zeros(num_samples, input_digits, 10)
    for i, dec_digits in enumerate(X_digits):
        for j, digit in enumerate(dec_digits):
            X[i, j, digit] = 1.0

    # Flatten the one-hot vectors
    X = X.view(num_samples, -1)

    # Convert each decimal to its base-4 representation
    y_list = [decimal_to_base4(num, output_digits) for num in decimal_nums]

    # One-hot encode the base-4 digits
    y = torch.zeros(num_samples, output_digits, 4)
    for i, base4_num in enumerate(y_list):
        for j, digit in enumerate(base4_num):
            y[i, j, digit] = 1.0

    return X, y, decimal_nums


def train_test_split_data(X, y, decimal_nums, test_size=0.2):
    """Split data into training and testing sets."""
    dataset = TensorDataset(X, y)

    # Calculate sizes
    dataset_size = len(dataset)
    test_size = int(test_size * dataset_size)
    train_size = dataset_size - test_size

    # Create tensor of indices and then split it
    all_indices = torch.arange(dataset_size)

    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Randomly shuffle the indices
    shuffled_indices = all_indices[torch.randperm(dataset_size)]

    # Split the shuffled indices
    train_indices = shuffled_indices[:train_size]
    test_indices = shuffled_indices[train_size:]

    # Create dataset subsets
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)

    # Safely split the decimal numbers to match the X and y tensors
    # Ensure we're using the correct indices within range of decimal_nums
    safe_train_indices = train_indices[train_indices < len(decimal_nums)]
    safe_test_indices = test_indices[test_indices < len(decimal_nums)]

    if len(safe_train_indices) < len(train_indices) or len(safe_test_indices) < len(
        test_indices
    ):
        print(
            f"Warning: decimal_nums has {len(decimal_nums)} elements, but dataset has {dataset_size} elements."
        )
        print(f"Using only the valid indices for decimal_nums.")

    # If decimal_nums is shorter than X/y (which can happen after the change to generate_training_data),
    # we need to pad it or truncate the indices
    if len(decimal_nums) < dataset_size:
        # Pad decimal_nums with zeros to match dataset size
        padded_nums = np.zeros(dataset_size, dtype=decimal_nums.dtype)
        padded_nums[: len(decimal_nums)] = decimal_nums
        train_decimal_nums = padded_nums[train_indices.numpy()]
        test_decimal_nums = padded_nums[test_indices.numpy()]
    else:
        # Use original array as-is
        train_decimal_nums = decimal_nums[train_indices.numpy()]
        test_decimal_nums = decimal_nums[test_indices.numpy()]

    return train_dataset, test_dataset, train_decimal_nums, test_decimal_nums


def train_model(
    model,
    train_dataset,
    device,
    epochs=100,
    batch_size=128,
    learning_rate=0.01,
    log_steps=50,
):
    """Train the neural network using SGD optimizer."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

    # Create DataLoader for training data
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model.train()  # Set model to training mode
        total_loss = 0

        for batch_X, batch_y in train_loader:
            # Move batch to device
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()

            # Forward pass
            predictions = model(batch_X)

            # Reshape for cross entropy loss
            batch_size = batch_X.size(0)
            predictions = predictions.view(batch_size * model.output_digits, 4)
            targets = batch_y.view(batch_size * model.output_digits, 4)
            targets = torch.argmax(targets, dim=1)  # Convert one-hot to class indices

            # Calculate loss
            loss = criterion(predictions, targets)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % log_steps == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")


def evaluate_model(model, test_dataset, device):
    """Evaluate the model on the test dataset."""
    model.eval()  # Set model to evaluation mode
    test_loader = DataLoader(test_dataset, batch_size=128)

    correct_digits = 0
    total_digits = 0
    correct_numbers = 0
    total_numbers = 0

    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            # Move batch to device
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)

            # Forward pass
            predictions = model(batch_X)

            # Get predicted digits
            pred_digits = torch.argmax(predictions, dim=2)
            true_digits = torch.argmax(batch_y, dim=2)

            # Count correct digits
            correct_digits += (pred_digits == true_digits).sum().item()
            total_digits += pred_digits.numel()

            # Count correct numbers (all digits correct)
            correct_per_sample = (pred_digits == true_digits).all(dim=1)
            correct_numbers += correct_per_sample.sum().item()
            total_numbers += correct_per_sample.numel()

    digit_accuracy = correct_digits / total_digits * 100
    number_accuracy = correct_numbers / total_numbers * 100

    print(f"\nTest Results:")
    print(f"Digit-level accuracy: {digit_accuracy:.2f}%")
    print(f"Number-level accuracy: {number_accuracy:.2f}%")

    return digit_accuracy, number_accuracy


def test_specific_examples(
    model, test_decimal_nums, device, input_digits=10, output_digits=6, num_samples=20
):
    """Test the model on specific examples from the test set."""
    model.eval()

    # Randomly select samples from test set
    if num_samples > len(test_decimal_nums):
        num_samples = len(test_decimal_nums)

    indices = np.random.choice(len(test_decimal_nums), num_samples, replace=False)
    sample_nums = test_decimal_nums[indices]

    # Convert to one-hot encoded digits
    X_digits = [decimal_to_digits(num, input_digits) for num in sample_nums]
    X_test = torch.zeros(num_samples, input_digits, 10)
    for i, dec_digits in enumerate(X_digits):
        for j, digit in enumerate(dec_digits):
            X_test[i, j, digit] = 1.0
    X_test = X_test.view(num_samples, -1).to(device)

    with torch.no_grad():
        predictions = model(X_test)

    print("\nTesting specific examples from the test set:")
    for i, decimal_num in enumerate(sample_nums):
        # Get the predicted base-4 digits
        pred_digits = torch.argmax(predictions[i], dim=1).cpu().numpy()

        # Get the true base-4 digits
        true_digits = decimal_to_base4(decimal_num, output_digits)
        true_digits = np.array([int(x) for x in true_digits])

        # Convert predictions to a single number for easy comparison
        pred_decimal = sum(
            digit * (4 ** (output_digits - idx - 1))
            for idx, digit in enumerate(pred_digits)
        )

        is_correct = np.array_equal(pred_digits, true_digits)
        print(
            f"Decimal: {decimal_num}. Correct: {is_correct}"
            f"\n    Predicted Base-4: {pred_digits}"
            f"\n    True Base-4     : {true_digits}"
        )

        if not is_correct:
            print(f"    Predicted as decimal: {pred_decimal}")

In [61]:
# Main execution
if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Set device to cuda:1
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Parameters
    input_digits = 10  # Number of decimal digits to encode
    output_digits = 7  # Maximum number of base-4 digits
    max_decimal = 4 ** output_digits - 1  # Maximum decimal number we can represent
    hidden_size = 1024  # Size of hidden layer

    # Generate data - adjust num_samples to be at most max_decimal+1
    num_samples = min(20000, max_decimal + 1)
    print(f"Using {num_samples} samples (max possible: {max_decimal + 1})")

    X, y, decimal_nums = generate_training_data(
        num_samples=num_samples,
        max_decimal=max_decimal,
        input_digits=input_digits,
        output_digits=output_digits,
    )

    # Split data into train and test sets
    (
        train_dataset,
        test_dataset,
        train_decimal_nums,
        test_decimal_nums,
    ) = train_test_split_data(X, y, decimal_nums, test_size=0.1)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Testing samples: {len(test_dataset)}")

    # Create the model and move to device
    model = DecimalToBase4NN(
        input_digits=input_digits,
        input_classes=10,  # 10 possible values (0-9) for decimal
        hidden_size=hidden_size,
        output_digits=output_digits,
    ).to(device)

    print("Initial model:")
    evaluate_model(model, test_dataset, device)
    # Train the model
    train_model(
        model,
        train_dataset,
        device,
        epochs=2000,  # Increased epochs
        batch_size=512,
        learning_rate=0.05,
    )

    # Evaluate the model on the test set
    digit_accuracy, number_accuracy = evaluate_model(model, test_dataset, device)

    # Test specific examples
    test_specific_examples(
        model, test_decimal_nums, device, input_digits, output_digits, num_samples=10
    )

Using device: cuda:1
Using 16384 samples (max possible: 16384)
Training samples: 14746
Testing samples: 1638
Initial model:

Test Results:
Digit-level accuracy: 24.59%
Number-level accuracy: 0.00%
Epoch 50/2000, Loss: 1.1575
Epoch 100/2000, Loss: 1.1124
Epoch 150/2000, Loss: 1.0928
Epoch 200/2000, Loss: 1.0848
Epoch 250/2000, Loss: 1.0742
Epoch 300/2000, Loss: 1.0683
Epoch 350/2000, Loss: 1.0651
Epoch 400/2000, Loss: 1.0624
Epoch 450/2000, Loss: 1.0595
Epoch 500/2000, Loss: 1.0563
Epoch 550/2000, Loss: 1.0528
Epoch 600/2000, Loss: 1.0476
Epoch 650/2000, Loss: 1.0433
Epoch 700/2000, Loss: 1.0392
Epoch 750/2000, Loss: 1.0343
Epoch 800/2000, Loss: 1.0270
Epoch 850/2000, Loss: 1.0154
Epoch 900/2000, Loss: 0.9999
Epoch 950/2000, Loss: 0.9837
Epoch 1000/2000, Loss: 0.9715
Epoch 1050/2000, Loss: 0.9632
Epoch 1100/2000, Loss: 0.9578
Epoch 1150/2000, Loss: 0.9541
Epoch 1200/2000, Loss: 0.9513
Epoch 1250/2000, Loss: 0.9494
Epoch 1300/2000, Loss: 0.9478
Epoch 1350/2000, Loss: 0.9466
Epoch 1400/20

In [66]:
test_specific_examples(
    model, test_decimal_nums, device, input_digits, output_digits, num_samples=10
)


Testing specific examples from the test set:
Decimal: 12100. Correct: True
    Predicted Base-4: [2 3 3 1 0 1 0]
    True Base-4     : [2 3 3 1 0 1 0]
Decimal: 11182. Correct: False
    Predicted Base-4: [2 3 3 2 2 3 0]
    True Base-4     : [2 2 3 2 2 3 2]
    Predicted as decimal: 12204
Decimal: 6990. Correct: False
    Predicted Base-4: [1 2 2 1 0 3 2]
    True Base-4     : [1 2 3 1 0 3 2]
    Predicted as decimal: 6734
Decimal: 3037. Correct: False
    Predicted Base-4: [0 3 3 3 1 3 3]
    True Base-4     : [0 2 3 3 1 3 1]
    Predicted as decimal: 4063
Decimal: 7449. Correct: True
    Predicted Base-4: [1 3 1 0 1 2 1]
    True Base-4     : [1 3 1 0 1 2 1]
Decimal: 8475. Correct: False
    Predicted Base-4: [2 0 1 2 1 1 3]
    True Base-4     : [2 0 1 0 1 2 3]
    Predicted as decimal: 8599
Decimal: 14227. Correct: False
    Predicted Base-4: [3 2 3 2 1 1 3]
    True Base-4     : [3 1 3 2 1 0 3]
    Predicted as decimal: 15255
Decimal: 2737. Correct: False
    Predicted Base-4: [0

### SGD

In [1]:
import simplegrad as sg
import simplegrad.module as snn
from simplegrad import Tensor, ops
from simplegrad.module import Module

import torch
import torch.nn as nn
import numpy as np

In [11]:
class Parameter(Tensor):
    """A special kind of tensor that represents parameters."""


def _unpack_params(value: object):
    if isinstance(value, Parameter):
        return [value]
    elif isinstance(value, Module):
        return value.parameters()
    elif isinstance(value, dict):
        params = []
        for k, v in value.items():
            params += _unpack_params(v)
        return params
    elif isinstance(value, (list, tuple)):
        params = []
        for v in value:
            params += _unpack_params(v)
        return params
    else:
        return []


class Module:
    def parameters(self):
        raise NotImplementedError

    def zero_grad(self):
        for p in self.parameters():
            p.grad = np.zeros_like(p.data)

    def parameters(self):  # List[Tensor]
        """Return the list of parameters in the module."""
        return _unpack_params(self.__dict__)


class Linear(Module):
    def __init__(self, in_features, out_features):
        self.weight = Parameter(
            np.random.randn(out_features, in_features) * 0.01, requires_grad=True
        )
        self.bias = Parameter(np.zeros(out_features), requires_grad=True)

    def __call__(self, x):
        output = ops.matmul(x, self.weight.transpose())
        output = output + ops.broadcast_to(self.bias, output.shape)
        return output
        # return matmul(x, self.weight.transpose()) + self.bias

    def load_state_dict(self, state_dict):
        self.weight.data = state_dict["weight"].detach().numpy()
        self.bias.data = state_dict["bias"].detach().numpy()

    def state_dict(self):
        return dict(
            weight=self.weight.data,
            bias=self.bias.data,
        )


class ReLU(Module):
    def __call__(self, x: Tensor) -> Tensor:
        return ops.relu(x)


class Sequential(Module):
    def __init__(self, *modules):
        super().__init__()
        self.modules = modules

    def __call__(self, x: Tensor) -> Tensor:
        output = x
        for module in self.modules:
            output = module(output)
        return output

In [42]:
y = np.random.randint(low=0, high=4, size=6)
y = Tensor(y)
label = one_hot(n_dim=4, y=y)

In [44]:
label.data

array([[0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 1., 0., 0.]], dtype=float32)

In [40]:
def one_hot(n_dim: int, y: Tensor) -> Tensor:
    """
    n_dim (int): Number of classes (length of each one-hot vector).
    y (Tensor): Tensor of shape (batch_size,) with integer class indices.
    """
    batch_size = y.shape[0]
    one_hot_np = np.zeros((batch_size, n_dim), dtype=int)
    for i in range(batch_size):
        one_hot_np.data[i, int(y.data[i])] = 1
    one_hot_tensor = Tensor(one_hot_np, requires_grad=False)
    return one_hot_tensor

class SoftmaxLoss(Module):
    def __call__(self, logits: Tensor, y: Tensor):
        n_dim = logits.shape[-1]
        y_one_hot = one_hot(n_dim, y)
        zy = ops.summation(logits * y_one_hot, axes=(1,))
        zy = ops.reshape(zy, (-1, 1))
        zy = ops.broadcast_to(zy, logits.shape)

        losses = ops.logsumexp(logits - zy, axes=(1))
        total_loss = ops.summation(losses) / losses.shape[-1]
        return total_loss

In [9]:
class Optimizer:
    def __init__(self, params):
        self.params = params

    def step(self):
        raise NotImplementedError()

    def reset_grad(self):
        for p in self.params:
            p.grad = None


class NaiveSGD(Optimizer):
    """
    Stochastic Gradient Descent (SGD) with optional L2 regularization.

    - SGD minimizes a loss by updating parameters opposite to the gradient.
    - L2-regularized loss: \( L_{\text{reg}}(w) = L(w) + \frac{\lambda}{2} \|w\|_2^2 \)
      - \( \|w\|_2^2 \): squared L2 norm of parameters.
      - \( \lambda \): regularization strength (`weight_decay`).
    - Gradient: \( \nabla L_{\text{reg}}(w) = \nabla L(w) + \lambda w \)
    - Update rule: \( w \leftarrow w - \eta (\nabla L(w) + \lambda w) \)
      - \( \eta \): learning rate (`lr`).
    - Implementation: adds \( \lambda w \) to gradient as weight decay.
    """

    def __init__(self, params, lr=0.01, weight_decay=0.0):
        super().__init__(params)
        self.lr = lr
        self.weight_decay = weight_decay

    def step(self):
        for w in self.params:
            gradient = w.grad + self.weight_decay * w.data
            w.data = w.data - self.lr * gradient

#### Test forward, backward, step for linear

In [290]:
import torch
import torch.nn as nn


def test_linear_vs_pytorch():
    # Test parameters
    in_features, out_features = 20, 30
    batch_size = 2
    lr = 0.005

    # Pre-generate weights for identical initialization
    weight = np.random.randn(out_features, in_features).astype(np.float32) * 0.01
    bias = np.zeros(out_features, dtype=np.float32)

    # Initialize your implementation with fixed weights
    linear = Linear(in_features, out_features)
    linear.weight.data = weight.copy()
    linear.bias.data = bias.copy()

    # Initialize PyTorch equivalent with same weights
    torch_linear = nn.Linear(in_features, out_features, bias=True)
    torch_linear.weight.data = torch.tensor(weight, dtype=torch.float32)
    torch_linear.bias.data = torch.tensor(bias, dtype=torch.float32)

    # Generate random input
    np_x = np.random.randn(batch_size, in_features).astype(np.float32)
    x = Tensor(np_x, requires_grad=False)
    torch_x = torch.tensor(np_x, requires_grad=False, dtype=torch.float32)

    # Random target
    np_y = np.random.randn(batch_size, out_features).astype(np.float32)
    y = Tensor(np_y, requires_grad=False)
    torch_y = torch.tensor(np_y, requires_grad=False, dtype=torch.float32)

    # Forward pass
    output = linear(x)
    torch_output = torch_linear(torch_x)
    # return output, torch_output
    # Check output
    np.testing.assert_allclose(output.data, torch_output.detach().numpy(), rtol=1e-4)
    print("Forward pass matches PyTorch ✓")

    # Loss
    loss = summation((output - y) ** 2)
    torch_loss = ((torch_output - torch_y) ** 2).sum()

    # Backward pass
    loss.backward()
    torch_loss.backward()

    # Check gradients
    np.testing.assert_allclose(
        linear.weight.grad, torch_linear.weight.grad.numpy(), rtol=1e-4
    )
    np.testing.assert_allclose(
        linear.bias.grad, torch_linear.bias.grad.numpy(), rtol=1e-4
    )
    print("Gradients match PyTorch ✓")

    # Optimizer step
    opt = NaiveSGD([linear.weight, linear.bias], lr=lr)
    torch_opt = torch.optim.SGD(torch_linear.parameters(), lr=lr)

    opt.step()
    torch_opt.step()

    # Check updated weights
    np.testing.assert_allclose(
        linear.weight.data, torch_linear.weight.data.numpy(), rtol=1e-4
    )
    np.testing.assert_allclose(
        linear.bias.data, torch_linear.bias.data.numpy(), rtol=1e-4
    )
    print("Parameter updates match PyTorch ✓")

    print("All tests passed!")

#### Continue

In [10]:
net = snn.Sequential(snn.Linear(3, 4), snn.ReLU(), snn.Linear(4, 3))
sgd = NaiveSGD(net.parameters(), lr=0.04)

In [None]:
def train_model(
    model,
    train_dataset,
    device,
    epochs=100,
    batch_size=128,
    learning_rate=0.01,
    log_steps=50,
):
    """Train the neural network using SGD optimizer."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

    # Create DataLoader for training data
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        model.train()  # Set model to training mode
        total_loss = 0

        for batch_X, batch_y in train_loader:
            # Move batch to device
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()

            # Forward pass
            predictions = model(batch_X)

            # Reshape for cross entropy loss
            batch_size = batch_X.size(0)
            predictions = predictions.view(batch_size * model.output_digits, 4)
            targets = batch_y.view(batch_size * model.output_digits, 4)
            targets = torch.argmax(targets, dim=1)  # Convert one-hot to class indices

            # Calculate loss
            loss = criterion(predictions, targets)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if (epoch + 1) % log_steps == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")