In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import sys

sys.path.insert(0, '/home/storage2/hans/jax_reco/python')
from network import TriplePandleNet
from trafos import transform_network_outputs, transform_dimensions

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import jax.numpy as jnp
import jax

from jax import config
config.update("jax_enable_x64", True)



In [3]:
net = TriplePandleNet('../../data/network/')

In [4]:
dist = 25
z = -210
rho = 0.0
zenith = np.pi/2
azimuth = 0.0

x = transform_dimensions(dist, rho, z, zenith, azimuth)

In [5]:
print(net.eval(x).shape)

(9,)


In [6]:
# time single evaluation
%timeit net.eval(x).block_until_ready()

168 µs ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
# testing low level
#from network import _eval_network
#params = net.get_network_params()
#print(params[0][0].shape, params[0][1].shape)
#%timeit _eval_network(x, params).block_until_ready()

In [8]:
# now try 500 evaluations in parallel on the gpu

np.random.seed(2)

batch_size = 500

x = np.array(x)
xx = x[np.newaxis, :]
xx = np.repeat(xx, batch_size, axis=0)
y = np.random.normal(0.025, 0.001, 500)
xx[:, 0] = y
xx = jnp.array(xx)
xx.devices()
print(xx.shape)

(500, 7)


In [9]:
%timeit net.eval_on_batch(xx).block_until_ready()

951 µs ± 1.92 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
eval_batch = jax.jit(net.eval_on_batch)

In [11]:
%timeit eval_batch(xx).block_until_ready()

917 µs ± 3.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
times = 20*np.ones(batch_size).reshape(1, batch_size)

In [13]:
z = net.eval_on_batch(xx)
logits, a, b = transform_network_outputs(z)

dist = tfd.Independent(
    distribution = tfd.MixtureSameFamily(
              mixture_distribution=tfd.Categorical(
                  logits=logits
                  ),
              components_distribution=tfd.Gamma(
                concentration=a,    
                rate=b,
                force_probs_to_zero_outside_support=True
                  )
            ),
    reinterpreted_batch_ndims=1
)

log_prob = jax.jit(dist.log_prob)
#log_prob = dist.log_prob
# evaluate gamma likelihood 
log_prob(times)

Array([-2249.39156307], dtype=float64)

In [14]:
z = eval_batch(xx)
logits, a, b = transform_network_outputs(z)

dist = tfd.Independent(
    distribution = tfd.MixtureSameFamily(
              mixture_distribution=tfd.Categorical(
                  logits=logits
                  ),
              components_distribution=tfd.Gamma(
                concentration=a,    
                rate=b,
                force_probs_to_zero_outside_support=True
                  )
            ),
    reinterpreted_batch_ndims=1
)

log_prob = jax.jit(dist.log_prob)
log_prob(times)

Array([-2249.39156307], dtype=float64)

In [15]:
%timeit log_prob(times).block_until_ready()

47.6 µs ± 2.09 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
# time combination of everything

@jax.jit
def do_likelihood_for_500_doms(x):
    # 500 NN evaluations
    z = net.eval_on_batch(x)
    logits, a, b = transform_network_outputs(z)

    # 500 PDF evaluations
    dist = tfd.Independent(
        distribution = tfd.MixtureSameFamily(
                  mixture_distribution=tfd.Categorical(
                      logits=logits
                      ),
                  components_distribution=tfd.Gamma(
                    concentration=a,    
                    rate=b,
                    force_probs_to_zero_outside_support=True
                      )
                ),
        reinterpreted_batch_ndims=1
    )
    return dist.log_prob(times)

In [17]:
do_likelihood_for_500_doms(xx)

Array([-2249.39156307], dtype=float64)

In [19]:
%timeit -n 10000 do_likelihood_for_500_doms(xx).block_until_ready()

949 µs ± 21.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
print(xx)

In [None]:
print(log_prob(times))

In [None]:
# Array([-2248.86100657], dtype=float64) # seed 0

In [None]:
# Array([-2247.3552]

In [None]:
# Array([-2248.8613]

In [None]:
#Array([-2252.40582564], dtype=float64) # seed 1

In [None]:
#Array([-2250.8875]

In [None]:
#Array([-2252.4062]

In [None]:
# Array([-2249.39156307], dtype=float64)

In [None]:
# Array([-2247.905],

In [None]:
# Array([-2249.392]