In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import sys

sys.path.insert(0, '/home/storage2/hans/jax_reco/python')
from network import TriplePandleNet
from trafos import transform_network_outputs, transform_dimensions

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
import jax.numpy as jnp
import jax



In [3]:
net = TriplePandleNet('../../data/network/')

In [4]:
dist = 25
z = -210
rho = 0.0
zenith = np.pi/2
azimuth = 0.0

x = transform_dimensions(dist, rho, z, zenith, azimuth)

In [5]:
print(net.eval(x).shape)

(9,)


In [6]:
# time single evaluation
%timeit net.eval(x).block_until_ready()

72.6 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
# testing low level
#from network import _eval_network
#params = net.get_network_params()
#print(params[0][0].shape, params[0][1].shape)
#%timeit _eval_network(x, params).block_until_ready()

In [8]:
# now try 500 evaluations in parallel on the gpu
batch_size = 500

x = np.array(x)
xx = x[np.newaxis, :]
xx = np.repeat(xx, batch_size, axis=0)
y = np.random.normal(0.025, 0.001, 500)
xx[:, 0] = y
xx = jnp.array(xx)
print(xx.shape)

(500, 7)


In [9]:
%timeit net.eval_on_batch(xx).block_until_ready()

127 µs ± 6.85 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
eval_batch = jax.jit(net.eval_on_batch)

In [11]:
%timeit eval_batch(xx).block_until_ready()

112 µs ± 8.16 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
z = net.eval_on_batch(xx)
print(z.shape)
logits, a, b = transform_network_outputs(z)

dist = tfd.Independent(
    distribution = tfd.MixtureSameFamily(
              mixture_distribution=tfd.Categorical(
                  logits=logits
                  ),
              components_distribution=tfd.Gamma(
                concentration=a,    
                rate=b,
                force_probs_to_zero_outside_support=True
                  )
            ),
    reinterpreted_batch_ndims=1
)

log_prob = jax.jit(dist.log_prob)

(500, 9)


In [13]:
# evaluate gamma likelihood 
times = 20*np.ones(batch_size).reshape(1, batch_size)

In [14]:
log_prob(times)

Array([-2245.1377], dtype=float32)

In [15]:
%timeit log_prob(times).block_until_ready()

51.7 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
# time combination of everything

@jax.jit
def do_likelihood_for_500_doms(time):
    # 500 NN evaluations
    z = net.eval_on_batch(xx)
    logits, a, b = transform_network_outputs(z)

    # 500 PDF evaluations
    dist = tfd.Independent(
        distribution = tfd.MixtureSameFamily(
                  mixture_distribution=tfd.Categorical(
                      logits=logits
                      ),
                  components_distribution=tfd.Gamma(
                    concentration=a,    
                    rate=b,
                    force_probs_to_zero_outside_support=True
                      )
                ),
        reinterpreted_batch_ndims=1
    )
    return dist.log_prob(time)

In [17]:
print(do_likelihood_for_500_doms(times))

[-2246.6128]


In [18]:
%timeit do_likelihood_for_500_doms(times).block_until_ready()

54.3 µs ± 4.67 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
