In [1]:
import matplotlib.pyplot as plt
import numpy as np

from sklearn.metrics import pairwise_distances

np.seterr(all='raise')


def affinity_from_distances(distances, scale=1.0):
    return np.exp(-distances**2 / (2.0 * scale**2))


def softmax(x):
    try:
        e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        s = e_x / e_x.sum(axis=1, keepdims=True)
    except FloatingPointError:
        s = np.zeros_like(x)
    return s


def soft_jaccard_similarity(dx, dy, T=1.0, scale=1.0):

    def compute_membership(d, T, scale):
        a = affinity_from_distances(d, scale)
        np.fill_diagonal(a, 0)
        s = softmax(a / T)
        return s

    sx = compute_membership(dx, T, scale)
    sy = compute_membership(dy, T, scale)

    J = np.minimum(sx, sy).sum(axis=1) / np.maximum(sx, sy).sum(axis=1)
    return J.mean()

In [5]:
def plot_jaccard_curve_for_scale(X, Y, scales, Ts):
    plt.figure()
    for scale in scales:
        Js = [soft_jaccard_similarity(X, Y, T=T, scale=scale) for T in Ts]
        plt.plot(Ts, Js, label=f'scale={scale}')
    plt.legend()
    plt.xlabel('Temperature')
    plt.ylabel('Soft Jaccard Similarity')
    plt.ylim(0, 1)
    plt.semilogx()
    plt.show()

num_samples = 1000
num_features = 10

X = np.random.normal(0, 1, (num_samples, num_features))
Y = np.random.normal(0, 1, (num_samples, num_features))
scales = [10.0**k for k in range(-1, 2)]
Ts = 10.0**np.linspace(-3, 2, 100)

plot_jaccard_curve_for_scale(X, Y, scales, Ts)


0.8944753810755031

In [None]:
def foo(dx, scale):
    affinity = affinity_from_distances(dx, scale=scale)
    plt.figure(figsize=(10, 10))
    for T in [10.0**i for i in np.linspace(-2, 2, 20)]:
        aux = -np.sort(-softmax(-affinity / T), axis=1).mean(axis=0)
        aux = aux / aux[0]
        plt.plot(aux, label='T={}'.format(T))
    plt.legend()
    plt.xscale('log')
    plt.xlabel('Rank')
    plt.ylim(0, 1)
    plt.ylabel('Average probability')
    plt.title('Scale={}'.format(scale))
    plt.show()


dx = pairwise_distances(X, X, metric='euclidean')
for scale in [0.1, 1.0, 5.0, 10.0, 100.0]:
    foo(dx, scale)
