In [110]:
import torch

@torch.jit.script
def beam_search(log_probs: torch.Tensor, beam_size: int):
    """
    Performs beam search on a tensor of log probabilities.

    Args:
        log_probs (torch.Tensor): Tensor of shape (b, t, v) containing log probabilities.
        beam_size (int): Number of beams to keep at each time step.

    Returns:
        sequences (torch.Tensor): Tensor of shape (b, beam_size, t) containing the top sequences.
        scores (torch.Tensor): Tensor of shape (b, beam_size) containing the scores of the top sequences.
    """
    
    b, t, v = log_probs.size()
    
    initial_beam_size = min(beam_size, v) # At the very first step (time step 0), we can't have more beams than the vocabulary size. This line ensures that the initial number of beams considered doesn't exceed the number of possible first tokens.

    topk_scores, topk_indices = torch.topk(log_probs[:, 0, :], initial_beam_size, dim=-1) # Returns the k largest elements of the given input tensor along a given dimension
    sequences = topk_indices.unsqueeze(-1)  # (b, initial_beam_size, 1)
    scores = topk_scores  # (b, initial_beam_size)

    for step in range(1, t):
        # Expand the current sequences with all possible next tokens
        current_log_probs = log_probs[:, step, :].unsqueeze(1)  # (b, 1, v)
        expanded_scores = scores.unsqueeze(-1) + current_log_probs  # (b, beam_size, v)
        flat_scores = expanded_scores.view(b, -1)  # (b, beam_size * v)

        # Select the top-k scores and their corresponding indices
        topk_flat_scores, topk_indices = flat_scores.topk(beam_size, dim=-1)  # (b, beam_size)
        beam_indices = topk_indices // v  # Indices of sequences to expand
        token_indices = topk_indices % v  # New tokens to append

        # Gather the sequences to expand and append the new tokens
        sequences = torch.gather(sequences, 1, beam_indices.unsqueeze(-1).expand(-1, -1, sequences.size(-1)))
        sequences = torch.cat([sequences, token_indices.unsqueeze(-1)], dim=-1)  # (b, beam_size, step+1)

        # Update the scores
        scores = topk_flat_scores

    return sequences, scores.unsqueeze(-1)


In [130]:
batch_size = 2
sequence_length = 100
vocab_size = 29
beam_size = 2

# Simulate log probabilities
log_probs = torch.randn(batch_size, sequence_length, vocab_size).log_softmax(dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_probs = log_probs.to(device)

# Perform beam search
sequences, scores = beam_search(log_probs, beam_size)

print("Top sequences:", sequences) # bsz, beamsize,seq_len
print("Scores:", scores) # bsz, beamsize,1

Top sequences: tensor([[[ 3, 19, 19,  1, 11,  8,  8,  2,  4, 26, 27, 17,  7, 12, 21,  7,  9,
           7, 15,  1, 25, 13, 25,  3, 13, 28, 25, 20, 24,  2,  8, 23, 26,  1,
          22,  6, 22,  5, 12,  8, 10, 12, 24, 22, 28,  8, 17, 27,  6, 10,  6,
          23, 25, 18, 22,  9,  0,  4, 10, 26,  4, 20,  5, 21, 17, 25, 19,  0,
          11, 22, 10, 11, 15,  1, 25, 17,  9, 25, 23, 10, 22, 15,  3,  5,  1,
          28, 12, 16, 28, 22, 22, 14, 18, 26, 17, 27,  1, 24, 19, 17],
         [ 3, 19, 19,  1, 11,  8,  8,  2,  4, 26, 27, 17,  7, 12, 21,  7,  9,
           7, 15,  1, 16, 13, 25,  3, 13, 28, 25, 20, 24,  2,  8, 23, 26,  1,
          22,  6, 22,  5, 12,  8, 10, 12, 24, 22, 28,  8, 17, 27,  6, 10,  6,
          23, 25, 18, 22,  9,  0,  4, 10, 26,  4, 20,  5, 21, 17, 25, 19,  0,
          11, 22, 10, 11, 15,  1, 25, 17,  9, 25, 23, 10, 22, 15,  3,  5,  1,
          28, 12, 16, 28, 22, 22, 14, 18, 26, 17, 27,  1, 24, 19, 17]],

        [[21, 25, 13, 12,  9, 10,  6,  0, 23, 14,  5, 18,  9,

In [131]:
sequences.shape, scores.shape

(torch.Size([2, 2, 100]), torch.Size([2, 2, 1]))

In [132]:
path_probs = torch.gather(log_probs, 2, sequences.transpose(1,2)).transpose(1,2) # bsz, beamsize, T
path_probs

tensor([[[-2.3404, -2.5064, -2.4116, -2.1333, -1.6098, -1.8109, -1.6270,
          -2.2315, -1.5690, -2.1118, -2.1869, -1.9060, -2.2188, -1.4155,
          -0.9839, -1.8076, -2.1683, -1.6998, -1.6267, -2.0404, -2.3530,
          -2.4981, -1.0137, -1.8512, -2.0450, -1.8122, -1.8771, -2.1188,
          -2.2018, -2.0547, -2.2928, -1.8956, -2.3741, -1.9083, -1.6760,
          -2.1687, -1.5837, -1.8111, -1.9235, -2.1176, -1.9232, -1.8182,
          -1.8214, -1.8918, -2.1509, -1.3543, -1.9072, -2.0839, -2.1099,
          -1.4305, -1.7079, -2.2080, -2.1379, -2.0946, -1.8048, -2.0222,
          -1.9320, -2.2414, -1.7560, -2.4706, -1.9054, -1.7765, -1.0247,
          -1.3049, -2.0192, -2.2386, -2.2207, -1.6373, -1.0833, -1.5956,
          -2.0073, -2.2533, -1.4145, -1.5185, -2.3554, -1.4078, -2.1353,
          -2.3884, -1.4423, -1.0621, -2.0759, -1.6714, -1.6209, -1.6636,
          -1.9561, -1.7320, -1.8773, -1.6799, -1.7005, -1.6985, -1.5929,
          -2.0375, -1.2639, -1.8439, -2.3617, -1.66

In [133]:
mean = scores.mean(dim=1, keepdim=True)
std = scores.std(dim=1, keepdim=True)

scores = (scores - mean) / std

In [134]:
path_probs*scores

tensor([[[-1.6549, -1.7723, -1.7052, -1.5085, -1.1383, -1.2805, -1.1505,
          -1.5779, -1.1094, -1.4933, -1.5464, -1.3478, -1.5690, -1.0009,
          -0.6957, -1.2781, -1.5332, -1.2019, -1.1503, -1.4428, -1.6639,
          -1.7664, -0.7168, -1.3090, -1.4460, -1.2814, -1.3273, -1.4982,
          -1.5569, -1.4529, -1.6213, -1.3404, -1.6787, -1.3494, -1.1851,
          -1.5335, -1.1199, -1.2806, -1.3601, -1.4974, -1.3599, -1.2856,
          -1.2879, -1.3377, -1.5209, -0.9576, -1.3486, -1.4735, -1.4919,
          -1.0115, -1.2077, -1.5613, -1.5117, -1.4811, -1.2762, -1.4299,
          -1.3661, -1.5849, -1.2416, -1.7470, -1.3473, -1.2562, -0.7246,
          -0.9227, -1.4278, -1.5829, -1.5703, -1.1577, -0.7660, -1.1282,
          -1.4194, -1.5934, -1.0002, -1.0738, -1.6655, -0.9955, -1.5099,
          -1.6888, -1.0199, -0.7510, -1.4679, -1.1818, -1.1461, -1.1763,
          -1.3832, -1.2247, -1.3275, -1.1879, -1.2024, -1.2010, -1.1264,
          -1.4407, -0.8937, -1.3038, -1.6699, -1.17

In [135]:
sequences # bsz,beam,T

tensor([[[ 3, 19, 19,  1, 11,  8,  8,  2,  4, 26, 27, 17,  7, 12, 21,  7,  9,
           7, 15,  1, 25, 13, 25,  3, 13, 28, 25, 20, 24,  2,  8, 23, 26,  1,
          22,  6, 22,  5, 12,  8, 10, 12, 24, 22, 28,  8, 17, 27,  6, 10,  6,
          23, 25, 18, 22,  9,  0,  4, 10, 26,  4, 20,  5, 21, 17, 25, 19,  0,
          11, 22, 10, 11, 15,  1, 25, 17,  9, 25, 23, 10, 22, 15,  3,  5,  1,
          28, 12, 16, 28, 22, 22, 14, 18, 26, 17, 27,  1, 24, 19, 17],
         [ 3, 19, 19,  1, 11,  8,  8,  2,  4, 26, 27, 17,  7, 12, 21,  7,  9,
           7, 15,  1, 16, 13, 25,  3, 13, 28, 25, 20, 24,  2,  8, 23, 26,  1,
          22,  6, 22,  5, 12,  8, 10, 12, 24, 22, 28,  8, 17, 27,  6, 10,  6,
          23, 25, 18, 22,  9,  0,  4, 10, 26,  4, 20,  5, 21, 17, 25, 19,  0,
          11, 22, 10, 11, 15,  1, 25, 17,  9, 25, 23, 10, 22, 15,  3,  5,  1,
          28, 12, 16, 28, 22, 22, 14, 18, 26, 17, 27,  1, 24, 19, 17]],

        [[21, 25, 13, 12,  9, 10,  6,  0, 23, 14,  5, 18,  9, 19, 10, 22, 14

In [136]:
# If using PyTorch
import torch

vocab = [' ', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '?']
len(vocab) # 29-1 is the shape 

idx2char = {i:c for i,c in enumerate(vocab)}

def ctc_merge_string(s: str, blank_char='?'):
    merged = []
    prev = None
    for c in s:
        if c == blank_char:
            prev = None  # reset repetition check on blank
            continue
        if c != prev:
            merged.append(c)
        prev = c
    return ''.join(merged)
   

# decode_seq(sequences, vocab)[0]
# sequences.shape torch.Size([128, 2, 100])

In [137]:
# 1) move to CPU & to plain Python list of lists
sentences = []
for b in range(sequences.shape[0]):
    rows = sequences[b].cpu().tolist() # beam,T
    decoded_beams = [''.join(idx2char[i] for i in row) for row in rows]
    sentences.append(decoded_beams)
print(sentences)

[["BRR'JGGACYZPFKTFHFN'XLXBL?XSWAGVY'UEUDKGIKWU?GPZEIEVXQUH CIYCSDTPXR JUIJN'XPHXVIUNBD'?KO?UUMQYPZ'WRP", "BRR'JGGACYZPFKTFHFN'OLXBL?XSWAGVY'UEUDKGIKWU?GPZEIEVXQUH CIYCSDTPXR JUIJN'XPHXVIUNBD'?KO?UUMQYPZ'WRP"], ["TXLKHIE VMDQHRIUMYAFBNDTDIXRQAUFXXQ 'NFLBNDBPIWUN'C'VPD K QXNBIMYPSAAS'HBPGFANZKYXNYIKOMURBDGCMPHFKX", "TXLKHIE VMDQHRIUMYAFBNDTDIXRQAUFXXQ 'NFLBNDBPIWUN'C'VPD KDQXNBIMYPSAAS'HBPGFANZKYXNYIKOMURBDGCMPHFKX"]]


In [138]:
for i in sentences:
    for j in i:
        print(ctc_merge_string(j))

BR'JGACYZPFKTFHFN'XLXBLXSWAGVY'UEUDKGIKWUGPZEIEVXQUH CIYCSDTPXR JUIJN'XPHXVIUNBD'KOUMQYPZ'WRP
BR'JGACYZPFKTFHFN'OLXBLXSWAGVY'UEUDKGIKWUGPZEIEVXQUH CIYCSDTPXR JUIJN'XPHXVIUNBD'KOUMQYPZ'WRP
TXLKHIE VMDQHRIUMYAFBNDTDIXRQAUFXQ 'NFLBNDBPIWUN'C'VPD K QXNBIMYPSAS'HBPGFANZKYXNYIKOMURBDGCMPHFKX
TXLKHIE VMDQHRIUMYAFBNDTDIXRQAUFXQ 'NFLBNDBPIWUN'C'VPD KDQXNBIMYPSAS'HBPGFANZKYXNYIKOMURBDGCMPHFKX


In [140]:
sequences.shape

torch.Size([2, 2, 100])

In [29]:
with torch.autograd.profiler.profile(use_device = 'cuda') as prof:
    sequences, scores = beam_search(log_probs, beam_size)
print(prof.key_averages().table(sort_by="cuda_time_total"))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            beam_search        18.14%     695.342us        99.55%       3.817ms       3.817ms     761.000us        19.91%       3.823ms       3.823ms             1  
                                     aten::floor_divide         5.20%     199.467us        15.51%     594.651us      33.036us     315.000us         8.24%     660.000us      36.667us            18  
         