In [1]:
batch_size = 3
beam_size = 4
summary_len = 11
vocab_size = 113
time_step = 6
alpha = 1.0
eos_id = 0
min_value = -1e10
import numpy as np

In [2]:
def softmax(X, theta = 1.0, axis = None):
    """
    Compute the softmax of each element along an axis of X.

    Parameters
    ----------
    X: ND-Array. Probably should be floats. 
    theta (optional): float parameter, used as a multiplier
        prior to exponentiation. Default = 1.0
    axis (optional): axis to compute values along. Default is the 
        first non-singleton axis.

    Returns an array the same size as X. The result will sum to 1
    along the specified axis.
    """

    # make X at least 2d
    y = np.atleast_2d(X)

    # find axis
    if axis is None:
        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)

    # multiply y against the theta parameter, 
    y = y * float(theta)

    # subtract the max for numerical stability
    y = y - np.expand_dims(np.max(y, axis = axis), axis)

    # exponentiate y
    y = np.exp(y)

    # take the sum along the specified axis
    ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)

    # finally: divide elementwise
    p = y / ax_sum

    # flatten if X was 1D
    if len(X.shape) == 1: p = p.flatten()

    return p

In [3]:
# previous log probs
log_probs = np.random.randint(0, 10, size=(batch_size, beam_size))
log_probs.astype(np.float32)
log_probs

array([[8, 5, 0, 8],
       [0, 5, 0, 7],
       [7, 4, 5, 3]])

In [4]:
# previous sequences
seqs = np.random.randint(0, vocab_size, size=(batch_size, beam_size, summary_len))
seqs

array([[[ 94,  47,  68,  98,  68,  61,  85,  49,  92, 100,  18],
        [ 30,  68,  27, 102,   4,  82,  50,  72,  54,  69,  63],
        [ 79,  92,  23,  65,  38,  98,  40,  28,   7,  59,  32],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7]],

       [[ 30,  85, 100, 101, 100,   2,  61,  15,  63,  38,  61],
        [ 87,  66,  21,  70, 101,  67, 109,  32,  21,  11,  61],
        [ 70,  39,  51,  11,   9,  65,  70,  51,  48,  47,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23]],

       [[ 81,  22,  71,  35, 112,  99,  65,  11,  86,  25, 109],
        [ 14, 103,  43,  99,  46,  52,  39,  80,  36,  84,  27],
        [ 38,   8,  95,   0,  18,  72,  72,  45,  36,  10,  91],
        [ 38,  46,  55,  94,  37,  11,  15,  42,  37,   4, 112]]])

In [5]:
# define current log probs
step_log_probs = np.random.randn(batch_size * beam_size, vocab_size)
step_log_probs = softmax(step_log_probs)
step_log_probs

array([[0.00866284, 0.09126753, 0.76027364, ..., 0.04482722, 0.0384534 ,
        0.01495241],
       [0.13010064, 0.09403806, 0.01349366, ..., 0.06896785, 0.03216171,
        0.04853784],
       [0.02818063, 0.13119897, 0.00556592, ..., 0.09018056, 0.09089643,
        0.06391199],
       ...,
       [0.05048084, 0.02752168, 0.0215683 , ..., 0.18469691, 0.16138069,
        0.04793495],
       [0.3051425 , 0.10034695, 0.03642663, ..., 0.02293839, 0.021237  ,
        0.02232829],
       [0.01683126, 0.20786609, 0.0195119 , ..., 0.06064144, 0.09925008,
        0.26470203]])

In [6]:
vo_size = step_log_probs.shape[-1]
step_log_probs = np.reshape(step_log_probs, [batch_size, beam_size, vo_size])
curr_log_probs = np.expand_dims(log_probs, 2) + step_log_probs  # add current vocab beam with previous one word

# length penalty
length_penalty = np.power((5.0 + (time_step + 1)) / 6.0, alpha)
curr_scores = curr_log_probs / length_penalty

In [7]:
curr_scores.shape

(3, 4, 113)

In [8]:
# Select top-k candidates
curr_scores = np.reshape(curr_scores, [-1, beam_size * vo_size])# [b, beam * v]
curr_scores.shape

(3, 452)

In [9]:
def get_top_k(scores, k):
    indices = np.argsort(scores, -1)[:, :(-k-1):-1]
    ret_scores = np.stack([np.take(scores[i], indices[i], axis=-1) for i in range(scores.shape[0])], axis=0)
    return ret_scores, indices

In [10]:
# get indices like: [3, vocab + 7, vocab * 3 + 180, ...], vocab_idx + beam_offset
top_scores, top_indices = get_top_k(curr_scores, 2 * beam_size) # [b, 2 * beam]

In [11]:
# indices and scores are correct, top-k indices
top_indices

array([[  2,  53, 423,  72, 429, 447, 397, 347],
       [403, 351, 365, 345, 390, 394, 424, 402],
       [ 83,  79,  28,  29,  45,  60,   4,  56]], dtype=int64)

In [12]:
top_scores

array([[4.38013682, 4.3093334 , 4.20907989, 4.1969112 , 4.18637286,
        4.16149883, 4.15171153, 4.14103024],
       [3.75622229, 3.67183793, 3.66811912, 3.6610425 , 3.66010466,
        3.65048016, 3.63808207, 3.62871759],
       [3.7836841 , 3.64934348, 3.64421052, 3.63404299, 3.63087913,
        3.62407101, 3.62306727, 3.61905244]])

In [13]:
beam_indices = top_indices // vo_size  # [b, 2 * beam]
symbol_indices = top_indices % vo_size  # [b, 2 * beam]

In [14]:
beam_indices

array([[0, 0, 3, 0, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3, 3, 3],
       [0, 0, 0, 0, 0, 0, 0, 0]], dtype=int64)

In [15]:
symbol_indices

array([[  2,  53,  84,  72,  90, 108,  58,   8],
       [ 64,  12,  26,   6,  51,  55,  85,  63],
       [ 83,  79,  28,  29,  45,  60,   4,  56]], dtype=int64)

In [16]:
symbol_indices.shape

(3, 8)

In [17]:
def gather_2d(params, indices, name=None):
    return np.stack([np.take(params[i], indices[i], axis=0) for i in range(params.shape[0])], axis=0)

In [18]:
# Expand sequences
# Get previous decoder sequence given beam indices
candidate_seqs = gather_2d(seqs, beam_indices)  # [b, 2 * beam, q']
candidate_seqs

array([[[ 94,  47,  68,  98,  68,  61,  85,  49,  92, 100,  18],
        [ 94,  47,  68,  98,  68,  61,  85,  49,  92, 100,  18],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7],
        [ 94,  47,  68,  98,  68,  61,  85,  49,  92, 100,  18],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35,  98,  85,  66, 110,   7]],

       [[ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  99, 106,  92,  66,  23],
        [ 20,  30, 100,

In [19]:
# put current decoded word id to given sequences
candidate_seqs[:, :, time_step] = symbol_indices

In [20]:
candidate_seqs

array([[[ 94,  47,  68,  98,  68,  61,   2,  49,  92, 100,  18],
        [ 94,  47,  68,  98,  68,  61,  53,  49,  92, 100,  18],
        [102,   4,  54,   7,  91,  35,  84,  85,  66, 110,   7],
        [ 94,  47,  68,  98,  68,  61,  72,  49,  92, 100,  18],
        [102,   4,  54,   7,  91,  35,  90,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35, 108,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35,  58,  85,  66, 110,   7],
        [102,   4,  54,   7,  91,  35,   8,  85,  66, 110,   7]],

       [[ 20,  30, 100,  74,  54,  37,  64, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  12, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  26, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,   6, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  51, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  55, 106,  92,  66,  23],
        [ 20,  30, 100,  74,  54,  37,  85, 106,  92,  66,  23],
        [ 20,  30, 100,

In [194]:
log_probs = np.array([[0.] + [min_value] * (beam_size - 1)], dtype=np.float32)  # [1, beam]
log_probs = np.tile(log_probs, [batch_size, 1])  # [b, beam]
alive_scores = np.zeros_like(log_probs)  # [b, beam]

fin_seqs = np.zeros([batch_size, beam_size, summary_len], np.int32)  # [b, beam, 1]
fin_scores = np.full([batch_size, beam_size], min_value)  # [b, beam]
fin_flags = np.zeros([batch_size, beam_size], np.bool)  # [b, beam]

In [155]:
fin_scores

array([[-1.e+10, -1.e+10, -1.e+10, -1.e+10],
       [-1.e+10, -1.e+10, -1.e+10, -1.e+10],
       [-1.e+10, -1.e+10, -1.e+10, -1.e+10]])

In [164]:
# Expand sequences
# Suppress finished sequences, if current decode word is eos
flags = np.equal(symbol_indices, eos_id)  # [b, beam]
# with our 2 * beam results, we set those eos score to -inf
alive_scores = top_scores + flags * min_value  # [b, 2 * beam]
# and keep top beam ones
alive_scores, alive_indices = get_top_k(alive_scores, beam_size)  # [b, beam]
# get their correspond vocab ids
alive_symbols = gather_2d(symbol_indices, alive_indices)  # [b, beam]
# and their correspond beam indices
alive_indices = gather_2d(beam_indices, alive_indices)  # [b, beam]
# get their correspond previous sequences
alive_seqs = gather_2d(seqs, alive_indices)  # [b, beam, q']
# concat, ta_da -_-
alive_seqs[:, :, time_step] = alive_symbols
alive_log_probs = alive_scores * length_penalty

In [165]:
alive_seqs

array([[[ 10,  44,  24,  26,  79, 101,  24,  29, 110,   5, 100],
        [ 10,  44,  24,  26,  79, 101, 107,  29, 110,   5, 100],
        [ 10,  44,  24,  26,  79, 101,  88,  29, 110,   5, 100],
        [ 10,  44,  24,  26,  79, 101, 101,  29, 110,   5, 100]],

       [[ 86,  48,  50,   3,   8, 108,  33,  65,  45,  27,  85],
        [ 86,  48,  50,   3,   8, 108,  43,  65,  45,  27,  85],
        [ 86,  48,  50,   3,   8, 108, 110,  65,  45,  27,  85],
        [ 86,  48,  50,   3,   8, 108,  93,  65,  45,  27,  85]],

       [[ 29,  96,  37,  73,  50,  81,  39,  83,  84,  95, 109],
        [ 29,  96,  37,  73,  50,  81,  51,  83,  84,  95, 109],
        [ 29,  96,  37,  73,  50,  81, 100,  83,  84,  95, 109],
        [ 29,  96,  37,  73,  50,  81,  91,  83,  84,  95, 109]]])

In [193]:
# Select finished sequences
step_fin_scores = top_scores + (1.0 - flags) * min_value  # [b, 2 * beam]
fin_flags = np.concatenate([fin_flags, flags], axis=1)  # [batch, 3 * beam]
fin_scores = np.concatenate([fin_scores, step_fin_scores], axis=1)
fin_scores, fin_indices = get_top_k(fin_scores, beam_size)  # [b, beam]
fin_flags = gather_2d(fin_flags, fin_indices)

# we always keep beam fin_seqs along with their scores and use current candidate to update
fin_seqs = np.concatenate([fin_seqs, candidate_seqs], axis=1)  # [b, 3 * beam, q' + 1]
fin_seqs = gather_2d(fin_seqs, fin_indices)  # [b, beam?, q' + 1]

In [169]:
fin_flags = np.concatenate([fin_flags, flags], axis=1)  # [batch, 3 * beam]

In [171]:
fin_flags.shape

(3, 12)

In [172]:
fin_scores = np.concatenate([fin_scores, step_fin_scores], axis=1)
fin_scores

array([[-1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10],
       [-1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10],
       [-1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10, -1.00000000e+10,
        -1.00000000e+10, -1.00000000e+10,  4.09773027e+00]])

In [173]:
fin_scores, fin_indices = get_top_k(fin_scores, beam_size)  # [b, beam]

In [175]:
fin_indices

array([[ 4,  5,  6,  7],
       [ 4,  5,  6,  7],
       [11,  4,  5,  6]], dtype=int64)

In [176]:
fin_flags = gather_2d(fin_flags, fin_indices)

In [177]:
fin_flags

array([[False, False, False, False],
       [False, False, False, False],
       [ True, False, False, False]])

In [195]:
fin_seqs = np.concatenate([fin_seqs, candidate_seqs], axis=1)

In [192]:
candidate_seqs[-1][-1][-1] = 0

IndexError: index 11 is out of bounds for axis 0 with size 8

In [196]:
fin_seqs[-1]

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [ 29,  96,  37,  73,  50,  81,  39,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81,  51,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81, 100,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81,  91,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81,  30,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81,  94,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81, 105,  83,  84,  95, 109],
       [ 29,  96,  37,  73,  50,  81,  25,  83,  84,  95,   0]])

In [197]:
fin_seqs = gather_2d(fin_seqs, fin_indices)  # [b, beam?, q' + 1]

In [198]:
fin_seqs

array([[[ 10,  44,  24,  26,  79, 101,  24,  29, 110,   5, 100],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 10,  44,  24,  26,  79, 101, 107,  29, 110,   5, 100],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]],

       [[ 86,  48,  50,   3,   8, 108,  33,  65,  45,  27,  85],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 86,  48,  50,   3,   8, 108,  43,  65,  45,  27,  85],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]],

       [[ 29,  96,  37,  73,  50,  81,  25,  83,  84,  95,   0],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 29,  96,  37,  73,  50,  81,  39,  83,  84,  95, 109],
        [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]]])

In [206]:
np.any(fin_flags)

True

In [205]:
fin_flags[-1][-1] = True

In [25]:
candidate_seqs[:, 0, :].shape

(3, 11)