Skip to content

Commit

Permalink
Avoid accidental type promotion in gamma sampler gradient. (google#2150)
Browse files Browse the repository at this point in the history
Reformat gamma sampler to use 2 space indent, consistent with the rest of JAX.
  • Loading branch information
hawkinsp committed Feb 3, 2020
1 parent 0644f5c commit 0b1d2fc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 93 deletions.
186 changes: 93 additions & 93 deletions jax/random.py
Expand Up @@ -782,99 +782,99 @@ def _next_kxv(kxv):
0.017050642, -0.0021309345, 0.00085092385, -1.5248239e-07]]

def _gamma_grad_one(z, alpha):
# Ref 1: Pathwise Derivatives Beyond the Reparameterization Trick, Martin & Fritz
# Ref 2: Case 4 follows https://github.com/fritzo/notebooks/blob/master/gamma-reparameterized.ipynb

# TODO: use lax.cond instead of lax.while_loop when its batching rule is available
# See https://github.com/google/jax/issues/490
def _case1(zagf):
z, alpha, _, flag = zagf

# dz = - dCDF(z; a) / pdf(z; a)
# pdf = z^(a-1) * e^(-z) / Gamma(a)
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
# =: unnormalized_dCDF / Gamma(a)
# IncompleteGamma ~ z^a [ 1/a - z/(a+1) + z^2/2!(a+2) - z^3/3!(a+3) + z^4/4!(a+4) - z^5/5!(a+5) ]
# =: z^a * term1
# dIncompleteGamma ~ z^a * log(z) * term1 - z^a [1/a^2 - z/(a+1)^2 + z^2/2!(a+2)^2
# - z^3/3!(a+3)^2 + z^4/4!(a+4)^2 - z^5/5!(a+5)^2 ]
# =: z^a * log(z) * term1 - z^a * term2
# unnormalized_dCDF = z^a { [log(z) - Digamma(a)] * term1 - term2 }
zi = 1.0
update = zi / alpha
term1 = update
term2 = update / alpha
for i in range(1, 6):
zi = -zi * z / i
update = zi / (alpha + i)
term1 = term1 + update
term2 = term2 + update / (alpha + i)

unnormalized_cdf_dot = np.power(z, alpha) * ((np.log(z) - lax.digamma(alpha)) * term1 - term2)
unnormalized_pdf = np.power(z, alpha - 1) * np.exp(-z)
grad = -unnormalized_cdf_dot / unnormalized_pdf

return z, alpha, grad, ~flag

def _cond2(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & ((z < 0.9 * alpha) | (z > 1.1 * alpha))

def _case2(zagf):
z, alpha, _, flag = zagf

# Formula 58 of [1]
sqrt_8a = np.sqrt(8 * alpha)
z_minus_a = z - alpha
log_z_div_a = np.log(z / alpha)
sign = np.where(z < alpha, 1.0, -1.0)
term1 = 4 * (z + alpha) / (sqrt_8a * z_minus_a * z_minus_a)
term2 = log_z_div_a * (sqrt_8a / z_minus_a + sign * np.power(z_minus_a - alpha * log_z_div_a, -1.5))
term3 = z * (1.0 + 1.0 / (12 * alpha) + 1.0 / (288 * alpha * alpha)) / sqrt_8a
grad = (term1 + term2) * term3

return z, alpha, grad, ~flag

def _cond3(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & (z >= 0.9 * alpha) & (z <= 1.1 * alpha)

def _case3(zagf):
z, alpha, _, flag = zagf

# Formula 59 of [1]
z_div_a = np.divide(z, alpha)
aa = alpha * alpha
term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107
term2 = 1244160 * alpha * aa
term3 = 1 + 24 * alpha + 288 * aa
grad = term1 * term3 / term2

return z, alpha, grad, ~flag

def _case4(zagf):
z, alpha, _, flag = zagf

# Ref [2]
u = np.log(z / alpha)
v = np.log(alpha)
c = []
for i in range(8):
c.append(_bivariate_coef[0][i] + u * (_bivariate_coef[1][i] + u * _bivariate_coef[2][i]))
p = c[0] + v * (c[1] + v * (c[2] + v * c[3]))
q = c[4] + v * (c[5] + v * (c[6] + v * c[7]))
grad = np.exp(p / np.maximum(q, 0.01))

return z, alpha, grad, ~flag

_, _, grad, flag = lax.while_loop(lambda zagf: (~zagf[3]) & (zagf[0] < 0.8),
_case1,
(z, alpha, lax._const(alpha, 0.0), False))
_, _, grad, flag = lax.while_loop(_cond2, _case2, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(_cond3, _case3, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(lambda zagf: ~zagf[3], _case4, (z, alpha, grad, flag))
return grad
# Ref 1: Pathwise Derivatives Beyond the Reparameterization Trick, Martin & Fritz
# Ref 2: Case 4 follows https://github.com/fritzo/notebooks/blob/master/gamma-reparameterized.ipynb

# TODO: use lax.cond instead of lax.while_loop when its batching rule is available
# See https://github.com/google/jax/issues/490
def _case1(zagf):
z, alpha, _, flag = zagf

# dz = - dCDF(z; a) / pdf(z; a)
# pdf = z^(a-1) * e^(-z) / Gamma(a)
# CDF(z; a) = IncompleteGamma(a, z) / Gamma(a)
# dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a)
# =: unnormalized_dCDF / Gamma(a)
# IncompleteGamma ~ z^a [ 1/a - z/(a+1) + z^2/2!(a+2) - z^3/3!(a+3) + z^4/4!(a+4) - z^5/5!(a+5) ]
# =: z^a * term1
# dIncompleteGamma ~ z^a * log(z) * term1 - z^a [1/a^2 - z/(a+1)^2 + z^2/2!(a+2)^2
# - z^3/3!(a+3)^2 + z^4/4!(a+4)^2 - z^5/5!(a+5)^2 ]
# =: z^a * log(z) * term1 - z^a * term2
# unnormalized_dCDF = z^a { [log(z) - Digamma(a)] * term1 - term2 }
zi = 1.0
update = zi / alpha
term1 = update
term2 = update / alpha
for i in range(1, 6):
zi = -zi * z / i
update = zi / (alpha + i)
term1 = term1 + update
term2 = term2 + update / (alpha + i)

unnormalized_cdf_dot = np.power(z, alpha) * ((np.log(z) - lax.digamma(alpha)) * term1 - term2)
unnormalized_pdf = np.power(z, alpha - 1) * np.exp(-z)
grad = -unnormalized_cdf_dot / unnormalized_pdf

return z, alpha, grad, ~flag

def _cond2(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & ((z < 0.9 * alpha) | (z > 1.1 * alpha))

def _case2(zagf):
z, alpha, _, flag = zagf

# Formula 58 of [1]
sqrt_8a = np.sqrt(8 * alpha)
z_minus_a = z - alpha
log_z_div_a = np.log(z / alpha)
sign = np.where(z < alpha, lax._const(z, 1.0), lax._const(z, -1.0))
term1 = 4 * (z + alpha) / (sqrt_8a * z_minus_a * z_minus_a)
term2 = log_z_div_a * (sqrt_8a / z_minus_a + sign * np.power(z_minus_a - alpha * log_z_div_a, -1.5))
term3 = z * (1.0 + 1.0 / (12 * alpha) + 1.0 / (288 * alpha * alpha)) / sqrt_8a
grad = (term1 + term2) * term3

return z, alpha, grad, ~flag

def _cond3(zagf):
z, alpha, _, flag = zagf
return (~flag) & (alpha > 8.0) & (z >= 0.9 * alpha) & (z <= 1.1 * alpha)

def _case3(zagf):
z, alpha, _, flag = zagf

# Formula 59 of [1]
z_div_a = np.divide(z, alpha)
aa = alpha * alpha
term1 = 1440 * alpha + 6 * z_div_a * (53 - 120 * z) - 65 * z_div_a * z_div_a + 3600 * z + 107
term2 = 1244160 * alpha * aa
term3 = 1 + 24 * alpha + 288 * aa
grad = term1 * term3 / term2

return z, alpha, grad, ~flag

def _case4(zagf):
z, alpha, _, flag = zagf

# Ref [2]
u = np.log(z / alpha)
v = np.log(alpha)
c = []
for i in range(8):
c.append(_bivariate_coef[0][i] + u * (_bivariate_coef[1][i] + u * _bivariate_coef[2][i]))
p = c[0] + v * (c[1] + v * (c[2] + v * c[3]))
q = c[4] + v * (c[5] + v * (c[6] + v * c[7]))
grad = np.exp(p / np.maximum(q, 0.01))

return z, alpha, grad, ~flag

_, _, grad, flag = lax.while_loop(lambda zagf: (~zagf[3]) & (zagf[0] < 0.8),
_case1,
(z, alpha, lax._const(alpha, 0.0), False))
_, _, grad, flag = lax.while_loop(_cond2, _case2, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(_cond3, _case3, (z, alpha, grad, flag))
_, _, grad, flag = lax.while_loop(lambda zagf: ~zagf[3], _case4, (z, alpha, grad, flag))
return grad

def _gamma_grad(sample, a):
samples = np.reshape(sample, -1)
Expand Down
9 changes: 9 additions & 0 deletions tests/random_test.py
Expand Up @@ -329,6 +329,15 @@ def testGammaGrad(self, alpha):
self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)

def testGammaGradType(self):
# Regression test for https://github.com/google/jax/issues/2130
key = random.PRNGKey(0)
a = np.array(1., dtype=np.float32)
b = np.array(3., dtype=np.float32)
f = lambda x, y: random.gamma(key=key, a=x, dtype=np.float32) / y
# Should not crash with a type error.
api.vjp(f, a, b)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
Expand Down

0 comments on commit 0b1d2fc

Please sign in to comment.