In [45]:
from src.models import SphericalVAE
from scipy.special import ive
from torch.distributions.kl import kl_divergence


import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double)
svae = SphericalVAE(3, 2, encoder_params={"layer_sizes" : [5]}, decoder_params={"layer_sizes" : [5]})

In [46]:
output = svae(x)
px, pz, qz, z = [output[k] for k in ["px", "pz", "qz", "z"]]

kl_term = kl_divergence(qz, pz)
log_px = px.log_prob(x).sum(-1)

loss = -log_px + kl_term
loss = loss.mean()

First, we try to retrieve the loss wrt. $k$, $\mu$ and the parameters of the decoder, and then use these to figure out the gradient wrt. the parameters of the model as a whole:

In [47]:
loss_d_k.mean()

tensor(0.0227, dtype=torch.float64)

In [48]:
svae.zero_grad()

log_px_d_k, = torch.autograd.grad(log_px, qz.k, grad_outputs=torch.ones_like(qz.k),  retain_graph=True)
kl_term_d_k, = torch.autograd.grad(kl_term, qz.k, grad_outputs=torch.ones_like(qz.k),  retain_graph=True)

loss_d_k = (-log_px_d_k + kl_term_d_k) / len(qz.k)

loss_d_mu = torch.autograd.grad(loss, qz.mu, retain_graph=True)
loss_d_decoder = torch.autograd.grad(loss, svae.decoder.parameters(), retain_graph=True)

torch.autograd.backward(svae.decoder.parameters(), grad_tensors=loss_d_decoder, retain_graph=True)
torch.autograd.backward(qz.k, grad_tensors=loss_d_k, retain_graph=True)
torch.autograd.backward(qz.mu, grad_tensors=loss_d_mu, retain_graph=True)

decomposed = []
for param in svae.parameters():
    decomposed.append(param.grad.clone())


svae.zero_grad()
loss.mean().backward(retain_graph=True)

usual = []
for param in svae.parameters():
    usual.append(param.grad.clone())



In [49]:
all(torch.isclose(a,b).all() for a, b in zip(decomposed, usual))

True

Then we can add the correspond gradient wrt. $k$, again using autograd to calculate the gradient of the correction term wrt. kappa.

In [50]:
svae.zero_grad()

log_px_d_k, = torch.autograd.grad(log_px, qz.k, grad_outputs=torch.ones_like(qz.k),  retain_graph=True)
kl_term_d_k, = torch.autograd.grad(kl_term, qz.k, grad_outputs=torch.ones_like(qz.k),  retain_graph=True)

loss_d_mu, = torch.autograd.grad(loss, qz.mu, retain_graph=True)
loss_d_decoder = torch.autograd.grad(loss, svae.decoder.parameters(), retain_graph=True)

In [51]:
eps = qz.saved_for_grad["eps"]
w = qz.saved_for_grad["w"]
b = qz.saved_for_grad["b"]

corr_term = (
    w * qz.k
    + 1 / 2 * (qz.m - 3) * torch.log(1 - w ** 2)
    + torch.log(torch.abs(((-2 * b) / (((b - 1) * eps + 1) ** 2))))
)

corr_term_d_k, = torch.autograd.grad(corr_term, qz.k, grad_outputs=torch.ones_like(corr_term), retain_graph=True)
corr_term_d_k

tensor([0.5009, 0.6706], dtype=torch.float64)

Now, we can construct the corrected gradient

In [52]:
with torch.no_grad():
    g_cor = log_px * ( -ive(qz.m/2, qz.k)/ive(qz.m/2-1, qz.k) + corr_term_d_k)
g_cor

tensor([-0.1673, -4.4906], dtype=torch.float64)

In [53]:
log_px_d_k_adj = log_px_d_k + g_cor
loss_d_k = (-log_px_d_k_adj + kl_term_d_k) / len(qz.k)

In [54]:
torch.autograd.backward(svae.decoder.parameters(), grad_tensors=loss_d_decoder, retain_graph=True)
torch.autograd.backward(qz.k, grad_tensors=loss_d_k, retain_graph=True)
torch.autograd.backward(qz.mu, grad_tensors=loss_d_mu, retain_graph=True)

In [55]:
adjusted = []
for param in svae.parameters():
    adjusted.append(param.grad.clone())

In [56]:
for a, b in zip(usual, adjusted):
    print("Without correction")
    print(a.numpy())
    print("With correction")
    print(b.numpy())

Without correction
[[ 0.          0.          0.        ]
 [-3.84098219 -4.80122774 -5.76147329]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]]
With correction
[[ 0.          0.          0.        ]
 [-1.0841667  -1.35520837 -1.62625005]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]]
Without correction
[ 0.         -0.96024555  0.          0.          0.        ]
With correction
[ 0.         -0.27104167  0.          0.          0.        ]
Without correction
[[ 0.          4.6254999   0.          0.          0.        ]
 [ 0.         -9.42140313  0.          0.          0.        ]
 [ 0.          1.33468114  0.          0.          0.        ]]
With correction
[[ 0.          4.6254999   0.          0.          0.        ]
 [ 0.         -9.42140313  0.          0.          0.        ]
 [ 0.          2.98612796  0.          0.          0.        ]]