In [1]:
import numpy as np
import random

In [6]:
class BeamSearch:
    '''
    Basic steps for beam search:
    1. initialize a list of empty beam path with size = width -- k
    2. choose first possible N candidates (N >= k), construct N * k probability paths
    3. score on beam paths and take first k prob_path as beam paths for next step
    4. terminate when decoding finish, otherwise go back to step 2
    '''
    
    def __init__(self):
        self.vocab = {
            0: 'a',
            1: 'b',
            2: 'c',
            3: 'd',
            4: 'e',
            5: 'SOS',
            6: 'EOS'
        }
        self.vocab_pair = dict([(v,k) for k,v in self.vocab.items()])
        self.vocab_size = len(self.vocab.items())
    
    def softmax(x):
        '''
        Compute softmax values for each sets of scores in x.
        '''
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()
    
    def reduce_mul(l):
        out = 1.0
        for x in l:
            out *= x
        return out
    
    def check_all_done(seqs):
        for seq in seqs:
            if not seq[-1]:
                return False
        return True
    
    def decode_step(encoder_context, input_seq):    
        #encoder_context contains infortaion of encoder
        #ouput_step contains the words' probability
        #these two varibles should be generated by seq2seq model
        words_prob = [random.random() for _ in range(vocab_size)]
        #downvote BOS
        words_prob[vocab_pair['BOS']] = 0.0
        words_prob = softmax(words_prob)
        ouput_step = [(idx,prob) for idx,prob in enumerate(words_prob)]        
        ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True)
        return ouput_step

    def beam_search_step(encoder_context, top_seqs, k):       
        '''
        input: [[word,word],[word,word],[word,word]]
        output: [[word,word,word],[word,word,word],[word,word,word]]
        '''
        all_seqs = []
        for seq in top_seqs:
            seq_score = reduce_mul([_score for _,_score in seq])
            if seq[-1][0] == vocab_pair['SOS']:
                all_seqs.append((seq, seq_score, True))
                continue
            #get current step using encoder_context & seq
            current_step = decode_step(encoder_context, seq)
            for i,word in enumerate(current_step):    
                if i >= k:
                    break
                word_index = word[0]
                word_score = word[1]   
                score = seq_score * word_score
                rs_seq = seq + [word]
                done = (word_index == vocab_pair['EOS'])            
                all_seqs.append((rs_seq, score, done))            
        all_seqs = sorted(all_seqs, key = lambda seq: seq[1], reverse=True)        
        topk_seqs = [seq for seq,_,_ in all_seqs[:k]]
        all_done = check_all_done(topk_seqs)
        return topk_seqs, all_done

    def beam_search(encoder_context, beam_width, max_len):
        #START
        top_seqs = [[(vocab_pair['SOS'],1.0)]]
        #loop
        for _ in range(max_len):        
            top_seqs, all_done = beam_search_step(encoder_context, top_seqs, beam_width)
            if all_done:            
                break        
        return top_seqs
    

In [7]:
Beam_Search = BeamSearch()
