<a href="https://colab.research.google.com/github/ralfcam/stajax/blob/main/stajax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import erf, betainc
from typing import Tuple

def t_cdf(t, df):
    """Implement the CDF of the t-distribution using the error function."""
    x = df / (t**2 + df)
    return 1 - 0.5 * (1 + erf(t / jnp.sqrt(2))) * betainc(df/2, 0.5, x)

def f_cdf(x, df1, df2):
    """Implement the CDF of the F-distribution using the incomplete beta function."""
    return betainc(df1 / 2, df2 / 2, df1 * x / (df1 * x + df2))

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray = None, axis: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Vectorized t-test implementation."""
    def single_ttest(x, y=None):
        if y is None:
            t = jnp.mean(x) / (jnp.std(x, ddof=1) / jnp.sqrt(x.shape[0]))
            df = x.shape[0] - 1
        else:
            n1, n2 = x.shape[0], y.shape[0]
            var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)
            pooled_se = jnp.sqrt(var1/n1 + var2/n2)
            t = (jnp.mean(x) - jnp.mean(y)) / pooled_se
            df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
        p = 2 * (1 - t_cdf(jnp.abs(t), df))
        return t, p

    if y is None:
        return vmap(single_ttest, in_axes=(axis,))(x)
    else:
        return vmap(single_ttest, in_axes=(axis, axis))(x, y)

def vectorized_anova(groups: jnp.ndarray, axis: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Vectorized one-way ANOVA implementation."""
    def single_anova(groups):
        k = groups.shape[0]  # number of groups
        n = groups.shape[1]  # samples per group
        grand_mean = jnp.mean(groups)
        between_group_var = jnp.sum(n * (jnp.mean(groups, axis=1) - grand_mean)**2) / (k - 1)
        within_group_var = jnp.sum((groups - jnp.mean(groups, axis=1, keepdims=True))**2) / (k * (n - 1))
        f = between_group_var / within_group_var
        df1, df2 = k - 1, k * (n - 1)
        p = 1 - f_cdf(f, df1, df2)
        return f, p

    return vmap(single_anova, in_axes=(axis,))(groups)

def vectorized_correlation(x: jnp.ndarray, y: jnp.ndarray, axis: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Vectorized Pearson correlation implementation."""
    def single_correlation(x, y):
        n = x.shape[0]
        r = jnp.corrcoef(x, y)[0, 1]
        t = r * jnp.sqrt((n - 2) / (1 - r**2))
        p = 2 * (1 - t_cdf(jnp.abs(t), n - 2))
        return r, p

    return vmap(single_correlation, in_axes=(axis, axis))(x, y)

# Set up the random key
key = random.PRNGKey(0)

# Generate data for t-test
key, subkey = random.split(key)
x = random.normal(subkey, shape=(100, 1000))  # 100 datasets of 1000 samples each
key, subkey = random.split(key)
y = random.normal(subkey, shape=(100, 1000))

# Perform t-tests
t_stats, p_values = vectorized_ttest(x, y, axis=1)
print("T-test results:")
print("T-statistics shape:", t_stats.shape)
print("P-values shape:", p_values.shape)

# Generate data for ANOVA
key, subkey = random.split(key)
groups = random.normal(subkey, shape=(100, 50, 20))  # 100 experiments, 50 groups, 20 samples each

# Perform ANOVA
f_stats, p_values = vectorized_anova(groups, axis=0)
print("\nANOVA results:")
print("F-statistics shape:", f_stats.shape)
print("P-values shape:", p_values.shape)

# Generate data for correlation
key, subkey = random.split(key)
x_corr = random.normal(subkey, shape=(100, 1000))
key, subkey = random.split(key)
y_corr = random.normal(subkey, shape=(100, 1000))

# Perform correlation
r_values, p_values = vectorized_correlation(x_corr, y_corr, axis=1)
print("\nCorrelation results:")
print("R-values shape:", r_values.shape)
print("P-values shape:", p_values.shape)

T-test results:
T-statistics shape: (1000,)
P-values shape: (1000,)

ANOVA results:
F-statistics shape: (100,)
P-values shape: (100,)

Correlation results:
R-values shape: (1000,)
P-values shape: (1000,)


In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import erf, betainc
import statsmodels.api as sm
from scipy import stats
from typing import Tuple

# JAX implementations

def t_cdf(t, df):
    x = df / (t**2 + df)
    return 1 - 0.5 * (1 + erf(t / jnp.sqrt(2))) * betainc(df/2, 0.5, x)

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray, axis: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
    def single_ttest(x, y):
        n1, n2 = x.shape[0], y.shape[0]
        mean1, mean2 = jnp.mean(x), jnp.mean(y)
        var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

        pooled_se = jnp.sqrt(var1/n1 + var2/n2)
        t = (mean1 - mean2) / pooled_se

        df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
        p = 2 * (1 - t_cdf(jnp.abs(t), df))
        return t, p

    return vmap(single_ttest, in_axes=(axis, axis))(x, y)

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import erf, betainc
from scipy import stats
from typing import Tuple

# JAX implementations

def t_cdf(t, df):
    x = df / (t**2 + df)
    return 1 - 0.5 * (1 + erf(t / jnp.sqrt(2))) * betainc(df/2, 0.5, x)

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    def single_ttest(x, y):
        n1, n2 = x.shape[0], y.shape[0]
        mean1, mean2 = jnp.mean(x), jnp.mean(y)
        var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

        pooled_se = jnp.sqrt(var1/n1 + var2/n2)
        t = (mean1 - mean2) / pooled_se

        df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
        p = 2 * (1 - t_cdf(jnp.abs(t), df))
        return t, p

    return vmap(single_ttest)(x, y)

# The vectorized_anova and vectorized_correlation functions remain the same as before

# Test functions

def test_ttest():
    np.random.seed(0)
    x = np.random.normal(0, 1, (100, 1000))
    y = np.random.normal(0.1, 1, (100, 1000))

    # Reshape x and y for JAX vectorization
    jax_x = jnp.array(x).T  # Shape becomes (1000, 100)
    jax_y = jnp.array(y).T  # Shape becomes (1000, 100)

    jax_t, jax_p = vectorized_ttest(jax_x, jax_y)

    sm_t = np.zeros(100)
    sm_p = np.zeros(100)
    for i in range(100):
        sm_t[i], sm_p[i] = stats.ttest_ind(x[i], y[i])

    print("T-test comparison:")
    print("Max absolute difference in t-statistic:", np.max(np.abs(jax_t - sm_t)))
    print("Max absolute difference in p-value:", np.max(np.abs(jax_p - sm_p)))

# The test_anova and test_correlation functions remain the same as before

if __name__ == "__main__":
    test_ttest()
    test_anova()
    test_correlation()

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax.scipy.special import erf, betainc
from scipy import stats
from typing import Tuple

# JAX implementations

def t_cdf(t, df):
    x = df / (t**2 + df)
    return 1 - 0.5 * (1 + erf(t / jnp.sqrt(2))) * betainc(df/2, 0.5, x)

def single_ttest(xy):
    x, y = xy
    n1, n2 = x.shape[0], y.shape[0]
    mean1, mean2 = jnp.mean(x), jnp.mean(y)
    var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

    pooled_se = jnp.sqrt(var1/n1 + var2/n2)
    t = (mean1 - mean2) / pooled_se

    df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
    p = 2 * (1 - t_cdf(jnp.abs(t), df))
    return t, p

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return lax.map(single_ttest, (x, y))

# Test functions

def test_ttest():
    np.random.seed(0)
    x = np.random.normal(0, 1, (100, 1000))
    y = np.random.normal(0.1, 1, (100, 1000))

    jax_x = jnp.array(x)
    jax_y = jnp.array(y)

    jax_t, jax_p = vectorized_ttest(jax_x, jax_y)

    sm_t = np.zeros(100)
    sm_p = np.zeros(100)
    for i in range(100):
        sm_t[i], sm_p[i] = stats.ttest_ind(x[i], y[i])

    print("T-test comparison:")
    print("Max absolute difference in t-statistic:", np.max(np.abs(jax_t - sm_t)))
    print("Max absolute difference in p-value:", np.max(np.abs(jax_p - sm_p)))

if __name__ == "__main__":
    test_ttest()


In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from jax.scipy import stats as jstats
import numpy as np
from scipy import stats
from typing import Tuple

def single_ttest(xy):
    x, y = xy
    n1, n2 = x.shape[0], y.shape[0]
    mean1, mean2 = jnp.mean(x), jnp.mean(y)
    var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

    pooled_se = jnp.sqrt(var1/n1 + var2/n2)
    t = (mean1 - mean2) / pooled_se

    df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
    p = 2 * (1 - jstats.t.cdf(jnp.abs(t), df))
    return t, p

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return lax.map(single_ttest, (x, y))

def test_ttest():
    np.random.seed(0)
    x = np.random.normal(0, 1, (100, 1000))
    y = np.random.normal(0.1, 1, (100, 1000))

    jax_x = jnp.array(x)
    jax_y = jnp.array(y)

    jax_t, jax_p = vectorized_ttest(jax_x, jax_y)

    sm_t = np.zeros(100)
    sm_p = np.zeros(100)
    for i in range(100):
        sm_t[i], sm_p[i] = stats.ttest_ind(x[i], y[i])

    print("T-test comparison:")
    print("Max absolute difference in t-statistic:", np.max(np.abs(jax_t - sm_t)))
    print("Max absolute difference in p-value:", np.max(np.abs(jax_p - sm_p)))

if __name__ == "__main__":
    test_ttest()


In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from jax.scipy.special import betainc
import numpy as np
from scipy import stats
from typing import Tuple

def t_cdf(t, df):
    x = df / (t**2 + df)
    return 0.5 * (1 + jnp.sign(t) * (1 - betainc(df/2, 0.5, x)))

def single_ttest(xy):
    x, y = xy
    n1, n2 = x.shape[0], y.shape[0]
    mean1, mean2 = jnp.mean(x), jnp.mean(y)
    var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

    pooled_se = jnp.sqrt(var1/n1 + var2/n2)
    t = (mean1 - mean2) / pooled_se

    df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
    p = 2 * (1 - t_cdf(jnp.abs(t), df))
    return t, p

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return lax.map(single_ttest, (x, y))

def test_ttest():
    np.random.seed(0)
    x = np.random.normal(0, 1, (100, 1000))
    y = np.random.normal(0.1, 1, (100, 1000))

    jax_x = jnp.array(x)
    jax_y = jnp.array(y)

    jax_t, jax_p = vectorized_ttest(jax_x, jax_y)

    sm_t = np.zeros(100)
    sm_p = np.zeros(100)
    for i in range(100):
        sm_t[i], sm_p[i] = stats.ttest_ind(x[i], y[i])

    print("T-test comparison:")
    print("Max absolute difference in t-statistic:", np.max(np.abs(jax_t - sm_t)))
    print("Max absolute difference in p-value:", np.max(np.abs(jax_p - sm_p)))

if __name__ == "__main__":
    test_ttest()


T-test comparison:
Max absolute difference in t-statistic: 7.1525574e-07
Max absolute difference in p-value: 0.0011698604


In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from jax.scipy.special import betainc
import numpy as np
from scipy import stats
import time
from typing import Tuple

# JAX implementation (as before)

def t_cdf(t, df):
    x = df / (t**2 + df)
    return 0.5 * (1 + jnp.sign(t) * (1 - betainc(df/2, 0.5, x)))

def single_ttest(xy):
    x, y = xy
    n1, n2 = x.shape[0], y.shape[0]
    mean1, mean2 = jnp.mean(x), jnp.mean(y)
    var1, var2 = jnp.var(x, ddof=1), jnp.var(y, ddof=1)

    pooled_se = jnp.sqrt(var1/n1 + var2/n2)
    t = (mean1 - mean2) / pooled_se

    df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
    p = 2 * (1 - t_cdf(jnp.abs(t), df))
    return t, p

def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return lax.map(single_ttest, (x, y))

# Speed test function

def speed_test(n_tests, n_samples):
    print(f"Running speed test with {n_tests} tests and {n_samples} samples each")

    # Generate data
    np.random.seed(0)
    x = np.random.normal(0, 1, (n_tests, n_samples))
    y = np.random.normal(0.1, 1, (n_tests, n_samples))

    # JAX implementation
    jax_x = jnp.array(x)
    jax_y = jnp.array(y)

    start_time = time.time()
    jax.block_until_ready(vectorized_ttest(jax_x, jax_y))
    jax_time = time.time() - start_time
    print(f"JAX implementation time: {jax_time:.4f} seconds")

    # Scipy implementation
    start_time = time.time()
    for i in range(n_tests):
        stats.ttest_ind(x[i], y[i])
    scipy_time = time.time() - start_time
    print(f"Scipy implementation time: {scipy_time:.4f} seconds")

    print(f"Speedup: {scipy_time / jax_time:.2f}x\n")

# Run speed tests
if __name__ == "__main__":
    speed_test(100, 1000)    # 100 tests, 1000 samples each
    speed_test(1000, 1000)   # 1000 tests, 1000 samples each
    speed_test(100, 10000)   # 100 tests, 10000 samples each
    speed_test(1000, 10000)  # 1000 tests, 10000 samples each

Running speed test with 100 tests and 1000 samples each
JAX implementation time: 0.5553 seconds
Scipy implementation time: 0.0787 seconds
Speedup: 0.14x

Running speed test with 1000 tests and 1000 samples each
JAX implementation time: 0.9086 seconds
Scipy implementation time: 0.8050 seconds
Speedup: 0.89x

Running speed test with 100 tests and 10000 samples each
JAX implementation time: 0.8440 seconds
Scipy implementation time: 0.1191 seconds
Speedup: 0.14x

Running speed test with 1000 tests and 10000 samples each
JAX implementation time: 0.9904 seconds
Scipy implementation time: 0.8481 seconds
Speedup: 0.86x



In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy.special import betainc
from typing import Tuple
import time

# Force JAX to use GPU if available
jax.config.update('jax_platform_name', 'gpu')

# Optimized T-test implementation

@jit
def t_cdf(t, df):
    x = df / (t**2 + df)
    return 0.5 * (1 + jnp.sign(t) * (1 - betainc(df/2, 0.5, x)))

@jit
def vectorized_ttest(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    n1, n2 = x.shape[1], y.shape[1]
    mean1, mean2 = jnp.mean(x, axis=1), jnp.mean(y, axis=1)
    var1, var2 = jnp.var(x, axis=1, ddof=1), jnp.var(y, axis=1, ddof=1)

    pooled_se = jnp.sqrt(var1/n1 + var2/n2)
    t = (mean1 - mean2) / pooled_se

    df = (var1/n1 + var2/n2)**2 / ((var1/n1)**2/(n1-1) + (var2/n2)**2/(n2-1))
    p = 2 * (1 - t_cdf(jnp.abs(t), df))
    return t, p

# Vectorized ANOVA implementation

@jit
def f_cdf(x, df1, df2):
    return betainc(df1 / 2, df2 / 2, df1 * x / (df1 * x + df2))

@jit
def vectorized_anova(groups: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    k = groups.shape[1]  # number of groups
    n = groups.shape[2]  # samples per group

    grand_mean = jnp.mean(groups, axis=(1, 2), keepdims=True)
    between_group_var = jnp.sum(n * (jnp.mean(groups, axis=2, keepdims=True) - grand_mean)**2, axis=1) / (k - 1)
    within_group_var = jnp.sum((groups - jnp.mean(groups, axis=2, keepdims=True))**2, axis=(1, 2)) / (k * (n - 1))

    f = between_group_var / within_group_var
    df1, df2 = k - 1, k * (n - 1)
    p = 1 - f_cdf(f, df1, df2)
    return f, p

# Vectorized correlation implementation

@jit
def vectorized_correlation(x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    n = x.shape[1]
    mean_x, mean_y = jnp.mean(x, axis=1, keepdims=True), jnp.mean(y, axis=1, keepdims=True)
    std_x, std_y = jnp.std(x, axis=1, keepdims=True), jnp.std(y, axis=1, keepdims=True)

    r = jnp.sum((x - mean_x) * (y - mean_y), axis=1) / (n * std_x.squeeze() * std_y.squeeze())
    t = r * jnp.sqrt((n - 2) / (1 - r**2))
    p = 2 * (1 - t_cdf(jnp.abs(t), n - 2))
    return r, p

# Test functions

def test_ttest(n_tests, n_samples):
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, (n_tests, n_samples))
    y = jax.random.normal(key, (n_tests, n_samples))
    return vectorized_ttest(x, y)

def test_anova(n_experiments, n_groups, n_samples):
    key = jax.random.PRNGKey(0)
    groups = jax.random.normal(key, (n_experiments, n_groups, n_samples))
    return vectorized_anova(groups)

def test_correlation(n_tests, n_samples):
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, (n_tests, n_samples))
    y = jax.random.normal(key, (n_tests, n_samples))
    return vectorized_correlation(x, y)

# Performance test function

def run_performance_test(test_func, *args):
    # Warm-up run
    _ = test_func(*args)
    jax.block_until_ready(_)

    # Timed run
    start_time = time.time()
    result = test_func(*args)
    jax.block_until_ready(result)
    end_time = time.time()

    return result, end_time - start_time

# Main execution

if __name__ == "__main__":
    print("Running tests on:", jax.devices()[0])

    # T-test performance
    _, time_taken = run_performance_test(test_ttest, 10000, 1000)
    print(f"T-test (10000 tests, 1000 samples each) time: {time_taken:.4f} seconds")

    # ANOVA performance
    _, time_taken = run_performance_test(test_anova, 1000, 5, 1000)
    print(f"ANOVA (1000 experiments, 5 groups, 1000 samples each) time: {time_taken:.4f} seconds")

    # Correlation performance
    _, time_taken = run_performance_test(test_correlation, 10000, 1000)
    print(f"Correlation (10000 pairs, 1000 samples each) time: {time_taken:.4f} seconds")


Running tests on: cuda:0
T-test (10000 tests, 1000 samples each) time: 0.0042 seconds
ANOVA (1000 experiments, 5 groups, 1000 samples each) time: 0.0488 seconds
Correlation (10000 pairs, 1000 samples each) time: 0.0054 seconds
