In [9]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
%autoreload 2

In [11]:
import numpy as np
import torch
import simplegrad as sg
from simplegrad import Tensor
from simplegrad import ops

In [113]:
def log(tensor):
    out = Tensor(np.log(tensor.data), requires_grad=tensor.requires_grad)
    def _backward():
        tensor.grad += out.grad / tensor.data
    out._backward = _backward
    out._prev = {tensor, }
    return out
    
def logsumexp(tensor, axis=None, keepdims=False):
    """
    mathematical operations, applied to 1D vector: 
    forward: log(e^z1 + e^z2 + ... + e^zn) = sum(e^zi)
    backward: local_grad[i] = e^zi / sum(e^zi)
    ------
    for numerical stability:
    forward: log(sum(e^zi))  = log(sum(e^(zi - zmax)) * e^zmax)
                             = log(sum(e^(zi - zmax))) + log(e^zmax)
    backward: e^zi/sum(e^zi) = e^(zi - zmax) / sum(e^(zi - zmax))
    """
    max_z = np.max(tensor.data, axis=axis, keepdims=True)
    stable_z = tensor.data - max_z
    exp_stable_z = np.exp(stable_z)
    stable_sum = np.sum(exp_stable_z, axis=axis, keepdims=keepdims)
    max_term = max_z if keepdims else np.squeeze(max_z, axis=axis)
    data = np.log(stable_sum) + max_term
    out = Tensor(data, requires_grad=tensor.requires_grad)
    
    def _backward():
        if tensor.requires_grad:
            if axis is None:
                # For None axis, basically all dims.
                if not keepdims:
                    grad_shaped = out.grad * np.ones_like(tensor.data)
                    softmax_terms = exp_stable_z / np.sum(exp_stable_z)
                    tensor.grad += grad_shaped * softmax_terms
                else:
                    # keepdims=True with axis=None
                    softmax_terms = exp_stable_z / stable_sum
                    tensor.grad += out.grad * softmax_terms
            else:
                # For specific axis reduction
                grad_shaped = out.grad
                if not keepdims:
                    grad_shaped = np.expand_dims(grad_shaped, axis=axis)
    
                denom = stable_sum if keepdims \
                        else np.expand_dims(stable_sum, axis=axis)
                softmax_terms = exp_stable_z / denom
                tensor.grad += grad_shaped * softmax_terms
    
    out._backward = _backward
    out._prev = {tensor, }
    return out

def softmax(tensor, axis: int = None):
    """
    mathematical operations, applied to 1D vector:
    forward: softmax(z)[i] = e^zi / sum(e^z)
    backward: local_grad[i,j] = softmax(z)[i] * (1{i==j} - softmax(z)[j])
    ------
    for numerical stability. note: id(i, j) = 1{i == j}
    forward: softmax(z)[i] = e^zi / sum(e^z)
             = e^(zi) / e^(logsumexp(z))
             = e^(zi - logsumexp(z))
             
    backward: local_grad[i,j] = softmax(z)[i] * (id(i,j) - softmax(z)[j])
              When i=j: softmax(z)[i] * (1 - softmax(z)[i])
              When i≠j: -softmax(z)[i] * softmax(z)[j]  
    """
    lse = logsumexp(tensor, axis=axis, keepdims=True)
    out = Tensor(np.exp(tensor.data - lse.data), 
                 requires_grad=tensor.requires_grad)

    def _backward():
        if tensor.requires_grad:
            tensor.grad += out.data * (out.grad - np.sum(out.grad * out.data, axis=axis, keepdims=True))
            
    out._backward = _backward
    out._prev = {tensor, }
    return out

In [108]:
test_case = {"data": np.random.rand(3, 4) * 0.001, "axis": None, "keepdims": False}
data = test_case["data"]
axis = test_case["axis"]
keepdims = test_case["keepdims"]

pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
x = Tensor.from_torch(pt_x)
logsumexp(x, axis=axis, keepdims=keepdims)

Tensor(2.485480785369873, requires_grad=True)

In [119]:
print("Testing LogSumExp operation...")
test_logsumexp()

print("\nTesting LogSumExp specific cases...")
test_logsumexp_specific_cases()

print("\nAll LogSumExp tests completed successfully!")

Testing LogSumExp operation...
LogSumExp forward test 1 passed!
LogSumExp backward test 1 passed!
LogSumExp forward test 2 passed!
LogSumExp backward test 2 passed!
LogSumExp forward test 3 passed!
LogSumExp backward test 3 passed!
LogSumExp forward test 4 passed!
LogSumExp backward test 4 passed!
LogSumExp forward test 5 passed!
LogSumExp backward test 5 passed!
LogSumExp forward test 6 passed!
LogSumExp backward test 6 passed!
LogSumExp forward test 7 passed!
LogSumExp backward test 7 passed!
LogSumExp forward test 8 passed!
LogSumExp backward test 8 passed!
LogSumExp forward test 9 passed!
LogSumExp backward test 9 passed!
LogSumExp forward test 10 passed!
LogSumExp backward test 10 passed!

Testing LogSumExp specific cases...
LogSumExp specific case 1 (large identical values) passed!
LogSumExp specific case 2 (extreme value differences) passed!
LogSumExp specific case 3 (softmax relation) passed!

All LogSumExp tests completed successfully!


In [94]:
import numpy as np
import torch
import torch.nn.functional as F

from simplegrad.tensor import Tensor


def test_logsumexp():
    # Define diverse test cases
    test_cases = [
        # Small values
        {"data": np.random.rand(3, 4) * 0.001, "axis": None, "keepdims": False},
        # Large values
        {"data": np.random.rand(3, 4) * 100, "axis": None, "keepdims": False},
        # Negative values
        {"data": np.random.rand(3, 4) * -10, "axis": None, "keepdims": False},
        # Mixed values
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": None, "keepdims": False},
        # Single dimension reduction with keepdims=True
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": 0, "keepdims": True},
        # Single dimension reduction with keepdims=False
        {"data": np.random.rand(3, 4) * 2 - 1, "axis": 1, "keepdims": False},
        # Multiple dimensions
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": None, "keepdims": False},
        # Multiple dimensions with specific axis
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": 1, "keepdims": False},
        # Multiple dimensions with tuple axis
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": (0, 2), "keepdims": False},
        # Multiple dimensions with tuple axis and keepdims=True
        {"data": np.random.rand(2, 3, 4) * 2 - 1, "axis": (0, 2), "keepdims": True}
    ]
    
    for i, test_case in enumerate(test_cases):
        data = test_case["data"]
        axis = test_case["axis"]
        keepdims = test_case["keepdims"]
        
        # Convert to tensors
        pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
        x = Tensor.from_torch(pt_x)
        
        # PyTorch version
        if axis is None:
            # PyTorch's logsumexp requires a specific dim
            expected = torch.logsumexp(pt_x, dim=tuple(range(pt_x.dim())), keepdim=keepdims)
        elif isinstance(axis, int):
            expected = torch.logsumexp(pt_x, dim=axis, keepdim=keepdims)
        else:
            # For multiple axes, we need to handle them one by one in PyTorch
            temp = pt_x
            # Process axes in reverse order to maintain correct dimensions
            for ax in sorted(axis, reverse=True):
                temp = torch.logsumexp(temp, dim=ax, keepdim=keepdims)
            expected = temp
        
        # Our implementation
        result = logsumexp(x, axis=axis, keepdims=keepdims)
        
        # Check forward pass
        np.testing.assert_allclose(
            result.data, 
            expected.detach().numpy(), 
            rtol=1e-5, atol=1e-5,
            err_msg=f"Forward pass failed for test case {i+1}: data shape {data.shape}, axis {axis}, keepdims {keepdims}"
        )
        print(f"LogSumExp forward test {i+1} passed!")
        
        # Compute gradients
        grad_output = torch.ones_like(expected)
        expected.backward(grad_output)
        result.backward()
        
        # Check backward pass
        np.testing.assert_allclose(
            x.grad, 
            pt_x.grad.detach().numpy(), 
            rtol=1e-5, atol=1e-5,
            err_msg=f"Backward pass failed for test case {i+1}: data shape {data.shape}, axis {axis}, keepdims {keepdims}"
        )
        print(f"LogSumExp backward test {i+1} passed!")


def test_logsumexp_specific_cases():
    """Test specific edge cases for logsumexp"""
    
    # Case 1: All elements are the same (tests numerical stability)
    data = np.ones((3, 3)) * 1000  # Large identical values
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)
    
    expected = torch.logsumexp(pt_x, dim=1, keepdim=False)
    result = logsumexp(x, axis=1, keepdims=False)
    
    np.testing.assert_allclose(result.data, expected.detach().numpy(), rtol=1e-5, atol=1e-5)
    print("LogSumExp specific case 1 (large identical values) passed!")
    
    # Case 2: Extreme differences between values (tests numerical stability)
    data = np.array([[1e-10, 1e10], [1e-10, 1e-10]])
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)
    
    expected = torch.logsumexp(pt_x, dim=1, keepdim=False)
    result = logsumexp(x, axis=1, keepdims=False)
    
    np.testing.assert_allclose(result.data, expected.detach().numpy(), rtol=1e-5, atol=1e-5)
    print("LogSumExp specific case 2 (extreme value differences) passed!")
    
    # Case 3: Test with softmax relation (logsumexp is used in softmax implementation)
    data = np.random.rand(5, 10) * 2 - 1
    pt_x = torch.tensor(data, dtype=torch.float32, requires_grad=True)
    x = Tensor.from_torch(pt_x)
    
    # Standard softmax calculation using logsumexp
    pt_logsumexp = torch.logsumexp(pt_x, dim=1, keepdim=True)
    pt_softmax = torch.exp(pt_x - pt_logsumexp)
    
    our_logsumexp = logsumexp(x, axis=1, keepdims=True)
    our_softmax = np.exp(x.data - our_logsumexp.data)
    
    np.testing.assert_allclose(our_softmax, pt_softmax.detach().numpy(), rtol=1e-5, atol=1e-5)
    print("LogSumExp specific case 3 (softmax relation) passed!")


In [85]:
test_logsumexp()

LogSumExp forward test passed!
LogSumExp backward test passed!


In [47]:
X = torch.rand(2, 3, 4)

In [78]:
Y = torch.rand(4)
np.expand_dims(Y.numpy(), axis=(0,2)).shape

(1, 4, 1)

In [71]:
torch.logsumexp(X, dim=1, keepdims=True).numpy()

array([[[1.482567 , 1.7950699, 1.6493765, 1.4997053]],

       [[1.7428087, 1.6379938, 1.7570441, 1.3153229]]], dtype=float32)

In [73]:
Y = logsumexp(Tensor.from_torch(X), axis=1).data
Y

array([[[1.482567 , 1.7950699, 1.6493764, 1.4997053]],

       [[1.7428087, 1.6379938, 1.7570441, 1.3153229]]], dtype=float32)

In [74]:
np.squeeze(Y)

array([[1.482567 , 1.7950699, 1.6493764, 1.4997053],
       [1.7428087, 1.6379938, 1.7570441, 1.3153229]], dtype=float32)

In [62]:
X[0], X[1]

(tensor([[0.3181, 0.8340, 0.0570, 0.2979],
         [0.1563, 0.9888, 0.6898, 0.1264],
         [0.6209, 0.0290, 0.7662, 0.6925]]),
 tensor([[0.8639, 0.0777, 0.8982, 0.2250],
         [0.6347, 0.8138, 0.0857, 0.2969],
         [0.3748, 0.5920, 0.8112, 0.1204]]))