# Beam search

In [None]:
import torch
import torch.nn.functional as F

def beam_search(decoder, hidden, context, beam_width=3, max_length=10):
    sequences = [[[], 1.0, hidden]]

    for _ in range(max_length):
        all_candidates = []    

        for seq, score, hidden in sequences:
            decoder_input = torch.tensor([seq[-1] if seq else 0])
            output, hidden = decoder(decoder_input, hidden, context)
            top_probs, top_indices = torch.topk(F.softmax(output, dim=1), beam_width)

            for i in range(beam_width):
                candidate = (seq + [top_indices[0][i].item()], 
                            score * top_probs[0][i].item(),
                            hidden)
                all_candidates.append(candidate)

        sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        
    return sequences[0][0]