In [2]:
import torch
import torch.nn.functional as F

### Simple functions

$$ y = f(x) = \sum{(x^2 + 2 \cdot x)} $$

In [81]:
x = torch.arange(10, dtype=torch.float, requires_grad=True)

y = torch.sum(x ** 2 + 2 * x)

In [82]:
dy_dx_analytic = 2 * x + 2 

In [83]:
y.backward(retain_graph=True)  # calculates gradient w.r.t. graph nodes

In [84]:
dy_dx_numeric = x.grad.clone()

In [85]:
bool(torch.all(dy_dx_numeric == dy_dx_analytic))

True

$$ y = W_{hy} h $$
$$ p = softmax(y) $$
$$ loss = -log(p) $$

In [205]:
n = 10
m = 20

w = torch.randn(n, m, requires_grad=True)
h = torch.randint(20, (20, 1), dtype=torch.float)
y = torch.matmul(w, h)
p = F.softmax(y, dim=0)

label = torch.zeros_like(p)
label[5] = 1.

loss = -torch.sum(label * torch.log(p))

In [206]:
loss

tensor(5.5026, grad_fn=<NegBackward>)

In [207]:
loss.backward()

In [208]:
w_analytic_grad = torch.matmul((p - label) , h.view(1, -1))

In [210]:
bool(torch.all(w_analytic_grad == w.grad.data))

True

In [211]:
torch.equal(w_analytic_grad,  w.grad.data)

True