In [1]:
import jax.numpy as jnp
import jax
import numpy as np

In [2]:
_recip__speedOfLight = 3.3356409519815204
_n__ = 1.32548384613875
_tan__thetaC = (_n__**2.-1.)**0.5

In [3]:
@jax.jit
def closest_distance_dom_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_d = v_a - jnp.dot(v_a, track_dir) * track_dir
    dist = jnp.linalg.norm(v_d)
    return dist

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
closest_distance_dom_track_v = jax.jit(jax.vmap(closest_distance_dom_track, (0, None, None), 0))

In [4]:
def convert_spherical_to_cartesian_direction(x):
    """
    """
    track_theta = x[0]
    track_phi = x[1]
    track_dir_x = jnp.sin(track_theta) * jnp.cos(track_phi)
    track_dir_y = jnp.sin(track_theta) * jnp.sin(track_phi)
    track_dir_z = jnp.cos(track_theta)
    direction = jnp.array([track_dir_x, track_dir_y, track_dir_z])
    return direction

# Generalize to matrix input for x with shape (N_DOMs, 2) for theta and phi angles.
# Output will be in form of (N_DOMs, 3) for dir_x, dir_y, dir_z
convert_spherical_to_cartesian_direction_v = jax.jit(jax.vmap(closest_distance_dom_track, 0, 0)) 

In [5]:
dom_pos = jnp.array([0, 0, 100])
track_pos = jnp.array([0, 0, 0])
track_theta = jnp.deg2rad(90.)
track_phi = jnp.deg2rad(100.)
track_dir = convert_spherical_to_cartesian_direction(jnp.array([track_theta, track_phi]))

In [6]:
print(closest_distance_dom_track(dom_pos, track_pos, track_dir))

100.0


In [7]:
# now try with batched inputs
dom_pos_v = np.random.normal(0, 500, (100, 3))
print(dom_pos_v.shape)

(100, 3)


In [8]:
print(closest_distance_dom_track_v(dom_pos_v, track_pos, track_dir))

[ 707.6191   444.44476  585.76874  542.6779   883.7168   400.80795
  140.01144  217.77444  554.7531   998.4975   330.2196   966.15845
  495.51086 1409.6699   159.54184  145.17502  843.3794  1256.3119
  324.3534   344.45926  829.6318   619.9755   819.57135 1128.7021
  650.09827  279.83295  276.25815  296.12143  676.7285   285.95758
 1113.8235   214.23154  217.8703   827.8846   480.3064  1346.2316
  760.9748  1548.1155   653.6748   408.8055   435.17502  461.7747
  549.71136  537.83923 1059.3977   675.5304   299.35046  550.416
  815.6207   568.8242  1150.5029   553.8918    95.22571  696.60156
  493.1495   308.61426  562.9281  1101.3491   710.5719   243.505
 1009.2475   641.2852   907.95715  782.254    416.2805  1413.0557
  801.53546  152.53952  734.7734   772.6073   782.9695   364.57834
  756.2235   405.74063  392.8775   928.0058  1095.9796   888.8098
  655.668    556.0227   651.6336   265.76727  474.69757  197.37643
  991.9401   664.2653   949.62317 2105.4084   575.8727   177.7666
 1722.

In [9]:
def light_travel_time(dom_pos, track_pos, track_dir):
    """
    Computes the direct, unscattered time it takes for a photon to travel from 
    the track to the dom.
    See Eq. 4 of the AMANDA track reco paper https://arxiv.org/pdf/astro-ph/0407044
    """
    closest_dist = closest_distance_dom_track(dom_pos, track_pos, track_dir)
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # distance muon travels from support point to point of closest approach.
    d1 = jnp.dot(v_a, track_dir)
    # distance that muon travels beyond closest approach until photon hits.
    d2 = closest_dist * _tan__thetaC 
    return (d1+d2) * _recip__speedOfLight

# Generalize to matrix input for dom_pos with shape (N_DOMs, 3).
# Output will be in form of (N_DOMs, 1)
light_travel_time_v = jax.jit(jax.vmap(light_travel_time, (0, None, None), 0))

In [10]:
print(light_travel_time(dom_pos, track_pos, track_dir))

290.20215


In [11]:
print(light_travel_time_v(dom_pos_v, track_pos, track_dir))

[ 1017.23785  3122.096     710.60583 -1618.4631   2545.9282   2767.0447
   190.271    1775.1626    942.2193   -328.16565  2515.9272   3773.9934
 -3393.5183   3559.433     292.6404  -1000.936    1385.4203   5746.5083
  2736.7593    478.19208  2937.6816    556.5861   2489.246    -175.28606
  5085.5356   1804.1371   1306.4128   2129.0627   4293.428    1042.1727
  4896.9717   3806.8376  -3491.8953   1663.1935   2172.5132   2948.52
  3676.8713   7796.4106   3103.4983   2721.9492   -394.99655  3278.282
  3886.84    -1432.6108   4366.291     585.0739   3874.8955   2057.148
  2810.4124   1950.8395   6008.2354   -421.59882  -406.0062   1868.438
  1740.5648   -964.0184    331.9091   2225.7356   1077.562     382.01245
  3676.3079     69.72129  2788.1736   1836.4861   1624.331    2948.7266
  3892.6929    792.6432   5407.2524   4524.313     684.73364  -378.85553
  4149.692    1411.7701   -418.3996   1084.3256   1203.7319   1237.3722
  1679.7257   3776.9866   2900.3096   3282.4666   1932.1149  -1367

In [12]:
@jax.jit
def z_component_closest_point_on_track(dom_pos, track_pos, track_dir):
    """
    dom_pos: 1D jax array with 3 components [x, y, z]
    track_pos: 1D jax array with 3 components [x, y, z]
    track_dir: 1D jax array with 3 components [dir_x, dir_y, dir_z]
    """
    
    # vector track support point -> dom
    v_a = dom_pos - track_pos 
    # vector: closest point on track -> dom
    v_c = track_pos + jnp.dot(v_a, track_dir) * track_dir
    return v_c[2]

z_component_closest_point_on_track_v = jax.jit(jax.vmap(z_component_closest_point_on_track, (0, None, None), 0))

In [13]:
print(z_component_closest_point_on_track(dom_pos, track_pos, track_dir))

1.9106854e-13


In [14]:
print(z_component_closest_point_on_track_v(dom_pos_v, track_pos, track_dir))

[ 1.3579879e-05 -2.4011188e-05  1.2964231e-05  4.1846448e-05
  2.4422931e-07 -2.1017942e-05  2.8311317e-06 -1.4980576e-05
  8.7495928e-06  4.2272350e-05 -2.0411626e-05 -1.2713579e-05
  6.3313659e-05  6.9644357e-06  2.2323713e-06  1.8637484e-05
  1.3917931e-05 -2.7527823e-05 -2.3528570e-05  6.8330801e-06
 -6.9462462e-06  1.6283411e-05 -1.4523815e-06  4.5220531e-05
 -4.1919964e-05 -1.3000239e-05 -6.6138368e-06 -1.6638744e-05
 -3.0527193e-05 -2.7822825e-06 -2.1813908e-05 -4.1739084e-05
  5.4044409e-05  9.6886397e-06 -1.0203746e-05  1.2557549e-05
 -1.9243806e-05 -4.3293428e-05 -1.5810650e-05 -2.0122856e-05
  2.1725484e-05 -2.5398862e-05 -3.0029467e-05  3.9226965e-05
 -1.6929449e-05  1.8022800e-05 -3.9393937e-05 -6.0257530e-06
 -5.8112914e-06 -3.9326023e-06 -3.4981400e-05  2.6588781e-05
  8.9417927e-06  2.0064713e-06 -4.0549330e-06  2.4369150e-05
  1.7058195e-05  1.2716504e-05  1.2901663e-05  4.2542524e-06
 -9.7948323e-06  2.3473845e-05 -2.0083971e-06  5.6824656e-06
 -5.4550242e-06  1.50961

In [15]:
@jax.jit
def rho_dom_relative_to_track(dom_pos, track_pos, track_dir):
    """
    clean up and verify!
    """
    v1 = dom_pos - track_pos
    closestapproach = track_pos + jnp.dot(v1, track_dir)*track_dir
    v2 = dom_pos - closestapproach
    zdir = jnp.cross(track_dir, jnp.cross(jnp.array([0,0,1]), track_dir))
    positivedir = jnp.cross(track_dir, zdir)
    ypart = v2-v2*jnp.dot(zdir, v2)
    zpart = v2-ypart
    z = jnp.dot(zpart, zdir)
    y = jnp.dot(ypart, positivedir)
    return jnp.arctan2(y,z)

rho_dom_relative_to_track_v = jax.jit(jax.vmap(rho_dom_relative_to_track, (0, None, None), 0))

In [16]:
print(rho_dom_relative_to_track(dom_pos, track_pos, track_dir))

-7.275959e-16


In [17]:
print(jnp.rad2deg(rho_dom_relative_to_track_v(dom_pos_v, track_pos, track_dir)))

[ 41.263985   87.94616    -6.786927  -63.200016  -82.97944   -58.947006
 -54.462128    2.845565   84.1405     66.703926  -82.5882     -2.0476346
 -40.266926  -35.769245  -24.61014    33.089184   73.08122    87.871086
  55.093487  -46.686176  -22.64364   -88.706764   30.508406   48.75631
 -13.698643  -55.81961   -65.86062    39.34551   -67.13592   -27.188456
 -55.725258  -74.30694    -4.00967   -41.39106    59.667046   59.887024
   7.96906   -25.541775  -51.673126   31.242788    4.0546136 -16.027323
  58.34345    49.27972    12.611174   28.267601   26.25734    37.687016
 -32.305588   80.619415  -69.876785   66.5613     41.637154  -63.3831
  40.136257   60.44382     2.3590398   9.980967   54.91069    37.76351
  -1.8509852  53.574516   -9.472357   17.793621  -22.095016  -18.486582
  29.691921   12.197003   33.915092   14.364562   18.23489   -88.91232
  23.818516   58.721745  -88.82185    77.775406   -8.866069   11.076788
 -12.43299   -20.886436    5.1235213  55.69004    73.97324    48.066