In [None]:
def _decode_one(seq):
    chars = [vocab[i] for i in seq]
    raw = ''.join(chars)

In [2]:
vocab = ['p', ' ', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '?']

In [4]:
import torch
seq = torch.tensor([1,2,3])

In [5]:
vocab[seq]

TypeError: only integer tensors of a single element can be converted to an index

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 -16,
 -15,
 -14,
 -13,
 -12,
 -11,
 -10,
 -9,
 -8,
 -7,
 -6,
 -5,
 -4,
 -3,
 -2,
 -1]

: 

In [1]:
import torch

@torch.jit.script
def beam_search(log_probs: torch.Tensor, beam_size: int):
    """
    Performs beam search on a tensor of log probabilities.

    Args:
        log_probs (torch.Tensor): Tensor of shape (b, t, v) containing log probabilities.
        beam_size (int): Number of beams to keep at each time step.

    Returns:
        sequences (torch.Tensor): Tensor of shape (b, beam_size, t) containing the top sequences.
        scores (torch.Tensor): Tensor of shape (b, beam_size) containing the scores of the top sequences.
    """
    
    b, t, v = log_probs.size()
    
    initial_beam_size = min(beam_size, v) # At the very first step (time step 0), we can't have more beams than the vocabulary size. This line ensures that the initial number of beams considered doesn't exceed the number of possible first tokens.

    topk_scores, topk_indices = torch.topk(log_probs[:, 0, :], initial_beam_size, dim=-1) # Returns the k largest elements of the given input tensor along a given dimension
    sequences = topk_indices.unsqueeze(-1)  # (b, initial_beam_size, 1)
    scores = topk_scores  # (b, initial_beam_size)

    for step in range(1, t):
        # Expand the current sequences with all possible next tokens
        current_log_probs = log_probs[:, step, :].unsqueeze(1)  # (b, 1, v)
        expanded_scores = scores.unsqueeze(-1) + current_log_probs  # (b, beam_size, v)
        flat_scores = expanded_scores.view(b, -1)  # (b, beam_size * v)

        # Select the top-k scores and their corresponding indices
        topk_flat_scores, topk_indices = flat_scores.topk(beam_size, dim=-1)  # (b, beam_size)
        beam_indices = topk_indices // v  # Indices of sequences to expand
        token_indices = topk_indices % v  # New tokens to append

        # Gather the sequences to expand and append the new tokens
        sequences = torch.gather(sequences, 1, beam_indices.unsqueeze(-1).expand(-1, -1, sequences.size(-1)))
        sequences = torch.cat([sequences, token_indices.unsqueeze(-1)], dim=-1)  # (b, beam_size, step+1)

        # Update the scores
        scores = topk_flat_scores

    return sequences, scores.unsqueeze(-1)


In [2]:
batch_size = 128
sequence_length = 100
vocab_size = 29
beam_size = 2

# Simulate log probabilities
log_probs = torch.randn(batch_size, sequence_length, vocab_size).log_softmax(dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_probs = log_probs.to(device)

# Perform beam search
sequences, scores = beam_search(log_probs, beam_size)

print("Top sequences:", sequences) # bsz, beamsize,seq_len
print("Scores:", scores) # bsz, beamsize,1

Top sequences: tensor([[[ 4,  2, 17,  ...,  9,  0, 19],
         [ 4,  2, 17,  ...,  9,  0, 19]],

        [[23, 13, 26,  ..., 19, 10, 23],
         [23, 13, 26,  ..., 19, 10, 23]],

        [[26, 17,  9,  ..., 12, 27, 19],
         [26, 17,  9,  ..., 12, 27, 19]],

        ...,

        [[23, 25, 17,  ..., 12, 18,  9],
         [23, 25, 17,  ..., 12, 18,  9]],

        [[21, 25, 16,  ..., 28,  3, 26],
         [21, 25, 16,  ..., 28,  3, 26]],

        [[24,  8, 14,  ...,  7, 28, 16],
         [24,  8, 14,  ...,  7, 28, 16]]], device='cuda:0')
Scores: tensor([[[-176.1211],
         [-176.1288]],

        [[-186.1906],
         [-186.1932]],

        [[-185.5934],
         [-185.6017]],

        [[-181.5307],
         [-181.5324]],

        [[-184.2902],
         [-184.2911]],

        [[-178.2433],
         [-178.2538]],

        [[-174.4158],
         [-174.4240]],

        [[-179.0708],
         [-179.0797]],

        [[-178.4495],
         [-178.4589]],

        [[-178.6902],
      

In [3]:
sequences.shape, scores.shape

(torch.Size([128, 2, 100]), torch.Size([128, 2, 1]))

In [4]:
path_probs = torch.gather(log_probs, 2, sequences.transpose(1,2)).transpose(1,2) # bsz, beamsize, T
path_probs

tensor([[[-2.4390, -1.8395, -1.1481,  ..., -1.7797, -1.0971, -2.0236],
         [-2.4390, -1.8395, -1.1481,  ..., -1.7797, -1.0971, -2.0236]],

        [[-1.3271, -1.8809, -1.8068,  ..., -1.6930, -2.0635, -1.4433],
         [-1.3271, -1.8809, -1.8068,  ..., -1.6930, -2.0635, -1.4433]],

        [[-1.9541, -2.1843, -2.0299,  ..., -2.4101, -1.3928, -1.9366],
         [-1.9541, -2.1843, -2.0299,  ..., -2.4101, -1.3928, -1.9366]],

        ...,

        [[-1.8473, -1.8991, -2.1747,  ..., -1.4464, -1.9908, -2.1479],
         [-1.8473, -1.8991, -2.1747,  ..., -1.4464, -1.9908, -2.1479]],

        [[-1.3115, -1.7654, -1.5687,  ..., -1.4245, -1.6708, -2.2221],
         [-1.3115, -1.7654, -1.5687,  ..., -1.4245, -1.6708, -2.2221]],

        [[-2.2707, -2.4410, -2.2451,  ..., -1.7021, -1.4997, -2.0378],
         [-2.2707, -2.4410, -2.2451,  ..., -1.7021, -1.4997, -2.0378]]],
       device='cuda:0')

In [5]:
mean = scores.mean(dim=1, keepdim=True)
std = scores.std(dim=1, keepdim=True)

scores = (scores - mean) / std

In [6]:
path_probs*scores

tensor([[[-1.7246, -1.3007, -0.8118,  ..., -1.2585, -0.7758, -1.4309],
         [ 1.7246,  1.3007,  0.8118,  ...,  1.2585,  0.7758,  1.4309]],

        [[-0.9329, -1.3222, -1.2701,  ..., -1.1901, -1.4506, -1.0146],
         [ 0.9439,  1.3378,  1.2851,  ...,  1.2041,  1.4677,  1.0265]],

        [[-1.3818, -1.5446, -1.4354,  ..., -1.7042, -0.9848, -1.3694],
         [ 1.3818,  1.5446,  1.4354,  ...,  1.7042,  0.9848,  1.3694]],

        ...,

        [[-1.3062, -1.3429, -1.5377,  ..., -1.0227, -1.4077, -1.5188],
         [ 1.3062,  1.3429,  1.5377,  ...,  1.0227,  1.4077,  1.5188]],

        [[-0.9274, -1.2483, -1.1093,  ..., -1.0073, -1.1814, -1.5712],
         [ 0.9274,  1.2483,  1.1093,  ...,  1.0073,  1.1814,  1.5712]],

        [[-1.6056, -1.7260, -1.5875,  ..., -1.2036, -1.0605, -1.4409],
         [ 1.6056,  1.7260,  1.5875,  ...,  1.2036,  1.0605,  1.4409]]],
       device='cuda:0')

In [7]:
sequences # bsz,beam,T

tensor([[[ 4,  2, 17,  ...,  9,  0, 19],
         [ 4,  2, 17,  ...,  9,  0, 19]],

        [[23, 13, 26,  ..., 19, 10, 23],
         [23, 13, 26,  ..., 19, 10, 23]],

        [[26, 17,  9,  ..., 12, 27, 19],
         [26, 17,  9,  ..., 12, 27, 19]],

        ...,

        [[23, 25, 17,  ..., 12, 18,  9],
         [23, 25, 17,  ..., 12, 18,  9]],

        [[21, 25, 16,  ..., 28,  3, 26],
         [21, 25, 16,  ..., 28,  3, 26]],

        [[24,  8, 14,  ...,  7, 28, 16],
         [24,  8, 14,  ...,  7, 28, 16]]], device='cuda:0')

In [8]:
# If using PyTorch
import torch

vocab = [' ', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '?']
len(vocab) # 29-1 is the shape 

idx2char = {i:c for i,c in enumerate(vocab)}

def ctc_merge_string(s: str, blank_char='?'):
    merged = []
    prev = None
    for c in s:
        if c == blank_char:
            prev = None  # reset repetition check on blank
            continue
        if c != prev:
            merged.append(c)
        prev = c
    return ''.join(merged)
   

# decode_seq(sequences, vocab)[0]
# sequences.shape torch.Size([128, 2, 100])

In [9]:
# 1) move to CPU & to plain Python list of lists
sentences = []
for b in range(sequences.shape[0]):
    print(b)
    rows = sequences[b].cpu().tolist() # beam,T
    decoded_beams = [ctc_merge_string( ''.join(idx2char[i] for i in row) ) for row in rows]
    sentences.append(decoded_beams)
print(sentences)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
[["CAPOF'COKLCQZMCOBRWIZTVFPC JIAVSMELKDSVT'JSMHIPXPC'SGAIYNSDPHGEUCQR'ZBRARET SIOVOARFRNDAGUH R", "CAPOF'COKLCQZMCOBRWIZTVFPC JIAVSMELKDSVT'JSMHIPXPC'SGAIYNSDPHUEUCQR'ZBRARET SIOVOARFRNDAGUH R"], ["VLYS GSFHFMWYKET NXGZAZS'DFWCZ' I'Y'UQ'FOIC W OCHZFXHISWLMBMJFGDZUJDXCXNYFIXJDVIVJYJRPRDRIV", "VLYS GSFHFMWYKET NXGZAZS'DFWCZ' I'Y'UQ'FOIC W OCHZFXHISWLMBMJFGDZUJDXCXNYFIXJDVIVJYJRPGDRIV"], ["YPHEV AZCNQBRLOCPHIQB'QISFZEV WJGZLNSJIDV 'NAYFYJPGDMHMJYIDWKBVIW UEJDMUVHXDYEUOAXSRBKDNE'RGKZR", "YPHEV AZCNQBRLOCP IQB'QISFZEV WJGZLNSJIDV 'NAYFYJPGDMHMJYIDWKBVIW UEJDMUVHXDYEUOAXSRBKDNE'RGKZR"], ["CJPVZ M

In [10]:
sentences

[["CAPOF'COKLCQZMCOBRWIZTVFPC JIAVSMELKDSVT'JSMHIPXPC'SGAIYNSDPHGEUCQR'ZBRARET SIOVOARFRNDAGUH R",
  "CAPOF'COKLCQZMCOBRWIZTVFPC JIAVSMELKDSVT'JSMHIPXPC'SGAIYNSDPHUEUCQR'ZBRARET SIOVOARFRNDAGUH R"],
 ["VLYS GSFHFMWYKET NXGZAZS'DFWCZ' I'Y'UQ'FOIC W OCHZFXHISWLMBMJFGDZUJDXCXNYFIXJDVIVJYJRPRDRIV",
  "VLYS GSFHFMWYKET NXGZAZS'DFWCZ' I'Y'UQ'FOIC W OCHZFXHISWLMBMJFGDZUJDXCXNYFIXJDVIVJYJRPGDRIV"],
 ["YPHEV AZCNQBRLOCPHIQB'QISFZEV WJGZLNSJIDV 'NAYFYJPGDMHMJYIDWKBVIW UEJDMUVHXDYEUOAXSRBKDNE'RGKZR",
  "YPHEV AZCNQBRLOCP IQB'QISFZEV WJGZLNSJIDV 'NAYFYJPGDMHMJYIDWKBVIW UEJDMUVHXDYEUOAXSRBKDNE'RGKZR"],
 ["CJPVZ MEXAXTSVNLGCESJ'PIJBFNABWPIUDF'PVHPMZWORZBANLHTFTQVRETZYJPIFZREZHTCAMJIVNTQ XHJYGRDKLZOPYA",
  "CJPVZ MEXAXTSVNLGCESJ'PIJBFNABWPIUDF'PVHPMZWORZBANLHTFTQVRETZYJPIFZREZHTCAMJIVNTQDXHJYGRDKLZOPYA"],
 ["GEHRFSGZXQNJYG RBJLYXNMSQGYFOYISRFRFLRDLBIW'T N JAFEXFY'RGEPIFH FSXVPSPNFMOKNRBSX JUCDWYND KCZUA",
  "GEHRFSGZXQNJYG RBJLYXNMSQGYFOYISRFRFLRDLBIW'T N JAFEXFY'RGEPIFY FSXVPSPNFMOKNRBSX JUCDWYND KC

In [16]:
def _decode_one( seq):
    chars = [vocab[i] for i in seq]
    raw = ''.join(chars)
    return raw


In [17]:
sequences.shape

torch.Size([128, 2, 100])

In [18]:
_decode_one(sequences[0,0])

"CAPOF'COKLCQZMC?OBRWIZTTVFPC JIAV?SMELKDSVT'?JSMHIIPXPC'SGAIYNSDPHGEUCQR'ZBRARE?T  SIOVOARFRNDAGUH R"

In [None]:
"""
1.  beamctc decoder
2. 


"""