In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Sparse Attention is Consistent with Feature Importance

A simple notebook for reproducing our results.

In [None]:
from importlib import reload
import logging
from ane_research.config import Config
from ane_research.utils.logger import initialize_logger
from IPython.display import Image, display

# clear any logging relics
reload(logging)
logger = initialize_logger(Config.logger_name, in_notebook=True)

## Option 1: Run All Experiments

In [None]:
from ane_research.utils.experiments import run_all_experiments_in_dir

experiments = run_all_experiments_in_dir('experiments')
for experiment in experiments:
    graphs = experiment.get_figure_paths()
    for graph in graphs:
        display(Image(filename=graph))

## Option 2: Run Specific Experiments

In [None]:
from ane_research.utils.experiments import run_experiment

# Edit as necessary - These are the 8 runs we performed in the paper
experiments_list = [
    'experiments/sst_tanh_sparsemax.jsonnet',
    'experiments/sst_tanh_softmax.jsonnet',
    'experiments/sst_sdp_sparsemax.jsonnet',
    'experiments/sst_sdp_softmax_jsonnet',
    'experiments/imdb_tanh_sparsemax.jsonnet',
    'experiments/imdb_tanh_softmax.jsonnet',
    'experiments/imdb_sdp_sparsemax.jsonnet',
    'experiments/imdb_sdp_softmax_jsonnet',
]

for experiment_file in experiments_list:
    experiment = run_experiment(experiment_file)
    graphs = experiment.get_figure_paths()
    for graph in graphs:
        # These graphs are big!
        display(Image(filename=graph))
        

In [None]:
import math
import numpy as np
import numbers
import scipy.stats as stats
import scipy.special as special
from typing import Any, Tuple

def kendall_top_k(a: Any, b: Any, k: int = None, kIsNonZero: bool = False, p: float = 0.5) -> Tuple[float, int]:
    """Compute the top-k kendall-tau correlation metric for the given ranked lists.
    Args:
        a (ArrayLike):
            The first ranked list. Can be any array-like type (e.g. list, numpy array or cpu-bound tensor)
        b (ArrayLike):
            The second ranked list. Can be any array-like type (e.g. list, numpy array or cpu-bound tensor)
        k (int, optional):
            Only the top "k" elements are compared. Defaults to the size of the first list
        kIsNonZero (bool, optional):
            If specified, overrides k to be the minimum number of non-zero elements in either list. Defaults to False
        p (float, optional):
            The penalty parameter in the range [0, 1]. Defaults to the neutral case, p = 0.5
    Raises:
        ValueError: If p is not defined as described or if the lists are not equal in length
    Returns:
        Tuple[float, int]: A tuple of the computed correlation and the value used for k
    """
    if not (isinstance(p, numbers.Real) and p >= 0 and p <= 1):
        raise ValueError("The penalty parameter p must be numeric and in the range [0,1]")

    x = np.array(a).flatten()
    y = np.array(b).flatten()

    if x.size != y.size:
        raise ValueError("The ranked lists must have same lengths")

    if kIsNonZero:
        k = min(np.count_nonzero(x), np.count_nonzero(y))
    elif k is None:
        k = x.size

    k = min(k, x.size)

    # indices of the top k arguments e.g., [1, 2, 3, 4] with k = 3 --> [1, 2 ,3]
    x_top_k = np.argpartition(x, -k)[-k:]
    # ranks of all arguments (projection from a list onto the domain [1...n]) e.g., [55, 42, 89, 100] --> [3, 4, 2, 1]
    x_ranks = np.full(x.size, x.size + 1) - stats.rankdata(x)

    y_top_k = np.argpartition(y, -k)[-k:]
    y_ranks = np.full(y.size, y.size + 1) - stats.rankdata(y)

    print(x_ranks)
    print(y_ranks)
    
    # Using the explicit notation of Fagin et al. with references to their equation numbers
    Z = np.intersect1d(x_top_k, y_top_k)
    S = np.setdiff1d(x_top_k, y_top_k)
    T = np.setdiff1d(y_top_k, x_top_k)
    z = Z.size
    print(Z)
    print(S)
    print(T)

    # Equation 1: i and j appear in both top k lists. Penalize per the number of shared pairs that are discordant
    # Code partially taken from scipy.stats.kendalltau
    rx, ry = x[Z], y[Z]
    eqn1 = np.sum([((ry[i + 1:] < ry[i]) * (rx[i + 1:] > rx[i])).sum() for i in range(len(ry) - 1)], dtype=float)

    # Equation 2: i and j both appear in one top k list, and exactly one of i or j appears in the other
    eqn2 = (k - z) * (k + z + 1) - sum(x_ranks[S]) - sum(y_ranks[T])
    
    assert (not k == z) or eqn2 == 0
    
    # Equation 3: i, but not j, appears in one top k list and j, but not i, appears in the other
    eqn3 = (k - z) ** 2

    assert (not k == z) or eqn3 == 0
    
    # Equation 4: i and j both appear in one top k list, but neither i nor j appears in the other
    eqn4 = 2 * p * special.comb(k - z, 2)

    assert (not k == z) or eqn4 == 0
    
    kendall_distance_with_penalty = eqn1 + eqn2 + eqn3 + eqn4

    # Normalize the distance to a correlation in the range [-1, 1]
    correlation = kendall_distance_with_penalty / special.comb(S.size + T.size + z, 2)
    correlation *= -2
    correlation += 1

    return (correlation, k)
    
from scipy.stats import kendalltau

    
a = [0.00924963, 0.04445712, 0.02014468, 0.04744543, 0.01616605, 0.03727183,
 0.02536283, 0.02776813, 0.01844794, 0.10662094, 0.10609859, 0.0161676,
 0.16745505, 0.02182889, 0.02090209, 0.00257502, 0.03120976, 0.03747579,
 0.00226188, 0.01272965, 0.00909619]
b = [0.00896137, 0.05222384, 0.02614146, 0.05706653, 0.01223692, 0.05689022,
 0.03800762, 0.03169076, 0.01273555, 0.14963414, 0.13315924, 0.02639366,
 0.26344773, 0.03380115, 0.02689381, 0.01485158, 0.0552543,  0.06618203,
 0.01013477, 0.02311206, 0.01089478]

a = [18, 5, 13, 4, 16, 7, 10, 9, 14, 2, 3, 15, 1, 11, 12, 20,  8,  6, 21, 17, 19]
b = [21, 8, 14, 5, 18, 6, 9, 11, 17, 2, 3, 13, 1, 10, 12, 16,  7,  4, 20, 15, 19]

k = len(a)
assert len(a) == len(b)

print(kendall_top_k(a, b, k=k))
print(kendalltau(a, b, nan_policy='propagate'))
