In [4]:
import torch
import numpy as np
from torch import nn

In [40]:
score = [[[-0.1, -0.4, -0.5, -0.6], [-0.4, -0.8, -0.4, -0.3], [-0.9, -0.5, -0.2, -0.7]],
                 [[-0.6, -0.3, -0.8, -0.9], [-0.9, -0.5, -0.7, -0.4], [-0.7, -0.8, -0.2, -0.8]]]
top_score = [[-0.1, -0.2, -0.3], [-0.2, -0.3, -0.4]]
top_beamid = [[0, 2, 1], [2, 0, 1]]
top_wordid = [[0, 2, 3], [2, 1, 3]]

In [42]:
score = torch.tensor(score)

In [44]:
top_canditate, wordId = score.max(dim = 2)
top_canditate, wordId

(tensor([[-0.1000, -0.3000, -0.2000],
         [-0.3000, -0.4000, -0.2000]]),
 tensor([[0, 3, 2],
         [1, 3, 2]]))

In [49]:
top_K_beams, top_beamid = torch.sort(top_canditate, dim = 1, descending=True)
top_K_beams, top_beamid 

(tensor([[-0.1000, -0.2000, -0.3000],
         [-0.2000, -0.3000, -0.4000]]),
 tensor([[0, 2, 1],
         [2, 0, 1]]))

In [53]:
def topK(score):
    """
    For every example in a batch, we have generated K candidates (partial sequences) in beam search (K=beam_size).
    For each candidate, we will search for the next token from a vocabulary of V tokens (V=vocab_size). 
    So, we expand the old candidates to get (K * V) new candidates in total, and then we will prune the new
    candidates to only keep the top K candidates in the beam.
    Args: 
        score: torch.FloatTensor, [batch_size, beam_size, vocab_size]. 
               This tensor has the cummulated score of selecting the next token from a vocabuary
               for each old candidate for each example from a batch.
    Return:
        top_score: torch.FloatTensor, [batch_size, beam_size], the scores of the top K new candidates in the beam after pruning 
        top_beamid: torch.LongTensor, [batch_size, beam_size], the beam ids of the top K new candidates
        top_wordid: torch.LongTensor, [batch_size, beam_size], the word ids of the next tokens to construct the top K new candidates
    
    Example:
        Assuming batch_size = 2, beam_size = 3, vocab_size = 4, we have the inputs and outputs as follows:
        score = [[[-0.1, -0.4, -0.5, -0.6], [-0.4, -0.8, -0.4, -0.3], [-0.9, -0.5, -0.2, -0.7]],
                 [[-0.6, -0.3, -0.8, -0.9], [-0.9, -0.5, -0.7, -0.4], [-0.7, -0.8, -0.2, -0.8]]]
        top_score = [[-0.1, -0.2, -0.3], [-0.2, -0.3, -0.4]]
        top_beamid = [[0, 2, 1], [2, 0, 1]]
        top_wordid = [[0, 2, 3], [2, 1, 3]]  

    """
    top_score, top_wordid = score.max(dim = 2)
    top_score, top_beamid = torch.sort(top_score, dim = 1, descending=True)
    top_wordid = torch.gather(top_wordid, dim=1, index=top_beamid)
    return top_score, top_beamid, top_wordid

In [55]:
top_score, top_beamid, top_wordid = topK(score)
top_score, top_beamid, top_wordid

(tensor([[-0.1000, -0.2000, -0.3000],
         [-0.2000, -0.3000, -0.4000]]),
 tensor([[0, 2, 1],
         [2, 0, 1]]),
 tensor([[0, 2, 3],
         [2, 1, 3]]))

In [43]:
sorted, indices = torch.sort(score, dim = 2, descending=True)
sorted, indices

(tensor([[[-0.1000, -0.4000, -0.5000, -0.6000],
          [-0.3000, -0.4000, -0.4000, -0.8000],
          [-0.2000, -0.5000, -0.7000, -0.9000]],
 
         [[-0.3000, -0.6000, -0.8000, -0.9000],
          [-0.4000, -0.5000, -0.7000, -0.9000],
          [-0.2000, -0.7000, -0.8000, -0.8000]]]),
 tensor([[[0, 1, 2, 3],
          [3, 0, 2, 1],
          [2, 1, 3, 0]],
 
         [[1, 0, 2, 3],
          [3, 1, 2, 0],
          [2, 0, 1, 3]]]))

In [None]:
sorted_1, indices_1 = torch.sort(sorted, dim = 2, descending=True)
sorted, indices

In [36]:
t_sorted = sorted[:, : , 0]
t_indices = indices[:, : , 0]


In [37]:
t_sorted, t_indices = torch

(tensor([[-0.1000, -0.3000, -0.2000],
         [-0.3000, -0.4000, -0.2000]]),
 tensor([[0, 3, 2],
         [1, 3, 2]]))

In [27]:
sortedScores = torch.sort(score.max(dim = 2).values, dim=2, descending = True)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [24]:
sortedScores


torch.return_types.sort(
values=tensor([[[-0.1000, -0.4000, -0.5000, -0.6000],
         [-0.3000, -0.4000, -0.4000, -0.8000],
         [-0.2000, -0.5000, -0.7000, -0.9000]],

        [[-0.3000, -0.6000, -0.8000, -0.9000],
         [-0.4000, -0.5000, -0.7000, -0.9000],
         [-0.2000, -0.7000, -0.8000, -0.8000]]]),
indices=tensor([[[0, 1, 2, 3],
         [3, 0, 2, 1],
         [2, 1, 3, 0]],

        [[1, 0, 2, 3],
         [3, 1, 2, 0],
         [2, 0, 1, 3]]]))

In [82]:
hiddens = torch.FloatTensor([[[0.6566, 0.2719, 0.7182, 0.1083],
                                    [0.1627, 0.4812, 0.1167, 0.4318],
                                    [0.1817, 0.2578, 0.5769, 0.2610]],

                                    [[0.9372, 0.4993, 0.5471, 0.9169],
                                    [0.8798, 0.6168, 0.6097, 0.8790],
                                    [0.6642, 0.4301, 0.5542, 0.3670]]])
beam_id = torch.LongTensor([[[1,1,1,1], [1,1,1,1], [0,0,0,0]], [[2,2,2,2], [0,0,0,0],[0,0,0,0]]])
# new_hiddens = torch.FloatTensor([[[0.1627, 0.4812, 0.1167, 0.4318],
#                                 [0.1627, 0.4812, 0.1167, 0.4318],
#                                 [0.6566, 0.2719, 0.7182, 0.1083]],

#                                 [[0.6642, 0.4301, 0.5542, 0.3670],
#                                 [0.9372, 0.4993, 0.5471, 0.9169],
#                                         [0.9372, 0.4993, 0.5471, 0.9169]]])
    # """


In [86]:
expanded_beam_idbeam_id= torch.LongTensor([[1, 1, 0], [2, 0, 0]])
expanded_beam_id = beam_id.unsqueeze(2).expand(-1, -1, hiddens.size(2))
expanded_beam_id

tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [0, 0, 0, 0]],

        [[2, 2, 2, 2],
         [0, 0, 0, 0],
         [0, 0, 0, 0]]])

In [88]:
new_hiddens = torch.gather(hiddens, 1, expanded_beam_id)
new_hiddens = torch.gather(hiddens, 1, expanded_beam_id)
new_hiddens

tensor([[[0.1627, 0.4812, 0.1167, 0.4318],
         [0.1627, 0.4812, 0.1167, 0.4318],
         [0.6566, 0.2719, 0.7182, 0.1083]],

        [[0.6642, 0.4301, 0.5542, 0.3670],
         [0.9372, 0.4993, 0.5471, 0.9169],
         [0.9372, 0.4993, 0.5471, 0.9169]]])

In [72]:
beam_id

tensor([[1, 1, 0],
        [2, 0, 0]])

In [64]:
new_hiddens

tensor([[[0.1627, 0.4812, 0.7182],
         [0.1817, 0.2719, 0.7182]]])