In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
import numpy as np
from sklearn.datasets import load_iris

In [15]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x71e5d875a6d0>

In [7]:
X, y = load_iris(return_X_y=True)

In [24]:
def check_for_nan(tensor, label):
    if torch.isnan(tensor).any():
        print(f"{label} contains NaN values")

In [8]:
X = torch.tensor(X, dtype=torch.float32)

In [40]:
def iso_kernel(X, all_X, eta, psi):
    map_tmp = None
    if all_X is None:
        all_X = X
    np.random.seed(42)
    samples_index = [
        np.random.choice(len(all_X), psi, replace=False) for _ in range(1)
    ]
    for s_index in samples_index:
        samples = all_X[s_index]
        dist = -2*eta*torch.cdist(X, samples)
        log_soft_max_dist = F.log_softmax(dist, dim=1)
        soft_max_dist = torch.clamp(torch.exp(log_soft_max_dist), min=0, max=1)
        check_for_nan(soft_max_dist,"soft_max_dist")
        soft_dist = torch.sqrt(soft_max_dist)

        check_for_nan(soft_dist,"soft_dist")
        if map_tmp is None:
            map_tmp = soft_dist
        else:
            map_tmp = torch.hstack([map_tmp, soft_dist])
        if torch.mm(soft_dist, soft_dist.T).isinf().any():
            print(soft_dist)
        if torch.mm(soft_dist, soft_dist.T).isnan().any():
            print(soft_dist)
    ik_similarity = torch.mm(map_tmp, map_tmp.T) / len(samples_index)
    assert ik_similarity.shape == (X.shape[0], X.shape[0])
    return ik_similarity
Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)
x_map= iso_kernel(X = Xs, all_X=None, eta=100, psi=8)

  Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)


In [38]:
Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)
x_map= iso_kernel(X = Xs, all_X=None, eta=100, psi=8)


  Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)


In [39]:
x_map.sum().backward()

RuntimeError: Function 'ExpBackward0' returned nan values in its 0th output.

In [37]:
Xs.grad

tensor([[ 4.1730e-01,  8.0712e+00,  1.7841e+00, -4.6473e+00],
        [-5.3073e-01,  8.3827e-01,  1.0163e+00, -1.9493e+00],
        [-9.3007e-01,  2.4024e+00,  7.1492e-01, -1.8400e+00],
        [-8.8756e-01,  2.0060e+00,  1.6668e+00, -1.6377e+00],
        [ 8.4285e-01,  6.7309e+00,  1.8402e+00, -2.8285e+00],
        [-1.0328e+02, -7.1586e+01, -4.2397e+01,  3.4431e+01],
        [-8.9760e-01,  3.9387e+00,  1.5046e+00, -1.4610e+00],
        [-6.9808e-01,  5.7369e+00,  2.8753e+00, -3.9332e+00],
        [-7.4143e-01,  1.1697e+00,  1.0369e+00, -1.1221e+00],
        [-3.9023e-01,  1.9284e+00,  1.9847e+00, -2.9866e+00],
        [-6.7373e+01, -9.4827e+01, -4.4913e+01,  2.5789e+01],
        [-4.0918e-01,  4.8921e+00,  3.1472e+00, -2.5774e+00],
        [-4.7456e-01,  1.2881e+00,  1.1438e+00, -2.2429e+00],
        [-5.9438e-01,  1.5283e+00,  3.4799e-01, -1.2482e+00],
        [-3.1820e+01, -3.8702e+01, -5.2870e+01,  6.8827e+00],
        [-1.5511e+01, -1.1195e+00, -1.6857e+01,  8.4329e+00],
        

In [98]:
def IK_Kernel(X, psi, t, eta):
    map_tmp = None
    for i in range(t):
        samples_index = torch.randperm(len(X))[:psi]
        samples = X[samples_index]
        dist = torch.cdist(X, samples)
        soft_dist = torch.exp(-eta * dist) / torch.sqrt(
            torch.exp(-2 * eta * dist).sum(dim=1)
        ).view(-1, 1)
        if map_tmp is None:
            map_tmp = soft_dist
        else:
            map_tmp = torch.hstack([map_tmp, soft_dist])
    return torch.matmul(map_tmp, map_tmp.T) / t


In [120]:
psi = 10
t = 200
samples_inds = [np.random.choice(len(X), psi) for _ in range(t)]
Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)
x_map= IK_Kernel(Xs, 0.9, samples_inds)

  Xs = torch.tensor(X, dtype=torch.float32, requires_grad=True)


In [106]:
x_map

tensor([[1.0000, 0.9718, 0.9717,  ..., 0.1082, 0.1041, 0.1321],
        [0.9718, 1.0000, 0.9962,  ..., 0.1127, 0.1081, 0.1385],
        [0.9717, 0.9962, 1.0000,  ..., 0.1029, 0.0990, 0.1267],
        ...,
        [0.1082, 0.1127, 0.1029,  ..., 1.0000, 0.9799, 0.9563],
        [0.1041, 0.1081, 0.0990,  ..., 0.9799, 1.0000, 0.9310],
        [0.1321, 0.1385, 0.1267,  ..., 0.9563, 0.9310, 1.0000]],
       grad_fn=<DivBackward0>)

In [121]:
x_map.sum().backward()

In [123]:
Xs.grad

tensor([[ 4.0772e+00, -7.9528e+00,  1.6242e+01,  4.8093e+00],
        [ 1.9147e+01, -1.0088e+01,  1.7106e+01,  9.2099e+00],
        [ 7.4152e+00,  3.3817e-01,  9.4550e+00,  6.2629e+00],
        [ 1.1261e+00, -6.9827e+00,  2.8333e+01,  8.8170e+00],
        [ 5.0610e-01, -1.6842e-01,  1.5150e+01,  4.8825e+00],
        [ 7.0885e+00, -6.1876e+00,  3.3950e+01,  1.6040e+01],
        [ 6.1307e-01,  4.0934e+00,  1.6945e+01,  1.1881e+01],
        [ 5.7725e+00, -6.6664e+00,  2.4124e+01,  6.3290e+00],
        [-1.4240e+00, -9.7120e+00,  1.9084e+01,  7.3879e+00],
        [ 1.4067e+01, -4.0080e+00,  2.2629e+01,  9.4858e-01],
        [ 1.0521e+01, -6.1557e+00,  1.7991e+01,  3.8390e+00],
        [ 4.2270e-01,  6.5634e-01,  3.1538e+01,  6.8427e+00],
        [ 1.1204e+01, -9.6265e+00,  1.6842e+01,  1.2700e-02],
        [-4.7312e+00, -2.5094e+00,  1.0070e+00,  7.8222e-01],
        [ 1.6562e+01, -6.4953e+00,  7.8912e+00,  4.1789e+00],
        [ 1.7012e+01,  1.4176e+01,  2.3876e+01,  1.4694e+01],
        

In [87]:
k_similar/ t

tensor([[1.0000, 0.9997, 0.9997,  ..., 0.9471, 0.9472, 0.9529],
        [0.9997, 1.0000, 0.9999,  ..., 0.9481, 0.9481, 0.9540],
        [0.9997, 0.9999, 1.0000,  ..., 0.9455, 0.9456, 0.9515],
        ...,
        [0.9471, 0.9481, 0.9455,  ..., 1.0000, 0.9998, 0.9994],
        [0.9472, 0.9481, 0.9456,  ..., 0.9998, 1.0000, 0.9992],
        [0.9529, 0.9540, 0.9515,  ..., 0.9994, 0.9992, 1.0000]])

In [None]:
def IK_Kernel(x, psi, t, eps):
    map_tmp = None
    for i in range(t):
        samples_index = torch.randperm(len(x))[:psi]
        samples = x[samples_index]

        if map_tmp is None:
            map_tmp = F.softmax(-eps * torch.cdist(x, samples), dim=1)
        else:
            map_tmp = torch.vstack([map_tmp ,F.softmax(-eps * torch.cdist(x, samples), dim=1)])

    return torch.matmul(map_tmp, map_tmp.T).mean(dim=0)