In [2]:
import torch

def create_simplex_points(N: int):
    I = torch.eye(N)
    ones = torch.ones(N)
    us = I - (1.0 / N) * ones.unsqueeze(dim=1)  # (N,N)

    # SVD
    U, S, V = torch.svd(us)
    B = V[:, :N-1]  # (N, N-1)
    points_nminus1 = us @ B  # (N, N-1)
    return points_nminus1

def random_orthonormal_embed(points_nminus1: torch.Tensor, D: int):
    N, dim = points_nminus1.shape  # dim = N-1
    if dim > D:
        raise ValueError(f"dim = {dim}, but D={D} < dim. Equidistant embedding is impossible.")

    M = torch.randn(D, dim)  # (D, N-1)

    Q, R = torch.linalg.qr(M)  # (D, dim)

    out = points_nminus1 @ Q.transpose(0,1)  # (N, D)
    return out

def scale_to_target_norm(points: torch.Tensor, target_norm: float):
    norms = torch.norm(points, dim=1, p=2, keepdim=True)  # (N, 1)
    scaled_points = points * (target_norm / norms)  # (N, D)
    return scaled_points

def pairwise_distance_stats(pts: torch.Tensor):
    dist_mat = torch.cdist(pts, pts, p=2)  # (N,N)
    N = pts.shape[0]

    mask = torch.triu(torch.ones(N, N), diagonal=1).bool()
    pairwise_dists = dist_mat[mask]  # (N*(N-1)/2,)

    mean_pairwise_dist = pairwise_dists.mean()
    std_pairwise_dist = pairwise_dists.std()

    tdm_pairwise = torch.sum(torch.abs(pairwise_dists - mean_pairwise_dist))

    norms = torch.norm(pts, dim=1, p=2)
    mean_norm = norms.mean()
    std_norm = norms.std()

    tdm_norms = torch.sum(torch.abs(norms - mean_norm))

    return (pairwise_dists, mean_pairwise_dist, std_pairwise_dist, tdm_pairwise,
            norms, mean_norm, std_norm, tdm_norms)

In [3]:
N = 10
D = 3 * 32 * 32
target_norm = 0.01

points_nminus1 = create_simplex_points(N)   # (N,N-1)

points_D = random_orthonormal_embed(points_nminus1, D)  # (N, D)

points_scaled = scale_to_target_norm(points_D, target_norm)

(pairwise_dists, mean_pd, std_pd, tdm_pd,
    norms, mean_n, std_n, tdm_n) = pairwise_distance_stats(points_scaled)

print(f"== {N}-point simplex in {D}-dim (random orthonormal embedding) ==")
print(f"Scaled to target norm: {target_norm}")
print("Pairwise distances:", pairwise_dists)
print(f"Mean pairwise distance = {mean_pd:.6f}, Std = {std_pd:.6f}")
print(f"Total deviation from mean (pairwise) = {tdm_pd:.6f}")

print("\nDistance from origin:", norms)
print(f"Mean norm = {mean_n:.6f}, Std = {std_n:.6f}")
print(f"Total deviation from mean (norms) = {tdm_n:.6f}")

torch.save(points_scaled, f"prior_mean_simplex_{D}dim_{N}point_radius{target_norm}.pt")

== 10-point simplex in 3072-dim (random orthonormal embedding) ==
Scaled to target norm: 0.01
Pairwise distances: tensor([0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149,
        0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149,
        0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149,
        0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149,
        0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149, 0.0149])
Mean pairwise distance = 0.014907, Std = 0.000000
Total deviation from mean (pairwise) = 0.000000

Distance from origin: tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
        0.0100])
Mean norm = 0.010000, Std = 0.000000
Total deviation from mean (norms) = 0.000000
