In [32]:
import torch

@torch.jit.script
def beam_search(log_probs: torch.Tensor, beam_size: int):
    """
    Performs beam search on a tensor of log probabilities.

    Args:
        log_probs (torch.Tensor): Tensor of shape (b, t, v) containing log probabilities.
        beam_size (int): Number of beams to keep at each time step.

    Returns:
        sequences (torch.Tensor): Tensor of shape (b, beam_size, t) containing the top sequences.
        scores (torch.Tensor): Tensor of shape (b, beam_size) containing the scores of the top sequences.
    """
    
    b, t, v = log_probs.size()
    
    initial_beam_size = min(beam_size, v) # At the very first step (time step 0), we can't have more beams than the vocabulary size. This line ensures that the initial number of beams considered doesn't exceed the number of possible first tokens.

    topk_scores, topk_indices = torch.topk(log_probs[:, 0, :], initial_beam_size, dim=-1) # Returns the k largest elements of the given input tensor along a given dimension
    sequences = topk_indices.unsqueeze(-1)  # (b, initial_beam_size, 1)
    scores = topk_scores  # (b, initial_beam_size)

    for step in range(1, t):
        # Expand the current sequences with all possible next tokens
        current_log_probs = log_probs[:, step, :].unsqueeze(1)  # (b, 1, v)
        expanded_scores = scores.unsqueeze(-1) + current_log_probs  # (b, beam_size, v)
        flat_scores = expanded_scores.view(b, -1)  # (b, beam_size * v)

        # Select the top-k scores and their corresponding indices
        topk_flat_scores, topk_indices = flat_scores.topk(beam_size, dim=-1)  # (b, beam_size)
        beam_indices = topk_indices // v  # Indices of sequences to expand
        token_indices = topk_indices % v  # New tokens to append

        # Gather the sequences to expand and append the new tokens
        sequences = torch.gather(sequences, 1, beam_indices.unsqueeze(-1).expand(-1, -1, sequences.size(-1)))
        sequences = torch.cat([sequences, token_indices.unsqueeze(-1)], dim=-1)  # (b, beam_size, step+1)

        # Update the scores
        scores = topk_flat_scores

    return sequences, scores.unsqueeze(-1)


In [33]:
batch_size = 2
sequence_length = 10
vocab_size = 30
beam_size = 2

# Simulate log probabilities
log_probs = torch.randn(batch_size, sequence_length, vocab_size).log_softmax(dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_probs = log_probs.to(device)

# Perform beam search
sequences, scores = beam_search(log_probs, beam_size)

print("Top sequences:", sequences)
print("Scores:", scores)


Top sequences: tensor([[[24, 12, 20, 25, 24,  1, 24, 13, 23,  4],
         [24, 12, 20, 25, 24,  1, 24, 13, 16,  4]],

        [[ 1,  3, 13,  1, 14,  7, 28, 28, 21,  4],
         [ 1,  3,  6,  1, 14,  7, 28, 28, 21,  4]]], device='cuda:0')
Scores: tensor([[[-17.9353],
         [-17.9991]],

        [[-18.7701],
         [-18.9559]]], device='cuda:0')


In [34]:
sequences.shape, scores.shape

(torch.Size([2, 2, 10]), torch.Size([2, 2, 1]))

In [38]:
mean_scores = torch.mean(scores, dim=1, keepdim=True)
std_scores = torch.std(scores, dim=1, keepdim=True)

(scores - mean_scores ) / std_scores

tensor([[[ 0.7071],
         [-0.7071]],

        [[ 0.7071],
         [-0.7071]]], device='cuda:0')

In [29]:
with torch.autograd.profiler.profile(use_device = 'cuda') as prof:
    sequences, scores = beam_search(log_probs, beam_size)
print(prof.key_averages().table(sort_by="cuda_time_total"))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            beam_search        18.14%     695.342us        99.55%       3.817ms       3.817ms     761.000us        19.91%       3.823ms       3.823ms             1  
                                     aten::floor_divide         5.20%     199.467us        15.51%     594.651us      33.036us     315.000us         8.24%     660.000us      36.667us            18  
         