In [1]:
from functools import partial
from typing import Union

from rlaopt.kernels import RBFLinOp, Matern12LinOp, Matern32LinOp, Matern52LinOp
import torch
from torch.distributions import Distribution, Normal, StudentT

In [2]:
# dist = Normal(torch.tensor([0.0], device="cpu"), torch.tensor([1.0], device="cpu"))
# sample = dist.sample((1000,))
# print(sample.shape)
# print(type(sample))
# print(type(sample.shape[0]))
# print(sample.device)

In [3]:
def _random_features(X: torch.Tensor, lengthscale: Union[float, torch.Tensor], num_features: int, Omega_dist: Distribution) -> torch.Tensor:
    X_whitened = X / lengthscale

    scale_factor = (2.0 / num_features) ** 0.5
    # Omega = Omega_fn(X_whitened.shape[1], num_features, device=X.device, dtype=X.dtype)
    # The sample method adds an extra dimension since the distribution parameters are tensors, so we need to squeeze it out
    Omega = Omega_dist.sample((X_whitened.shape[1], num_features)).squeeze(-1)
    B = 2 * torch.pi * torch.rand(num_features, device=X.device, dtype=X.dtype)
    return scale_factor * torch.cos(X_whitened @ Omega + B)


def rbf_random_features(X: torch.Tensor, lengthscale: Union[float, torch.Tensor], num_features: int) -> torch.Tensor:
    loc = torch.tensor([0.0], device=X.device, dtype=X.dtype)
    scale = torch.tensor([1.0], device=X.device, dtype=X.dtype)
    return _random_features(X, lengthscale, num_features, Normal(loc=loc, scale=scale))


def matern_random_features(X: torch.Tensor, lengthscale: Union[float, torch.Tensor], num_features: int, nu: float) -> torch.Tensor:
    df = torch.tensor([2.0 * nu], device=X.device, dtype=X.dtype)
    loc = torch.tensor([0.0], device=X.device, dtype=X.dtype)
    scale = torch.tensor([1.0], device=X.device, dtype=X.dtype)
    return _random_features(X, lengthscale, num_features, StudentT(df=df, loc=loc, scale=scale))

In [4]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

<torch._C.Generator at 0x7f08a41b3910>

In [5]:
device = torch.device("cuda:0")

n = 4
d = 10
lengthscale_scalar = 0.1
lengthscale_tensor = torch.tensor([0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2], device=device)
num_features = 10 ** 4

X = torch.randn(n, d, device=device) / (d ** 0.5)

In [6]:
def test_rf_fn(kernel_linop_class, rf_fn, X, lengthscale, num_features):
    kernel_linop = kernel_linop_class(X, X, {"lengthscale": lengthscale})
    K_true = kernel_linop @ torch.eye(kernel_linop.shape[1], device=kernel_linop.device)
    K_rf = rf_fn(X, lengthscale, num_features)
    K_approx = K_rf @ K_rf.T
    return K_true - K_approx

RBF

In [7]:
diff = test_rf_fn(RBFLinOp, rbf_random_features, X, lengthscale_scalar, num_features)
print(diff)

tensor([[-0.0158,  0.0121, -0.0035, -0.0098],
        [ 0.0121, -0.0066, -0.0117,  0.0070],
        [-0.0035, -0.0117, -0.0100, -0.0114],
        [-0.0098,  0.0070, -0.0114, -0.0047]], device='cuda:0')


In [8]:
diff = test_rf_fn(RBFLinOp, rbf_random_features, X, lengthscale_tensor, num_features)
print(diff)

tensor([[ 7.9411e-03, -4.5772e-03, -1.5838e-05, -1.0340e-02],
        [-4.5772e-03,  5.4056e-03, -6.2336e-03, -8.8580e-03],
        [-1.5838e-05, -6.2336e-03, -7.7012e-04, -1.0339e-02],
        [-1.0340e-02, -8.8580e-03, -1.0339e-02, -4.6142e-03]], device='cuda:0')


Matern-1/2

In [9]:
diff = test_rf_fn(Matern12LinOp, partial(matern_random_features, nu=0.5), X, lengthscale_scalar, num_features)
print(diff)

tensor([[ 0.0031,  0.0043, -0.0163,  0.0186],
        [ 0.0043, -0.0081, -0.0039, -0.0091],
        [-0.0163, -0.0039,  0.0038,  0.0081],
        [ 0.0186, -0.0091,  0.0081, -0.0021]], device='cuda:0')


In [10]:
diff = test_rf_fn(Matern12LinOp, partial(matern_random_features, nu=0.5), X, lengthscale_tensor, num_features)
print(diff)

tensor([[ 0.0105, -0.0141, -0.0017,  0.0171],
        [-0.0141,  0.0008,  0.0049,  0.0120],
        [-0.0017,  0.0049, -0.0037,  0.0197],
        [ 0.0171,  0.0120,  0.0197,  0.0113]], device='cuda:0')


Matern-3/2

In [11]:
diff = test_rf_fn(Matern32LinOp, partial(matern_random_features, nu=1.5), X, lengthscale_scalar, num_features)
print(diff)

tensor([[ 0.0064,  0.0125,  0.0196,  0.0030],
        [ 0.0125,  0.0041, -0.0010, -0.0057],
        [ 0.0196, -0.0010,  0.0041,  0.0064],
        [ 0.0030, -0.0057,  0.0064, -0.0028]], device='cuda:0')


In [12]:
diff = test_rf_fn(Matern32LinOp, partial(matern_random_features, nu=1.5), X, lengthscale_tensor, num_features)
print(diff)

tensor([[-0.0005,  0.0175, -0.0043,  0.0197],
        [ 0.0175,  0.0012, -0.0235,  0.0093],
        [-0.0043, -0.0235,  0.0167,  0.0104],
        [ 0.0197,  0.0093,  0.0104, -0.0039]], device='cuda:0')


Matern-5/2

In [13]:
diff = test_rf_fn(Matern52LinOp, partial(matern_random_features, nu=2.5), X, lengthscale_scalar, num_features)
print(diff)

tensor([[-0.0039,  0.0058,  0.0098,  0.0099],
        [ 0.0058,  0.0100, -0.0125, -0.0098],
        [ 0.0098, -0.0125,  0.0105, -0.0118],
        [ 0.0099, -0.0098, -0.0118,  0.0099]], device='cuda:0')


In [14]:
diff = test_rf_fn(Matern52LinOp, partial(matern_random_features, nu=2.5), X, lengthscale_tensor, num_features)
print(diff)

tensor([[-0.0076,  0.0013,  0.0163, -0.0114],
        [ 0.0013,  0.0046, -0.0240,  0.0043],
        [ 0.0163, -0.0240,  0.0068,  0.0110],
        [-0.0114,  0.0043,  0.0110,  0.0072]], device='cuda:0')
