In [9]:
from dak.kernels.laplace_kernel import LaplaceProductKernel
from dak.utils.sparse_design.design_class import HyperbolicCrossDesign
from dak.utils.operators.chol_inv import mk_chol_inv

In [10]:
L = 4 # Level-L dyadic grid: m=2^L-1 grid points
dyadic_design = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1)
design_points = dyadic_design.points.reshape(-1, 1)  # [m, 1] size tensor
print(f'Dyadic sorted design points: {dyadic_design.points}')

Dyadic sorted design points: tensor([0.5000, 0.2500, 0.7500, 0.1250, 0.3750, 0.6250, 0.8750, 0.0625, 0.1875,
        0.3125, 0.4375, 0.5625, 0.6875, 0.8125, 0.9375])


In [11]:
import torch

x = torch.tensor([0.35, 0.65])

In [12]:
# PART 1: Use Cholesky decomposition phi(x) = k(x,U)L^{-T}
chol_inv = mk_chol_inv(
    dyadic_design=dyadic_design,
    markov_kernel=LaplaceProductKernel(lengthscale=1.),
    upper=True)  # [m, m] size tensor
k_xu = LaplaceProductKernel(lengthscale=1.)(x, design_points)
phi = torch.matmul(k_xu, chol_inv)
print(f"x: {x}")
print(f"phi(x): {phi}")

x: tensor([0.3500, 0.6500])
phi(x): tensor([[ 8.6071e-01,  3.7387e-01, -1.4539e-07, -5.6704e-09,  2.8185e-01,
          2.0832e-07, -3.2482e-08,  8.6208e-08,  3.4005e-09,  9.9880e-02,
         -7.8084e-08, -1.1129e-07,  3.7601e-09, -1.4542e-07,  6.6069e-08],
        [ 8.6071e-01, -1.4539e-07,  3.7387e-01, -3.2482e-08,  1.3319e-07,
          2.8185e-01, -5.6704e-09,  6.6069e-08,  1.2297e-07, -1.1129e-07,
          4.1125e-08, -7.3236e-08,  9.9880e-02, -4.3964e-08,  8.6208e-08]])


In [13]:
# PART 2: Use compact support and sparse psi
def dyadic_psi(x: torch.Tensor, L: int, sigma: float = 1.0, ell_c: float = 1.0):
    """
    Batch-wise dyadic nonzero indices.

    Args
    ----
    x : (...,) tensor with values in [0, 1].
        Works with any number of leading batch dims.
    L : int, number of dyadic levels (total columns m = 2^L - 1).

    Returns
    -------
    idx : (..., L) long tensor
        0-based global column indices in dyadic order for each level (DC is level 1).
        The returned shape matches the leading shape of x, with an extra trailing dim of size L.
    """
    if x.ndim == 0:
        x = x.unsqueeze(0)  # promote scalar to shape (1,)
    device, dtype = x.device, x.dtype

    # 2^s for s=1..L
    pow2 = torch.pow(2, torch.arange(1, L+1, device=device, dtype=torch.int64))  # (L,)

    # k_s = ceil(2^s * x) clamped to [1, 2^s - 1]
    ks = torch.ceil(x[..., None] * pow2.to(x.dtype)).to(torch.int64)  # (..., L)
    ks = torch.clamp(ks, min=1)
    ks_max = (pow2 - 1)  # (L,)
    ks = torch.minimum(ks, ks_max)  # (..., L)

    # r_s^(odd): force to be odd (right endpoint index made odd)
    # if ks even -> ks-1, else ks
    rs = ks - ((ks & 1) == 0).to(torch.int64) # (..., L), odd in {1,3,...,2^s-1}

    # position within level s: t_s in {1,...,2^{s-1}}
    ts = (rs + 1) // 2  # (..., L)

    # offsets: number of columns before level s (0-based indexing)
    offsets = (pow2 // 2) - 1  # (L,)

    # global 0-based indices: J_s = offset(s) + (t_s - 1)
    idx = offsets + (ts - 1)  # (..., L)

    u = HyperbolicCrossDesign(dyadic_sort=True, return_neighbors=True)(deg=L, input_lb=0, input_ub=1).points # (2^L-1,)
    view_shape = (1,) * x.dim() + (u.shape[0],)      # (1,1,...,1, 2^L-1)
    u_selected = torch.gather(u.view(view_shape).expand(*x.shape, -1), dim=-1, index=idx) # (..., L)

    delta = torch.abs(x.unsqueeze(-1) - u_selected) # |x - m2^{-l}|
    # pow2_f = (1.0 / pow2).view(view_shape).expand(*x.shape, -1)  # (..., L), 2^{-l} for l=1..L
    pow2_f = (1.0 / pow2).to(x.dtype)

    psi =  sigma * torch.sqrt(2 / torch.sinh(pow2_f * 2 * ell_c)) * torch.sinh(ell_c * (pow2_f - delta))

    return idx, psi  # (..., L)

In [14]:
# Let's find out the non-zero indices using wavelet design
# Compared to phi_dense, the selected nonzero indices are correct
idx, psi = dyadic_psi(x, L)
print(f"Non-zero idx:\n {idx}")
phi_nonzero = torch.gather(phi, dim=-1, index=idx)
print(f"Non-zero phi(x):\n {phi_nonzero}")
print(f"psi(x):\n {psi}")

Non-zero idx:
 tensor([[ 0,  1,  4,  9],
        [ 0,  2,  5, 12]])
Non-zero phi(x):
 tensor([[0.8607, 0.3739, 0.2818, 0.0999],
        [0.8607, 0.3739, 0.2818, 0.0999]])
psi(x):
 tensor([[0.4660, 0.2950, 0.2818, 0.0999],
        [0.4660, 0.2950, 0.2818, 0.0999]])


In [15]:
# Then let's verify the batch-wise computing of non-zero indices and psi
x = torch.tensor([[0.35, 0.65], [0.65, 0.35]]) # (B, N), batch size=2, feature=2
idx, psi = dyadic_psi(x, L)
print(f"Non-zero idx:\n {idx}")
print(f"psi(x):\n {psi}")
print(f"psi shape: {psi.shape}")

Non-zero idx:
 tensor([[[ 0,  1,  4,  9],
         [ 0,  2,  5, 12]],

        [[ 0,  2,  5, 12],
         [ 0,  1,  4,  9]]])
psi(x):
 tensor([[[0.4660, 0.2950, 0.2818, 0.0999],
         [0.4660, 0.2950, 0.2818, 0.0999]],

        [[0.4660, 0.2950, 0.2818, 0.0999],
         [0.4660, 0.2950, 0.2818, 0.0999]]])
psi shape: torch.Size([2, 2, 4])
