In [6]:
import jax.numpy as jnp
import jax
import numpy as np

In [12]:
_recip__speedOfLight = 3.3356409519815204
_n__ = 1.32548384613875
_tan__thetaC = (_n__**2.-1.)**0.5

In [2]:
@jax.jit
def closest_distance_dom_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_d = v_a - jnp.dot(v_a, track_dir) * track_dir
    dist = jnp.linalg.norm(v_d)
    return dist

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
closest_distance_dom_track_v = jax.jit(jax.vmap(closest_distance_dom_track, (0, None, None), 0))

In [27]:
def convert_spherical_to_cartesian_direction(x):
    """
    """
    track_theta = x[0]
    track_phi = x[1]
    track_dir_x = jnp.sin(track_theta) * jnp.cos(track_phi)
    track_dir_y = jnp.sin(track_theta) * jnp.sin(track_phi)
    track_dir_z = jnp.cos(track_theta)
    direction = jnp.array([track_dir_x, track_dir_y, track_dir_z])
    return direction

# Generalize to matrix input for x with shape (N_DOMs, 2) for theta and phi angles.
# Output will be in form of (N_DOMs, 3) for dir_x, dir_y, dir_z
convert_spherical_to_cartesian_direction_v = jax.jit(jax.vmap(closest_distance_dom_track, 0, 0)) 

In [116]:
dom_pos = jnp.array([0, 0, +100])
track_pos = jnp.array([0, 0, 0])
track_theta = jnp.deg2rad(90.)
track_phi = jnp.deg2rad(100.)
track_dir = convert_spherical_to_cartesian_direction(jnp.array([track_theta, track_phi]))

In [114]:
print(closest_distance_dom_track(dom_pos, track_pos, track_dir))

100.0


In [102]:
# now try with batched inputs
dom_pos_v = np.random.normal(0, 500, (100, 3))
print(dom_pos_v.shape)

(100, 3)


In [103]:
print(closest_distance_dom_track_v(dom_pos_v, track_pos, track_dir))

[ 116.697     103.03797  1024.2969    675.87854   324.52417   564.75385
  593.78436  1235.7764    936.8248    759.6192    732.629     680.7577
  863.4089    310.74988   377.15936    56.720856  779.1402    628.3892
  361.79376  1021.67975   941.6828    420.49622   851.24335   167.74657
  716.1566    440.80118  1015.7084    730.66785   788.8878    180.54437
  647.207     621.15      309.26633   765.8268    636.1266    419.40942
 1103.3851    348.2061    431.17737  1247.2603    887.40894  1068.9005
 1015.7297    433.2806   1112.8293   1181.8934    589.6558   1531.7487
 1432.3542    834.2106    788.1943    754.28516   318.9307    623.6772
 1061.9905    537.28253   702.5405    739.42957   431.49142   343.5018
  671.57635   821.2296    794.8022    764.21106   317.91022   914.76495
  813.06903   562.7532    701.4595    343.72568  1101.1233    487.846
  391.1694    202.05447   145.65205  1030.8912    641.91254   209.4851
  915.09216   797.6607    360.9073    670.15063   701.9792    420.35327
 

In [104]:
def light_travel_time(dom_pos, track_pos, track_dir):
    """
    Computes the direct, unscattered time it takes for a photon to travel from 
    the track to the dom.
    See Eq. 4 of the AMANDA track reco paper https://arxiv.org/pdf/astro-ph/0407044
    """
    closest_dist = closest_distance_dom_track(dom_pos, track_pos, track_dir)
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # distance muon travels from support point to point of closest approach.
    d1 = jnp.dot(v_a, track_dir)
    # distance that muon travels beyond closest approach until photon hits.
    d2 = closest_dist * _tan__thetaC 
    return (d1+d2) * _recip__speedOfLight

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
light_travel_time_v = jax.jit(jax.vmap(light_travel_time, (0, None, None), 0))

In [105]:
print(light_travel_time(dom_pos, track_pos, track_dir))

714.6805


In [106]:
print(light_travel_time_v(dom_pos_v, track_pos, track_dir))

[ 2196.436    -133.52304  6188.9683   4406.825    -961.5291  -1031.0009
  2907.7957   1853.7916   1952.452    3301.1682   -761.5907   2573.8696
  3919.2664     35.49175  2223.2693    216.15514    40.06466  1072.2708
  1190.1101   2627.092    2102.7      1343.219    2889.6064    381.19952
  1620.1749   -932.86707  2151.446    1627.505    5686.2803   -842.6394
  2182.1836   2307.6047    653.97656  2657.977   -1887.4792    198.09268
  3934.6064   2550.028    -576.31525  5993.978    3045.8594   6391.0176
  2407.0278   -138.13371  3076.17     3031.8308   1230.5631   4661.1133
  3902.7239   4866.5864    323.64197  2052.4185    766.22815   284.89655
  2949.425    1482.6832   1821.5427   1349.6788   2893.443   -2416.0962
  2072.226    3989.8372    462.36456  1327.839    -509.3201   3141.083
  1971.9211   1502.481    4157.901   -3429.4792   3363.6418   2869.0334
  1604.5126   1333.4427   -479.80463  1476.8468    831.3002   2155.0654
  3490.2937   3330.2654    851.69293  2757.5088   1052.7855  -

In [107]:
@jax.jit
def z_component_closest_point_on_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_c = track_pos + jnp.dot(v_a, track_dir) * track_dir
    return v_c[2]

z_component_closest_point_on_track_v = jax.jit(jax.vmap(z_component_closest_point_on_track, (0, None, None), 0))

In [108]:
print(z_component_closest_point_on_track(dom_pos, track_pos, track_dir))

-3.5456906e-06


In [109]:
print(z_component_closest_point_on_track_v(dom_pos_v, track_pos, track_dir))

[-2.4345327e-05  5.6681943e-06 -4.2150285e-05 -3.2046213e-05
  2.4941724e-05  3.4987836e-05 -1.5524145e-05  2.2702434e-05
  1.0040683e-05 -1.4372546e-05  3.7841488e-05 -7.8406920e-06
 -1.8525388e-05  1.1352435e-05 -1.4791816e-05 -6.7556340e-07
  2.9104971e-05  9.8454939e-06 -1.8371162e-06  4.4268168e-06
  8.2565020e-06 -1.6111308e-06 -5.4948487e-06  1.3838186e-06
  6.0031966e-06  2.8988032e-05  1.0432834e-05  6.4589872e-06
 -4.4515204e-05  1.7908313e-05 -3.9837450e-06 -6.6182479e-06
  3.1910781e-06 -5.7077759e-06  4.8925802e-05  1.3353854e-05
 -9.6003296e-06 -2.0174895e-05  2.3949611e-05 -3.1115938e-05
 -6.1671244e-06 -4.3101805e-05  7.0843712e-06  1.8287445e-05
  2.0082077e-06  5.2156915e-06  6.2981571e-06 -2.8305685e-06
  3.3278670e-06 -3.2049938e-05  2.5733149e-05  1.7888535e-06
  2.0876039e-06  1.9984449e-05  1.7357802e-06  1.0025384e-06
  2.8465633e-06  1.0432963e-05 -2.1507914e-05  4.4724809e-05
 -1.6160587e-06 -2.1054226e-05  2.4166549e-05  1.1661580e-05
  1.8764225e-05 -6.37465

In [110]:
@jax.jit
def rho_dom_relative_to_track(dom_pos, track_pos, track_dir):
    """
    clean up and verify!
    """
    v1 = dom_pos - track_pos
    closestapproach = track_pos + jnp.dot(v1, track_dir)*track_dir
    v2 = dom_pos - closestapproach
    zdir = jnp.cross(track_dir, jnp.cross(jnp.array([0,0,1]), track_dir))
    positivedir = jnp.cross(track_dir, zdir)
    ypart = v2-v2*jnp.dot(zdir, v2)
    zpart = v2-ypart
    z = jnp.dot(zpart, zdir)
    y = jnp.dot(ypart, positivedir)
    return jnp.arctan2(y,z)

rho_dom_relative_to_track_v = jax.jit(jax.vmap(rho_dom_relative_to_track, (0, None, None), 0))

In [117]:
print(rho_dom_relative_to_track(dom_pos, track_pos, track_dir))

-7.275959e-16


In [118]:
print(jnp.rad2deg(rho_dom_relative_to_track_v(dom_pos_v, track_pos, track_dir)))

[ 74.58642   -68.693855    2.4342232  33.765476   35.656597  -51.508778
  80.984665  -46.959     -78.1858     53.63857   -10.687195  -55.516094
  39.04053     7.1179194  -9.368165  -15.61844   -49.60108   -32.80054
 -78.55552   -82.79323   -66.30072    42.41334    21.603401   36.380325
 -60.30434   -77.70556   -56.632244  -43.650112  -86.88298    49.77658
  38.31964    17.12973   -52.95567   -48.882378   38.667793   62.803127
  42.114002   41.20255   -89.27945   -62.668125  -13.036718   -7.5279145
  28.309353  -27.571009   55.77942     5.755886   75.195724   89.078
  10.856768  -87.10581   -61.512215   -1.7035322 -81.38656   -25.890308
 -62.117947  -68.075966  -29.728767   36.95242   -77.44444    34.972668
  75.917336  -83.59593   -48.08782     3.994736   75.398506  -21.476469
  87.66253    58.463306   68.783165  -42.63527    21.590569   66.589386
  58.559753   20.769266  -40.55096    47.11135   -47.580128  -31.702522
  86.65136   -33.09812   -36.9316     37.816116   27.410305    7.898