In [1]:
import jax.numpy as jnp
import jax
import numpy as np

In [2]:
_recip__speedOfLight = 3.3356409519815204
_n__ = 1.32548384613875
_tan__thetaC = (_n__**2.-1.)**0.5

In [3]:
@jax.jit
def closest_distance_dom_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_d = v_a - jnp.dot(v_a, track_dir) * track_dir
    dist = jnp.linalg.norm(v_d)
    return dist

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
closest_distance_dom_track_v = jax.jit(jax.vmap(closest_distance_dom_track, (0, None, None), 0))

In [4]:
def convert_spherical_to_cartesian_direction(x):
    """
    """
    track_theta = x[0]
    track_phi = x[1]
    track_dir_x = jnp.sin(track_theta) * jnp.cos(track_phi)
    track_dir_y = jnp.sin(track_theta) * jnp.sin(track_phi)
    track_dir_z = jnp.cos(track_theta)
    direction = jnp.array([track_dir_x, track_dir_y, track_dir_z])
    return direction

# Generalize to matrix input for x with shape (N_DOMs, 2) for theta and phi angles.
# Output will be in form of (N_DOMs, 3) for dir_x, dir_y, dir_z
convert_spherical_to_cartesian_direction_v = jax.jit(jax.vmap(closest_distance_dom_track, 0, 0)) 

In [5]:
dom_pos = jnp.array([0, 0, 100])
track_pos = jnp.array([0, 0, 0])
track_theta = jnp.deg2rad(90.)
track_phi = jnp.deg2rad(100.)
track_dir = convert_spherical_to_cartesian_direction(jnp.array([track_theta, track_phi]))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [6]:
print(closest_distance_dom_track(dom_pos, track_pos, track_dir))

100.0


In [7]:
# now try with batched inputs
dom_pos_v = np.random.normal(0, 500, (100, 3))
print(dom_pos_v.shape)

(100, 3)


In [8]:
print(closest_distance_dom_track_v(dom_pos_v, track_pos, track_dir))

[ 734.9613   1132.05      411.29584   670.467     615.91864   385.05234
  567.0591    280.3011    604.7056   1474.9904     17.213976  393.51273
 1043.9243   1079.267     699.37006   499.469     154.96368   729.6035
  671.9363    595.9386   1013.52576   412.34747   466.35406   161.17166
  872.0785   1097.5822    914.09503   997.76404   477.5296    367.04446
 1083.1692    743.09607   786.1463    579.9173    215.93436   658.3945
  687.6606     77.9546    942.1284    213.10603   222.6412    806.3688
  792.4577    623.7761    874.5432    392.2688    119.551704  429.0185
  880.1689    995.70874   292.1692    374.62088  1044.158     230.21112
  795.58325   230.00607  1229.6013    739.72186  1086.9563    452.9667
 1367.9958    870.35596   426.91028   678.17737   817.7927    942.7511
 1103.4459    312.6761    511.7104   1130.4279    348.9942    244.69107
  414.97946   502.27075   537.46136   462.35666   835.7597    192.31586
 1055.8003    621.78546   225.70782   613.53625   604.58746   874.2267

In [9]:
def light_travel_time(dom_pos, track_pos, track_dir):
    """
    Computes the direct, unscattered time it takes for a photon to travel from 
    the track to the dom.
    See Eq. 4 of the AMANDA track reco paper https://arxiv.org/pdf/astro-ph/0407044
    """
    closest_dist = closest_distance_dom_track(dom_pos, track_pos, track_dir)
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # distance muon travels from support point to point of closest approach.
    d1 = jnp.dot(v_a, track_dir)
    # distance that muon travels beyond closest approach until photon hits.
    d2 = closest_dist * _tan__thetaC 
    return (d1+d2) * _recip__speedOfLight

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
light_travel_time_v = jax.jit(jax.vmap(light_travel_time, (0, None, None), 0))

In [10]:
print(light_travel_time(dom_pos, track_pos, track_dir))

290.20215


In [11]:
print(light_travel_time_v(dom_pos_v, track_pos, track_dir))

[ 3697.579     3304.5317     -58.9647    2435.801     1984.549
  1768.092     1697.8025   -1379.6575     518.9842    1617.301
 -2991.5251    1309.7119    5617.3037    6037.9404     810.4508
  -560.91736    702.6794    2529.8762    3890.486     3454.4038
  3905.3098    -575.76556   2029.754      126.12225   2245.5742
  4742.1323     543.7948    2251.4482     569.20776   2017.5795
  3664.6245    3749.7583    3806.6274     237.86401   1088.8313
   655.47485   4575.8086   -1696.121     3682.4167    -281.97598
   511.88828   1132.6306     341.84747    324.13254    -27.847332
 -3009.6797     664.77075   -851.27185   3401.593     2453.259
  -203.605     2481.4385    3242.6323    -679.4367     -91.04192
  1452.0001    7179.0454    2202.5168    1605.7339    2530.356
  5286.1396    1617.7802    -534.16504   1499.0654    3540.3552
  3688.9028    4167.475     1356.9147   -1067.4904    4221.172
 -1371.9436    5448.301     2312.448    -1303.6566    1586.5406
  -303.65512   2315.8386    1262.0613    

In [12]:
@jax.jit
def z_component_closest_point_on_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_c = track_pos + jnp.dot(v_a, track_dir) * track_dir
    return v_c[2]

z_component_closest_point_on_track_v = jax.jit(jax.vmap(z_component_closest_point_on_track, (0, None, None), 0))

In [13]:
print(z_component_closest_point_on_track(dom_pos, track_pos, track_dir))

1.9106854e-13


In [14]:
print(z_component_closest_point_on_track_v(dom_pos_v, track_pos, track_dir))

[-2.05044362e-05 -2.52883410e-07  1.64139019e-05 -6.42232453e-06
 -2.58338423e-06 -8.52649646e-06 -6.83846338e-07  2.87391031e-05
  1.61954595e-05  3.48988797e-05  3.98566081e-05 -2.19798267e-06
 -3.39115868e-05 -3.80797064e-05  1.59759838e-05  2.63448092e-05
 -3.31502360e-06 -5.40620749e-06 -2.54291408e-05 -2.26046923e-05
 -1.26330578e-05  2.32262319e-05 -8.86357248e-06  4.47646289e-06
  3.73757621e-06 -2.04024800e-05  2.76361334e-05  8.44031001e-06
  1.07009246e-05 -1.24806875e-05 -6.83055941e-06 -2.08788551e-05
 -1.99869264e-05  1.89366783e-05 -6.05663263e-06  1.64485773e-05
 -3.38118625e-05  2.51910969e-05 -1.24273683e-05  1.17993404e-05
  1.75888295e-06  1.58230941e-05  2.56567619e-05  1.94740969e-05
  3.36230078e-05  5.43575079e-05 -4.16494186e-06  2.74705453e-05
 -1.11036188e-05  5.71755027e-06  1.37790403e-05 -1.82711337e-05
 -2.78419043e-06  1.76582944e-05  3.14483550e-05 -1.02805880e-05
 -4.73160580e-05 -7.31585374e-07  2.02938736e-05 -1.59327428e-05
 -1.72477448e-05  1.18989

In [15]:
@jax.jit
def rho_dom_relative_to_track(dom_pos, track_pos, track_dir):
    """
    clean up and verify!
    """
    v1 = dom_pos - track_pos
    closestapproach = track_pos + jnp.dot(v1, track_dir)*track_dir
    v2 = dom_pos - closestapproach
    zdir = jnp.cross(track_dir, jnp.cross(jnp.array([0,0,1]), track_dir))
    positivedir = jnp.cross(track_dir, zdir)
    ypart = v2-v2*jnp.dot(zdir, v2)
    zpart = v2-ypart
    z = jnp.dot(zpart, zdir)
    y = jnp.dot(ypart, positivedir)
    return jnp.arctan2(y,z)

rho_dom_relative_to_track_v = jax.jit(jax.vmap(rho_dom_relative_to_track, (0, None, None), 0))

In [16]:
print(rho_dom_relative_to_track(dom_pos, track_pos, track_dir))

4.73666e-16


In [17]:
print(jnp.rad2deg(rho_dom_relative_to_track_v(dom_pos_v, track_pos, track_dir)))

[-84.14582    -56.270733   -89.27325    -87.143036    17.247957
 -22.87404    -36.350452    15.077467   -55.26094      2.4157414
  15.572105     4.300561    39.48524     16.895443    38.063335
  45.16409     64.50946      8.6174965  -35.84839    -60.964863
  54.506367   -28.21697     39.59431    -77.21685    -48.857677
 -50.99233     61.26645    -24.964264    18.522024    43.64505
 -56.936005     6.652291    24.874058    89.09147    -42.763157
  39.254135   -39.30967     43.73764    -29.293903    65.02472
  66.77908     10.062869    55.36843     81.10722    -66.92159
  17.728918    82.39804    -77.74851     64.45234    -52.864597
  -7.4588356   47.935333   -49.33473     41.898563   -16.176819
 -58.919342    -9.361332    48.75691    -82.82179     46.0653
 -40.862408   -26.05654     40.90024    -85.43946     45.076168
 -32.6996      44.291355   -62.426823    57.864433     4.4835825
 -44.634666   -63.181137    20.470926   -63.72171    -68.275795
 -53.440094    83.675514   -83.50586    -55