In [None]:
import numpy as np
import matplotlib.pyplot as plt

def gen_numbers():
    np.random.seed(444)
    np.set_printoptions(precision=3)
    array_size=5000
    #rnd = np.random.uniform(low=0.0, high=1.0, size=array_size).reshape((array_size, 1))
    #rnd = np.random.normal(0.5, 0.3, size=array_size).reshape((array_size, 1))
    rnd1 = np.random.normal(0.2, 0.1, size=(array_size // 2)).reshape((array_size // 2, 1))
    rnd2 = np.random.normal(0.7, 0.1, size=(array_size // 2)).reshape((array_size // 2, 1))
    rnd = np.concatenate((rnd1, rnd2), axis=0)
    return rnd

def plot_data_hist(ax, data, title='A uniform distribution'):
    n, bins, patches = ax.hist(x=data, bins=40, range=(0.0, 1.0), color='#0504aa', rwidth=0.85)
    plt.grid(axis='y', alpha=0.75)
    #plt.xlabel('Value')
    #plt.ylabel('Frequency')
    #plt.title(title)
    maxfreq = n.max()
    # Set a clean upper y-axis limit.
    plt.ylim(top=np.ceil(maxfreq / 100) * 100 if maxfreq % 100 else maxfreq + 100)
    #plt.show()

def softmax(X, theta=1.0, axis=None):
    """
    Compute the softmax of each element along an axis of X.

    Parameters
    ----------
    X: ND-Array. Probably should be floats. 
    theta (optional): float parameter, used as a multiplier
        prior to exponentiation. Default = 1.0
    axis (optional): axis to compute values along. Default is the 
        first non-singleton axis.

    Returns an array the same size as X. The result will sum to 1
    along the specified axis.
    """

    # make X at least 2d
    y = np.atleast_2d(X)

    # find axis
    if axis is None:
        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

    # multiply y against the theta parameter, 
    y = y * float(theta)

    # subtract the max for numerical stability
    y = y - np.expand_dims(np.max(y, axis = axis), axis)
    
    # exponentiate y
    y = np.exp(y)

    # take the sum along the specified axis
    ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)

    # finally: divide elementwise
    p = y / ax_sum

    # flatten if X was 1D
    if len(X.shape) == 1: p = p.flatten()

    return p

def attention(query, key, value):
    dot_product = np.matmul(query, np.transpose(key))
    # scale the values down
    # dot_product *= 1.0 / math.sqrt(dot_product.shape[-1])
    attention_weights = softmax(dot_product, axis=-1)
    attentions = np.matmul(attention_weights, value)
    return attentions

def plot_attentions(data):
    total_plots = 5
    attn_count_between = 1
    fig, axs = plt.subplots(total_plots, 1, sharex=True, figsize=(10,10))
    for i in range(total_plots):
        plot_data_hist(axs[i], data, 'Distribution with ' + str(i) + ' attention runs')
        if i < total_plots - 1:
            for _ in range(attn_count_between):
                data = (data + attention(data, data, data)) / 2.0
    plt.show()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plot_attentions(gen_numbers())