In [53]:
import torch
import gpytorch
from dak.kernels.laplace_kernel import LaplaceProductKernel
from dak.utils.sparse_design.design_class import HyperbolicCrossDesign, SparseGridDesign
from dak.utils.operators.chol_inv import mk_chol_inv, tmk_chol_inv

In [55]:
dyadic_design = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=3, input_lb=0, input_ub=1)
chol_inv = mk_chol_inv(
    dyadic_design=dyadic_design,
    markov_kernel=LaplaceProductKernel(lengthscale=1.),
    upper=True)  # [m, m] size tensor
design_points = dyadic_design.points.reshape(-1, 1)  # [m, 1] size tensor
print(dyadic_design.points)

tensor([0.5000, 0.2500, 0.7500, 0.1250, 0.3750, 0.6250, 0.8750])


In [3]:
print(chol_inv.to_dense())

tensor([[ 1.0000, -1.2416, -1.2416,  0.0000, -1.4069, -1.4069,  0.0000],
        [ 0.0000,  1.5942,  0.0000, -1.8764, -1.4069,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.5942,  0.0000,  0.0000, -1.4069, -1.8764],
        [ 0.0000,  0.0000,  0.0000,  2.1262,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.8358,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.8358,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.1262]])


In [5]:
from linear_operator.utils.cholesky import psd_safe_cholesky  # from gpytorch's linear_operator

@torch.no_grad()  # drop this if you want gradients through everything
def inv_cholesky_transpose(K: torch.Tensor, jitter: float = 1e-6) -> torch.Tensor:
    """
    Returns U = L^{-T}, where K = L L^T (L lower-triangular), so K^{-1} = U U^T.
    Works with batched SPD matrices (..., n, n).
    """
    n = K.size(-1)
    I = torch.eye(n, dtype=K.dtype, device=K.device)
    # Stable Cholesky of K (lower-triangular)
    L = psd_safe_cholesky(K, jitter=jitter)  # (..., n, n), lower-triangular
    # Solve L^T U = I  ->  U = L^{-T} (upper-triangular)
    U = torch.linalg.solve_triangular(L.transpose(-1, -2), I, upper=True)
    return U  # U is upper-triangular and K^{-1} = U @ U.transpose(-1, -2)

# --- quick check ---
# X = torch.randn(128, 3); K = kernel(X, X) + 1e-4*I, etc.
# U = inv_cholesky_transpose(K)
# err = torch.linalg.norm(K @ (U @ U.T) - torch.eye(K.size(-1), device=K.device))

covar = LaplaceProductKernel(lengthscale=1.)(design_points)
# K_inv = gpytorch.root_inv_decomposition(covar, method='cholesky').to_dense()
# print(K_inv)
L_inv_T = inv_cholesky_transpose(covar)
print(L_inv_T)

tensor([[ 1.0000e+00, -1.2416e+00, -1.2416e+00, -1.1921e-07, -1.4069e+00,
         -1.4069e+00, -2.8689e-07],
        [ 0.0000e+00,  1.5942e+00, -1.1309e-08, -1.8764e+00, -1.4069e+00,
         -5.1470e-07, -6.6061e-07],
        [ 0.0000e+00,  0.0000e+00,  1.5942e+00, -2.6396e-08, -1.6125e-07,
         -1.4069e+00, -1.8764e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.1262e+00, -1.6530e-07,
          3.4646e-07,  4.5889e-07],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.8358e+00,
          9.6151e-08,  4.6208e-07],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          2.8358e+00, -2.2046e-07],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  2.1262e+00]])


In [7]:
L = chol_inv.to_dense()
K_uu_inv = L @ L.T
K_uu = K_uu_inv.inverse()
print(K_uu)

tensor([[1.0000, 0.7788, 0.7788, 0.6873, 0.8825, 0.8825, 0.6873],
        [0.7788, 1.0000, 0.6065, 0.8825, 0.8825, 0.6873, 0.5353],
        [0.7788, 0.6065, 1.0000, 0.5353, 0.6873, 0.8825, 0.8825],
        [0.6873, 0.8825, 0.5353, 1.0000, 0.7788, 0.6065, 0.4724],
        [0.8825, 0.8825, 0.6873, 0.7788, 1.0000, 0.7788, 0.6065],
        [0.8825, 0.6873, 0.8825, 0.6065, 0.7788, 1.0000, 0.7788],
        [0.6873, 0.5353, 0.8825, 0.4724, 0.6065, 0.7788, 1.0000]])


In [8]:
K_uu_lap = LaplaceProductKernel(lengthscale=1.)(design_points)
print(K_uu_lap)

tensor([[1.0000, 0.7788, 0.7788, 0.6873, 0.8825, 0.8825, 0.6873],
        [0.7788, 1.0000, 0.6065, 0.8825, 0.8825, 0.6873, 0.5353],
        [0.7788, 0.6065, 1.0000, 0.5353, 0.6873, 0.8825, 0.8825],
        [0.6873, 0.8825, 0.5353, 1.0000, 0.7788, 0.6065, 0.4724],
        [0.8825, 0.8825, 0.6873, 0.7788, 1.0000, 0.7788, 0.6065],
        [0.8825, 0.6873, 0.8825, 0.6065, 0.7788, 1.0000, 0.7788],
        [0.6873, 0.5353, 0.8825, 0.4724, 0.6065, 0.7788, 1.0000]])


In [9]:
x = torch.Tensor([0.35]).unsqueeze(dim=-1)
h = LaplaceProductKernel(lengthscale=1.)(x, design_points)
out = torch.matmul(h, chol_inv)
# y = torch.sparse.mm(chol_inv.T, h.T).T
print(out)

x = torch.Tensor([0.65]).unsqueeze(dim=-1)
h = LaplaceProductKernel(lengthscale=1.)(x, design_points)
out = torch.matmul(h, chol_inv)
print(out)

tensor([[ 8.6071e-01,  3.7387e-01, -1.4539e-07, -5.6704e-09,  2.8185e-01,
          2.0832e-07, -3.2482e-08]])
tensor([[ 8.6071e-01, -1.4539e-07,  3.7387e-01, -3.2482e-08,  1.3319e-07,
          2.8185e-01, -5.6704e-09]])


In [46]:
def dyadic_nonzero_indices(x: torch.Tensor, L: int) -> torch.LongTensor:
    """
    Return the per-level nonzero column index J_s(x) for each x, under dyadic ordering.
    Level 1 is the DC column (index 0).
    """
    assert x.ndim == 1, "x must be 1D"
    N = x.shape[0]
    device = x.device

    # powers of two: 2^s for s=1..L
    pow2 = torch.pow(2, torch.arange(1, L+1, device=device, dtype=torch.int64))

    # k_s = ceil(2^s * x) clamped to [1, 2^s - 1]
    ks = torch.ceil(x[:, None] * pow2[None, :].to(x.dtype)).to(torch.int64)  # (N, L)
    ks = torch.clamp(ks, min=1)
    ks_max = (pow2 - 1)[None, :]  # (1, L)
    ks = torch.minimum(ks, ks_max)

    # r_s^(odd): force to be odd (right endpoint index made odd)
    is_even = (ks & 1) == 0
    rs = ks - is_even.to(torch.int64)  # (N, L), odd in {1,3,...,2^s-1}

    # position within level s: t_s in {1,...,2^{s-1}}
    ts = (rs + 1) // 2  # (N, L)

    # offsets: number of columns before level s (0-based indexing)
    offsets = (pow2 // 2) - 1  # (L,)

    # global 0-based indices: J_s = offset(s) + (t_s - 1)
    idx = offsets[None, :] + (ts - 1)  # (N, L)

    return idx  # (N, L), long


def dyadic_phi_values(x: torch.Tensor, L: int, sigma: float = 1.0, ell_c: float = 1.0):
    """
    Compute the nonzero φ(x) values at the per-level columns (dyadic order).
    Returns:
      idx:  (N, L) long — nonzero column indices per level (DC is level 1)
      vals: (N, L) float — corresponding φ values per level
    """
    assert x.ndim == 1, "x must be 1D"
    device = x.device
    dtype  = x.dtype
    N = x.shape[0]

    # Indices per level (includes DC at level 1)
    idx = dyadic_nonzero_indices(x, L)  # (N, L)

    # Preallocate values
    vals = torch.zeros((N, L), dtype=dtype, device=device)

    # Level 1 (DC): u = 1/2
    vals[:, 0] = sigma * torch.exp(-torch.abs(x - 0.5) / ell_c)

    # Levels s = 2..L: one active detail column per level
    for s in range(2, L+1):
        h_s   = 2.0 ** (-s)
        rho_s = torch.exp(torch.tensor(-h_s / ell_c, dtype=dtype, device=device))
        # rs_odd = odd index computed above; recover it to get left sibling location
        # From idx we can recompute r_s: r_s = 2*t_s - 1, and t_s = idx - offset + 1
        # But it's simpler to recompute from x: use k_s and rs as in dyadic_nonzero_indices
        # Recompute small to keep code clear:
        k_s = torch.ceil(x * (2 ** s)).to(torch.int64).clamp(min=1, max=(2 ** s) - 1)
        r_s = k_s - ((k_s & 1) == 0).to(torch.int64)  # odd
        u_left = r_s.to(dtype) * h_s                  # left sibling location
        t_s = x - u_left                              # in [0, h_s]

        # φ_s(x) = (2 σ ρ_s / sqrt(1-ρ_s^2)) sinh( t_s / ell_c )
        # vals[:, s-1] = (2.0 * sigma * rho_s / torch.sqrt(1.0 - rho_s**2)) * torch.sinh(t_s / ell_c)
        vals[:, s-1] = torch.abs((2.0 * sigma * rho_s) * torch.sinh(t_s / ell_c))

    return idx, vals


def dyadic_phi_dense(x: torch.Tensor, L: int, sigma: float = 1.0, ell_c: float = 1.0):
    """
    Build the full dense row(s) φ(x) of shape (N, m) with zeros elsewhere,
    where m = 2^L - 1 and only L entries per row are nonzero.
    """
    idx, vals = dyadic_phi_values(x, L, sigma, ell_c)  # (N, L), (N, L)
    N = x.shape[0]
    m = (1 << L) - 1
    Phi = torch.zeros((N, m), dtype=x.dtype, device=x.device)
    rows = torch.arange(N, device=x.device)[:, None].expand_as(idx)  # (N, L)
    Phi[rows.reshape(-1), idx.reshape(-1)] = vals.reshape(-1)
    return Phi  # (N, m)

In [60]:
def dyadic_psi(x: torch.Tensor, L: int, sigma: float = 1.0, ell_c: float = 1.0):
    """
    Batched dyadic nonzero indices.

    Args
    ----
    x : (...,) tensor with values in [0, 1].
        Works with any number of leading batch dims.
    L : int, number of dyadic levels (total columns m = 2^L - 1).

    Returns
    -------
    idx : (..., L) long tensor
        0-based global column indices in dyadic order for each level (DC is level 1).
        The returned shape matches the leading shape of x, with an extra trailing dim of size L.
    """
    if x.ndim == 0:
        x = x.unsqueeze(0)  # promote scalar to shape (1,)
    device, dtype = x.device, x.dtype

    # 2^s for s=1..L
    pow2 = torch.pow(2, torch.arange(1, L+1, device=device, dtype=torch.int64))  # (L,)

    # k_s = ceil(2^s * x) clamped to [1, 2^s - 1]
    ks = torch.ceil(x[..., None] * pow2.to(x.dtype)).to(torch.int64)  # (..., L)
    ks = torch.clamp(ks, min=1)
    ks_max = (pow2 - 1)  # (L,)
    ks = torch.minimum(ks, ks_max)  # (..., L)

    # r_s^(odd): force to be odd (right endpoint index made odd)
    # if ks even -> ks-1, else ks
    rs = ks - ((ks & 1) == 0).to(torch.int64) # (..., L), odd in {1,3,...,2^s-1}

    # position within level s: t_s in {1,...,2^{s-1}}
    ts = (rs + 1) // 2  # (..., L)

    # offsets: number of columns before level s (0-based indexing)
    offsets = (pow2 // 2) - 1  # (L,)

    # global 0-based indices: J_s = offset(s) + (t_s - 1)
    idx = offsets + (ts - 1)  # (..., L)

    u = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1).points # (2^L-1,)
    view_shape = (1,) * x.dim() + (u.shape[0],)      # (1,1,...,1, 2^L-1)
    u_selected = torch.gather(u.view(view_shape).expand(*x.shape, -1), dim=-1, index=idx) # (..., L)

    delta = torch.abs(x.unsqueeze(-1) - u_selected) # |x - m2^{-l}|
    # pow2_f = (1.0 / pow2).view(view_shape).expand(*x.shape, -1)  # (..., L), 2^{-l} for l=1..L
    pow2_f = (1.0 / pow2).to(x.dtype)

    psi =  torch.sqrt(2 / torch.sinh(pow2_f * 2 * ell_c)) * torch.sinh(ell_c * (pow2_f - delta))

    return idx, psi  # (N, L), long

    # dyadic_points = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1).points # [L,]
    # selected_points = torch.gather(dyadic_points.expand(N, -1), dim=1, index=idx)  # [N, L]
    # delta = torch.abs(x.unsqueeze(-1) - selected_points)  # [N, L]
    # pow_half = torch.pow(0.5, torch.arange(1, L + 1, device=device, dtype=torch.float32)).expand(2, -1)


In [61]:
L = 3
x = torch.tensor([0.35, 0.65])  # N=4 points in [0,1]

idx = dyadic_nonzero_indices(x, L)
print(idx)

h = LaplaceProductKernel(lengthscale=1.)(x, design_points)
out = torch.matmul(h, chol_inv)
print(out)

out_selected = torch.gather(out, dim=-1, index=idx)
print(out_selected)

tensor([[0, 1, 4],
        [0, 2, 5]])
tensor([[ 8.6071e-01,  3.7387e-01, -1.4539e-07, -5.6704e-09,  2.8185e-01,
          2.0832e-07, -3.2482e-08],
        [ 8.6071e-01, -1.4539e-07,  3.7387e-01, -3.2482e-08,  1.3319e-07,
          2.8185e-01, -5.6704e-09]])
tensor([[0.8607, 0.3739, 0.2818],
        [0.8607, 0.3739, 0.2818]])


In [70]:
x = torch.tensor([[0.35, 0.65], [0.35, 0.65]])
idx, psi = dyadic_psi(x, L)
print(idx)
print(psi)
print(psi.shape)

tensor([[[0, 1, 4],
         [0, 2, 5]],

        [[0, 1, 4],
         [0, 2, 5]]])
tensor([[[0.4660, 0.2950, 0.2818],
         [0.4660, 0.2950, 0.2818]],

        [[0.4660, 0.2950, 0.2818],
         [0.4660, 0.2950, 0.2818]]])
torch.Size([2, 2, 3])


In [41]:
points = dyadic_design.points.expand(2, -1)
print(points)
selected_points = torch.gather(points, dim=1, index=idx)
print(selected_points)
distances = torch.abs(x.unsqueeze(-1) - selected_points)
print(distances)

k = torch.arange(1, L + 1, dtype=torch.float32)   # 1,2,...,m
t = torch.pow(0.5, k).expand(2, -1)                             # [1/2, 1/4, 1/8, ..., 1/2^m]
# equivalently: t = 2.0 ** (-k)
print(t)

pow2 = torch.pow(2, torch.arange(1, L+1, dtype=torch.int64))
print((1.0/pow2).expand(2, -1))
print(pow2[None,:])

tensor([[0.5000, 0.2500, 0.7500, 0.1250, 0.3750, 0.6250, 0.8750],
        [0.5000, 0.2500, 0.7500, 0.1250, 0.3750, 0.6250, 0.8750]])
tensor([[0.5000, 0.2500, 0.3750],
        [0.5000, 0.7500, 0.6250]])
tensor([[0.1500, 0.1000, 0.0250],
        [0.1500, 0.1000, 0.0250]])
tensor([[0.5000, 0.2500, 0.1250],
        [0.5000, 0.2500, 0.1250]])
tensor([[0.5000, 0.2500, 0.1250],
        [0.5000, 0.2500, 0.1250]])
tensor([[2, 4, 8]])


In [13]:
L = 3
x = torch.tensor([0.35, 0.65])  # N=4 points in [0,1]

idx, phi = dyadic_phi_values(x, L)
phi_dense = dyadic_phi_dense(x, L)
print(idx)
print(phi)
print(phi_dense)

phi_selected = torch.gather(phi_dense, dim=1, index=idx)
print(phi_selected)

tensor([[0, 1, 4],
        [0, 2, 5]])
tensor([[0.8607, 0.1560, 0.0441],
        [0.8607, 0.1560, 0.0441]])
tensor([[0.8607, 0.1560, 0.0000, 0.0000, 0.0441, 0.0000, 0.0000],
        [0.8607, 0.0000, 0.1560, 0.0000, 0.0000, 0.0441, 0.0000]])
tensor([[0.8607, 0.1560, 0.0441],
        [0.8607, 0.1560, 0.0441]])
