In [1]:
import numpy as np
from source.metrics import get_risk_approximation
from psruq.source import vectorizer_uncertainty_scores
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
N_ensemble = 20
N_optim_searches = 1000000
n_classes = 2
max_concentration = 5

n_exp = 1

In [3]:
def central_prediction_spherical(
    logits_: np.ndarray,
):
    probs = safe_softmax(logits_)
    K = logits_.shape[-1]

    norms = np.linalg.norm(probs, axis=-1, keepdims=True, ord=2)
    x = np.mean(probs / norms, axis=0, keepdims=True)

    x0 = np.ones(K).reshape(1, 1, K) / K
    x0_norm = np.linalg.norm(x0, ord=2, axis=-1).ravel()

    y_orth = x - np.sum(x * x0, axis=-1, keepdims=True) * (x0 / x0_norm**2)
    y_orth_norm = np.linalg.norm(y_orth, ord=2, axis=-1).ravel()

    central_pred = x0 + (y_orth / np.sqrt(1 - y_orth_norm**2)) * x0_norm
    return central_pred


def central_prediction_logscore(
    logits_: np.ndarray,
):
    # probs = safe_softmax(logits_)
    return safe_softmax(np.mean(logits_, axis=0, keepdims=True))

In [4]:
FUNC_MINIMIZE = pairwise_prob_diff  # pairwise_brier pairwise_kl pairwise_prob_diff pairwise_spherical
THEORETICAL_CENTRAL_PREDICTION = central_prediction_neglog  # central_prediction_neglog central_prediction_brier central_prediction_logscore central_prediction_maxprob central_prediction_spherical

In [5]:
l1_err = []

for exp_id in tqdm(range(n_exp)):
    # sampled_vectors = np.random.dirichlet(
    #     alpha=np.random.rand(n_classes) * max_concentration,
    #     size=N_ensemble
    # ).reshape(N_ensemble, 1, n_classes)
    sampled_vectors = np.array(
        [
            [[0.89914033, 0.10085967]],
            [[0.4389841, 0.5610159]],
            [[0.99207102, 0.00792898]],
            [[0.57802041, 0.42197959]],
            [[0.89409964, 0.10590036]],
            [[0.60881584, 0.39118416]],
            [[0.9533906, 0.0466094]],
            [[0.91530455, 0.08469545]],
            [[0.24826006, 0.75173994]],
            [[0.36068223, 0.63931777]],
            [[0.60374964, 0.39625036]],
            [[0.96204757, 0.03795243]],
            [[0.99746462, 0.00253538]],
            [[0.51398327, 0.48601673]],
            [[0.88041703, 0.11958297]],
            [[0.44736386, 0.55263614]],
            [[0.44013821, 0.55986179]],
            [[0.77684085, 0.22315915]],
            [[0.40832816, 0.59167184]],
            [[0.21596403, 0.78403597]],
        ]
    )

    loss = []
    p_vals = []

    for rep in range(N_optim_searches + 1):
        if rep == N_optim_searches:
            central_prediction = np.array([0.99577723, 0.00422277]).reshape(
                1, 1, n_classes
            )
        central_prediction = np.random.dirichlet(
            alpha=max_concentration * np.ones(n_classes), size=1
        ).reshape(1, 1, n_classes)
        loss.append(
            FUNC_MINIMIZE(np.log(central_prediction), np.log(sampled_vectors)).mean()
        )
        p_vals.append(central_prediction.ravel())

    central_prediction_theory = THEORETICAL_CENTRAL_PREDICTION(np.log(sampled_vectors))
    p_vals = np.vstack(p_vals)

    err = np.linalg.norm(
        np.abs(p_vals[np.argmin(loss)].ravel() - central_prediction_theory.ravel()),
        ord=1,
        axis=-1,
    )

    l1_err.append(err)

100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:45<00:00, 45.40s/it]


In [6]:
np.abs(p_vals[np.argmin(loss)].ravel() - central_prediction_theory.ravel())

array([0.43357605, 0.00128463])

In [7]:
central_prediction_theory.ravel()

array([0.53354849, 0.03159083])

In [8]:
p_vals[np.argmin(loss)].ravel()

array([0.96712454, 0.03287546])

In [9]:
p_vals[np.argmax(loss)].ravel()

array([0.02508621, 0.97491379])

In [9]:
np.mean(l1_err)

0.06658520691916528

In [4]:
sampled_vectors = np.array(
    [
        [[0.89914033, 0.10085967]],
        [[0.4389841, 0.5610159]],
        [[0.99207102, 0.00792898]],
        [[0.57802041, 0.42197959]],
        [[0.89409964, 0.10590036]],
        [[0.60881584, 0.39118416]],
        [[0.9533906, 0.0466094]],
        [[0.91530455, 0.08469545]],
        [[0.24826006, 0.75173994]],
        [[0.36068223, 0.63931777]],
        [[0.60374964, 0.39625036]],
        [[0.96204757, 0.03795243]],
        [[0.99746462, 0.00253538]],
        [[0.51398327, 0.48601673]],
        [[0.88041703, 0.11958297]],
        [[0.44736386, 0.55263614]],
        [[0.44013821, 0.55986179]],
        [[0.77684085, 0.22315915]],
        [[0.40832816, 0.59167184]],
        [[0.21596403, 0.78403597]],
    ]
)

In [5]:
dim = sampled_vectors.shape[-1]
x_0 = np.ones(dim) / dim
x = np.mean(1 / sampled_vectors, axis=0)
x_0 = x_0.reshape(*x.shape)
x_parallel = x_0 * np.sum(x_0 * x) / np.linalg.norm(x_0, ord=2) ** 2
x_perp = x - x_parallel

In [6]:
k = (dim - np.sum(x_parallel * x_0, axis=-1)) / np.linalg.norm(x_perp) ** 2

In [7]:
k[0]

-0.03329541647082493

In [8]:
x

array([[ 1.8742439, 31.6547535]])

In [9]:
z = x_0 + k[0] * x_perp

In [10]:
z

array([[0.99577723, 0.00422277]])

In [11]:
z.sum()

0.9999999999999997

1.8742439016279302

In [16]:
A = x.ravel()[1] - x.ravel()[0]
answ_anal = (A - 2 + np.sqrt((A - 2) ** 2 + 4 * A)) / (2 * A)

In [19]:
z_anal = np.array([answ_anal, 1 - answ_anal])
z_anal

array([0.96754727, 0.03245273])