In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]= "platform"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from einops import rearrange
from functools import partial

In [None]:
# jax.config.update('jax_platform_name', 'cpu')
# jax.config.update("jax_enable_x64", True)  # Enables float64 precision globally

In [3]:
@partial(jax.jit, static_argnames=["pool", "win", "stride", "pad"])
def shrink(X, pool, win, stride, pad):
    """
    Apply shrink operation using max pooling and neighborhood construction.

    Parameters
    ----------
    X : jnp.ndarray
        Input array of shape (batch, height, width, channels).
    pool : int
        Pooling window size.
    win : int
        Neighborhood window size.
    stride : int
        Stride size for neighborhood construction.
    pad : int
        Padding size.

    Returns
    -------
    jnp.ndarray
        Transformed array.
    """

    # ---- max pooling ----
    X = jax.lax.reduce_window(
        X,
        -jnp.inf,
        jax.lax.max,
        (1, pool, pool, 1),
        (1, pool, pool, 1),
        padding="VALID",
    )

    # ---- neighborhood construction ----
    X = rearrange(X, "b h w c -> b c h w")
    X = jnp.pad(X, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode="reflect")
    X = jax.lax.conv_general_dilated_patches(
        lhs=X, filter_shape=(win, win), window_strides=(stride, stride), padding="VALID"
    )
    X = rearrange(X, "b (c p) h w -> b h w p c", p=win**2)
    return X


In [4]:
from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnames=["win", "stride"])
def extract_patches(X, win, stride):
    B, H, W, C = X.shape
    out_h = (H - win) // stride + 1
    out_w = (W - win) // stride + 1

    @partial(jax.jit)
    def get_patch(i, j, X):
        """Dynamically extract a patch at position (i, j)."""
        return jax.lax.dynamic_slice(X, (0, i * stride, j * stride, 0), (B, win, win, C))

    # Vectorized patch extraction
    i_vals = jnp.arange(out_h)
    j_vals = jnp.arange(out_w)
    patches = jax.vmap(lambda i: jax.vmap(lambda j: get_patch(i, j, X))(j_vals))(i_vals)

    return patches.reshape(B * out_h * out_w, win * win * C)

@partial(jax.jit, static_argnames=["pool", "win", "stride", "pad", "batch_size"])
def compute_patch_stats(X, pool, win, stride, pad, batch_size):
    B, H, W, C = X.shape
    # X = X.astype(jnp.float32)
    num_kernel = win * win * C
    num_batches = max(B // batch_size, 1)
    X = X[:num_batches * batch_size]

    # ---- Apply Max Pooling ----
    X = jax.lax.reduce_window(
        X, -jnp.inf, jax.lax.max,
        (1, pool, pool, 1),  # Pooling size
        (1, pool, pool, 1),  # Stride
        padding="VALID"
    )

    # ---- Apply Padding ----
    X = jnp.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), mode="reflect")

    X_batch = jnp.split(X, num_batches)

    # ---- Compute Statistics ----
    dc_batch = []
    bias = 0
    mean = jnp.zeros((1, num_kernel))
    for X in X_batch:
        patches = extract_patches(X, win, stride)
        dc_local = jnp.mean(patches, axis=-1, keepdims=-1)
        X_centered = patches - dc_local
        bias_local = jnp.max(jnp.linalg.norm(X_centered, axis=-1))
        mean_local = jnp.mean(X_centered, axis=0, keepdims=True)

        dc_batch.append(dc_local)
        bias = jnp.maximum(bias_local, bias)
        mean += mean_local / num_batches

    # ---- Compute Covariance ----
    covariance = jnp.zeros((num_kernel, num_kernel))
    for X, dc in zip(X_batch, dc_batch):
        patches = extract_patches(X, win, stride)
        X_centered = (patches - dc - mean)
        covariance += X_centered.T @ X_centered
        # covariance += jnp.einsum("ni,nj->ij", patches - dc - mean, patches - dc - mean)


    covariance = covariance / (batch_size  * num_batches * H * W - 1)
    return covariance, mean
    # return mean

In [None]:
from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnames=["win", "stride"])
def _extract_patches(X, win, stride):
    """
    Extracts sliding patches dynamically from input X using `jax.lax.dynamic_slice`.

    Args:
        X: Input tensor (B, H, W, C)
        win: Patch size.
        stride: Stride for sliding window.

    Returns:
        Patches of shape (B * out_H * out_W, win * win * C)
    """
    B, H, W, C = X.shape
    out_h = (H - win) // stride + 1
    out_w = (W - win) // stride + 1

    def get_patch(i, j, X):
        return jax.lax.dynamic_slice(X, (0, i * stride, j * stride, 0), (B, win, win, C))

    # Vectorized patch extraction
    i_vals, j_vals = jnp.arange(out_h), jnp.arange(out_w)
    patches = jax.vmap(lambda i: jax.vmap(lambda j: get_patch(i, j, X))(j_vals))(i_vals)

    return patches.reshape(B * out_h * out_w, win * win * C)

@partial(jax.jit, static_argnames=["extract_patches", "num_kernel"])
def _compute_statistics(X_batch, extract_patches, num_kernel):
    """
    Computes local mean (dc_local), bias, and mean update for each batch.

    Args:
        X_batch: Batched input tensor.
        win: Patch size.
        stride: Stride for patches.
        num_batches: Number of mini-batches.
        num_kernel: Number of elements in a single patch.

    Returns:
        mean, bias, dc_batch
    """
    def scan_statistics(carry, X_cur):
        mean, bias = carry  # Unpack carry state
        patches = extract_patches(X_cur)
        dc_local = jnp.mean(patches, axis=-1, keepdims=True)  # Mean per patch
        X_centered = patches - dc_local

        bias_local = jnp.max(jnp.linalg.norm(X_centered, axis=-1))
        mean_local = jnp.mean(X_centered, axis=0, keepdims=True)

        new_mean = mean + mean_local
        new_bias = jnp.maximum(bias, bias_local)

        return (new_mean, new_bias), dc_local

    mean_init = jnp.zeros((1, num_kernel))
    bias_init = 0.0

    (final_mean, final_bias), dc_batch = jax.lax.scan(scan_statistics, (mean_init, bias_init), X_batch)
    final_mean /= len(X_batch)
    
    return final_mean, final_bias, dc_batch

@partial(jax.jit, static_argnames=["num_kernel", "extract_patches"])
def _compute_covariance(X_batch, dc_batch, mean, extract_patches,  num_kernel):
    def scan_covariance(carry, inputs):
        X_cur, dc = inputs
        patches = extract_patches(X_cur)
        X_centered = patches - dc - mean

        carry += jnp.einsum("...i,...j->ij", X_centered, X_centered, precision=jax.lax.Precision.HIGHEST)
        return carry, None

    covariance_init = jnp.zeros((num_kernel, num_kernel))

    covariance, _ = jax.lax.scan(scan_covariance, covariance_init, (X_batch, dc_batch))

    return covariance 

@partial(jax.jit, static_argnames=["pool", "win", "stride", "pad", "batch_size"])
def compute_patch_stats(X, pool, win, stride, pad, batch_size):
    """
    Compute the mean, bias, and covariance of sliding patches.

    Args:
        X: Input tensor (B, H, W, C)
        pool: Pooling size.
        win: Patch size.
        stride: Stride for patches.
        pad: Padding size.
        batch_size: Batch size.

    Returns:
        covariance matrix, mean
    """
    B, H, W, C = X.shape
    num_kernel = win * win * C
    num_batches = max(B // batch_size, 1)
    X = X[:num_batches * batch_size]

    # ---- Apply Max Pooling ----
    X = jax.lax.reduce_window(
        X, -jnp.inf, jax.lax.max,
        (1, pool, pool, 1),
        (1, pool, pool, 1),
        padding="VALID"
    )

    # ---- Apply Padding ----
    X = jnp.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), mode="reflect")

    X_batch = jnp.stack(jnp.split(X, num_batches))

    # ---- Compute Mean and Bias ----
    mean, bias, dc_batch = _compute_statistics(X_batch, lambda X: _extract_patches(X, win, stride), num_kernel)

    # ---- Compute Covariance ----
    covariance = _compute_covariance(X_batch, dc_batch, mean, lambda X: _extract_patches(X, win, stride), num_kernel)
    covariance /= (batch_size * num_batches * H * W - 1)

    return covariance, mean

In [57]:
@partial(jax.jit, static_argnames=["pool", "win", "stride", "pad"])
def compute(X, pool, win, stride, pad):
    # X = X.astype(jnp.float32)
    X = shrink(X, pool, win, stride, pad)
    X = rearrange(X, "... p c -> (...) (p c)")

    dc = jnp.mean(X, axis=-1, keepdims=-1)
    X = X - dc
    bias = jnp.max(jnp.linalg.norm(X, axis=-1))
    mean = jnp.mean(X, axis=0, keepdims=True)

    # covariance = (X - mean).T @ (X - mean)
    covariance = jnp.einsum("ni,nj->ij", X - mean, X - mean)
    covariance /= (X.shape[0] - 1)
    return covariance, mean

In [None]:
X = np.random.randn(50000, 128, 128, 3) + 2

In [17]:
a = compute(X, 1, 7, 1, 3)
a[0].block_until_ready()

NameError: name 'compute' is not defined

In [181]:
a = np.ones((10, 100, 10))
rearrange(a, "n (p p) c -> n p p c").shape

EinopsError:  Error while processing rearrange-reduction pattern "n (p p) c -> n p p c".
 Input tensor shape: (10, 100, 10). Additional info: {}.
 Indexing expression contains duplicate dimension "p"

In [94]:
# jax.profiler.start_trace("./jax-trace")
# b = compute_patch_stats(X[:10], 1, 7, 1, 3, 1)
# b[0].block_until_ready()

b = compute_patch_stats(X, 1, 7, 1, 3, 100)
b[0].block_until_ready()
# jax.profiler.stop_trace()

Array([[ 0.99270904, -0.00725159, -0.00722823, ..., -0.00708457,
        -0.00731058, -0.00726368],
       [-0.00725159,  0.99279356, -0.00727239, ..., -0.00731609,
        -0.00706959, -0.00724925],
       [-0.00722823, -0.00727239,  0.9926517 , ..., -0.00736245,
        -0.00731914, -0.00707388],
       ...,
       [-0.00708457, -0.00731609, -0.00736245, ...,  0.992648  ,
        -0.00726744, -0.00719396],
       [-0.00731058, -0.00706959, -0.00731914, ..., -0.00726744,
         0.9928077 , -0.0072573 ],
       [-0.00726368, -0.00724925, -0.00707388, ..., -0.00719396,
        -0.0072573 ,  0.9926555 ]], dtype=float32)

In [90]:
jnp.abs(jnp.load("gt.npy") - b[0]).max()

Array(0.00029544, dtype=float32)

In [63]:
np.save("gt.npy", b[0])

In [12]:
print(b[0].dtype)

float32


In [12]:
jnp.allclose(np.load("b.npy"), b[0])

Array(False, dtype=bool)

In [21]:
jnp.abs(b[0]-jnp.eye(a[0].shape[0])).sum()

Array(151.95303, dtype=float32)

In [22]:
jnp.allclose(a[1], b[1])

Array(True, dtype=bool)

In [None]:
X = np.random.randn(10000, 128, 128, 3)

In [55]:
# jax.profiler.start_trace("./jax-trace")
# compute_channel_wise = jax.vmap(lambda X: compute_patch_stats(X, 1, 5, 1, 3, 10))
Xt = rearrange(X, "n h w c ->  c n h w 1")
print(Xt.shape)
b = jnp.array([compute_patch_stats(X_channel, 1, 7, 1, 3, 40)[0] for X_channel in Xt])
# b = compute_channel_wise(Xt)
# b = jax.lax.map(lambda X: compute_patch_stats(X, 1, 5, 1, 3, 10), Xt)
b[0].block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof", backend="gpu")
print(b.shape)
# jax.profiler.stop_trace()

(3, 10000, 128, 128, 1)
(3, 49, 49)


In [64]:
X = np.random.randn(5000, 128, 128, 1)

In [45]:
# jax.profiler.start_trace("./jax-trace")
b = compute_patch_stats(X, 1, 7, 1, 3, 40)
b[0].block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof", backend="gpu")
print(b[0].shape)
# jax.profiler.stop_trace()

(147, 147)
