In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from lalinference import BurstSineGaussian
from ripplegw.waveforms import SineGaussian

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

# 2048 sampling

In [5]:
sine_gaussian = SineGaussian(2048, 8.0)

quality = jnp.array([3.0, 10.0, 100.0, 55.0])
frequency = jnp.array([100.0, 500.0, 800.0, 961.0])
hrss = jnp.array([1e-23, 1e-22, 1e-21, 4e-10])
phase = jnp.array([0.0, jnp.pi / 2.0, jnp.pi, 2 * jnp.pi])
eccentricity = jnp.array([0., 0.5, 1, 0.1])

cross, plus = sine_gaussian(quality, frequency, hrss, phase, eccentricity)

for i in range(4):
    hplus, hcross = BurstSineGaussian(
        Q=quality[i].item(),
        centre_frequency=frequency[i].item(),
        hrss=hrss[i].item(),
        eccentricity=eccentricity[i].item(),
        phase=phase[i].item(),
        delta_t=1 / 2048,
    )
    hplus = hplus.data.data
    hcross = hcross.data.data
    n_samples = len(hplus)
    start, stop = (
        len(cross[i]) // 2 - n_samples // 2,
        len(cross[i]) // 2 + n_samples // 2 + 1,
    )
    cross_, plus_ = cross[i][start:stop], plus[i][start:stop]
    print(jnp.max(jnp.abs(hcross - cross_)), jnp.max(jnp.abs(hplus - plus_)))

1.1754943508222875e-38 1.1754943508222875e-38
4.70197740328915e-38 1.88079096131566e-37
0.0 1.504632769052528e-36
4.1359030627651384e-25 4.1359030627651384e-25


# 4096 sampling

In [6]:
sine_gaussian = SineGaussian(4096, 8.0)

quality = jnp.array([3.0, 10.0, 100.0, 55.0])
frequency = jnp.array([100.0, 500.0, 800.0, 961.0])
hrss = jnp.array([1e-23, 1e-22, 1e-21, 4e-10])
phase = jnp.array([0.0, jnp.pi / 2.0, jnp.pi, 2 * jnp.pi])
eccentricity = jnp.array([0., 0.5, 1, 0.1])

cross, plus = sine_gaussian(quality, frequency, hrss, phase, eccentricity)

for i in range(4):
    hplus, hcross = BurstSineGaussian(
        Q=quality[i].item(),
        centre_frequency=frequency[i].item(),
        hrss=hrss[i].item(),
        eccentricity=eccentricity[i].item(),
        phase=phase[i].item(),
        delta_t=1 / 4096,
    )
    hplus = hplus.data.data
    hcross = hcross.data.data
    n_samples = len(hplus)
    start, stop = (
        len(cross[i]) // 2 - n_samples // 2,
        len(cross[i]) // 2 + n_samples // 2 + 1,
    )
    cross_, plus_ = cross[i][start:stop], plus[i][start:stop]
    print(jnp.max(jnp.abs(hcross - cross_)), (hplus - plus_).max())

1.1754943508222875e-38 1.1754943508222875e-38
1.88079096131566e-37 1.88079096131566e-37
0.0 1.504632769052528e-36
4.1359030627651384e-25 4.1359030627651384e-25
