In [2]:
import numpy as np
from microtorch import Tensor
from microtorch import nn

import torch

In [44]:
import torch
print (torch.__version__)

_ = torch.manual_seed (2024)

def softmax_grad(probs):
    tensor = probs.clone().detach()
    flat = torch.flatten(tensor)
    diagonal = torch.diagflat(flat)
    off_diagonal = torch.outer(flat, flat)
    return diagonal - off_diagonal
    

2.3.0


In [45]:
y = torch.arange (4.)                            # no batch dimension
probs_y = y.softmax (dim = -1)

print ('y = ...')
print (y)
print ('probs_y = ...')
print (probs_y)

jacyA = softmax_grad (probs_y)                   # works with no batch dimension
jacyB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), y)

print ('jacyA = ...')
print (jacyA)
print ('torch.allclose (jacyA, jacyB) =', torch.allclose (jacyA, jacyB))

y = ...
tensor([0., 1., 2., 3.])
probs_y = ...
tensor([0.0321, 0.0871, 0.2369, 0.6439])
jacyA = ...
tensor([[ 0.0310, -0.0028, -0.0076, -0.0206],
        [-0.0028,  0.0796, -0.0206, -0.0561],
        [-0.0076, -0.0206,  0.1808, -0.1525],
        [-0.0206, -0.0561, -0.1525,  0.2293]])
torch.allclose (jacyA, jacyB) = True


In [46]:
x = torch.arange (20).reshape (5, 4)             # has batch dimension
x = x.float().requires_grad_()
probs_x = x.softmax (dim = -1)

print ('x = ...')
print (x)
print ('probs_x = ...')
print (probs_x)

jacxA = softmax_grad (probs_x)                   # incorrectly flattens batch and features dimensions together
jacxB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), x)

print ('jacxA.shape =', jacxA.shape)             # wrong shape for jacobian of 2d tensor
print ('jacxB.shape =', jacxB.shape)

jfA = jacxA.flatten()                            # reshaping doesn't help
jfB = jacxB.flatten()
print ('(jfA - jfB).abs().max() =', (jfA - jfB).abs().max())

print ('(jfA == 0).sum() =', (jfA == 0).sum())   # too many non-zero elements (but ...)
print ('(jfB == 0).sum() =', (jfB == 0).sum())

print ('torch.allclose (jfA[torch.where (jfB != 0)], jfB[torch.where (jfB != 0)]) =', torch.allclose (jfA[torch.where (jfB != 0)], jfB[torch.where (jfB != 0)]))

x = ...
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.],
        [16., 17., 18., 19.]], requires_grad=True)
probs_x = ...
tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]], grad_fn=<SoftmaxBackward0>)
jacxA.shape = torch.Size([20, 20])
jacxB.shape = torch.Size([5, 4, 5, 4])
(jfA - jfB).abs().max() = tensor(0.4146)
(jfA == 0).sum() = tensor(0)
(jfB == 0).sum() = tensor(320)
torch.allclose (jfA[torch.where (jfB != 0)], jfB[torch.where (jfB != 0)]) = True


In [47]:
z = torch.randn (5, 4, requires_grad = True)     # has batch dimension

probs_z = z.softmax (dim = -1)

print ('z = ...')
print (z)
print ('probs_z = ...')
print (probs_z)
print ('probs_z.sum (dim = -1) = ...')
print (probs_z.sum (dim = -1))                   # probs sum to one, a constant, for each row in batch

probs_z.sum().backward()                         # sum over batch, as well as features, to get a scalar
gradzA = z.grad                                  # gradient of sum of softmax is zero because sum is constant
jaczB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), z)
gradzB = jaczB.sum (dim = (-2, -1))              # sum of partials (jacobian) is derivative (gradient) of sum

print ('gradzA = ...')
print (gradzA)
print ('torch.allclose (gradzA, gradzB, atol = 1.e-7) =', torch.allclose (gradzA, gradzB, atol = 1.e-7))

z = ...
tensor([[-1.2262, -0.0093,  1.5420, -0.4657],
        [ 1.8567,  1.9776, -0.4322,  1.3667],
        [ 0.7131, -0.3869, -0.2535, -1.6675],
        [ 0.9962,  0.9391,  1.4148,  0.6343],
        [-0.0776, -1.1175, -0.6481,  0.6530]], requires_grad=True)
probs_z = ...
tensor([[0.0446, 0.1504, 0.7097, 0.0953],
        [0.3518, 0.3970, 0.0357, 0.2155],
        [0.5538, 0.1843, 0.2107, 0.0512],
        [0.2403, 0.2270, 0.3653, 0.1674],
        [0.2503, 0.0885, 0.1415, 0.5197]], grad_fn=<SoftmaxBackward0>)
probs_z.sum (dim = -1) = ...
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)
gradzA = ...
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [6.6016e-08, 2.1975e-08, 2.5112e-08, 6.1062e-09],
        [1.4326e-08, 1.3531e-08, 2.1772e-08, 9.9763e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])
torch.allclose (gradzA, gradzB, atol = 1.e-7) = True


In [172]:
def softmax(t: Tensor, dim: int = -1):
    m = t.data.max()
    e = np.exp(t.data - m)
    s = e / e.sum(axis=dim, keepdims=True)
    out = Tensor(s, _children=(t,), _op='softmax')

    def softmax_backward():
        for i in range(len(s)):
            si = s[i].reshape(1, -1)
            j = np.diagflat(si) - si.T@si  # jacobian matrix
            _j = np.concatenate((j, np.zeros_like(j)), axis=-1).reshape(4, 2, 4)
            print(_j.shape)
            print(_j.sum(axis=0))
            t.grad += out.grad@j
        
    out._backward = softmax_backward
    return out

In [281]:
np.transpose(np.ones((2, 3, 4)), axis=[1, 2, 0])

TypeError: transpose() got an unexpected keyword argument 'axis'

In [334]:
def softmax(t: Tensor, dim: int = -1):
    m = t.data.max()
    e = np.exp(t.data - m)
    s = e / e.sum(axis=dim, keepdims=True)
    out = Tensor(s, _children=(t,), _op='softmax')

    s = np.moveaxis(s, dim, -1)
    shape = s.shape
    s = s.reshape(-1, s.shape[-1])

    def softmax_backward():
        grad = np.zeros_like(s)
        for i in range(len(s)):
            si = s[i].reshape(1, -1)
            j = np.diagflat(si) - si.T@si  # jacobian matrix
            grad[i] += j.sum(-1)
        grad = grad.reshape(shape)
        grad = np.moveaxis(grad, -1, dim)
        t.grad += grad
        
    out._backward = softmax_backward
    return out

In [348]:
torch.manual_seed(42)
x = torch.arange(4).float().reshape(1, 2, 2).numpy() * 10
x = Tensor(x)
x

Tensor([[[ 0., 10.],
         [20., 30.]]], dtype=float32)

In [352]:
s = softmax(x, -2)
s

Tensor([[[2.0611535e-09, 2.0611537e-09],
         [1.0000000e+00, 1.0000000e+00]]], dtype=float32)

In [353]:
s.backward()

In [354]:
x.grad

array([[[-3.6379788e-12, -3.6674464e-08],
        [-2.0611535e-09,  2.0872665e-08]]], dtype=float32)