In [1]:
import torch
import numpy as np

In [56]:
beam_width = 2
batch_size = 1
num_nodes = 5

In [57]:
torch.manual_seed(42)

y = torch.randn(1, 5, 5, 2).type(torch.float)
y_pred = torch.nn.functional.softmax(y, dim=3)
y_pred = y_pred[:, :, :, 1]

In [75]:
y_pred = torch.zeros(1, 5, 5).type(torch.float)

for i in range(num_nodes):
    j = (i + 1) % num_nodes
    k = (i + 2) % num_nodes
    y_pred[0][i][i] = 0
    y_pred[0][i][j] = 0.8
    y_pred[0][i][k] = 0.2

In [213]:
class Beamsearch:
    """
    Beam search procedure class.

    References:
        [1]: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam.py
        [2]: https://github.com/alexnowakvila/QAP_pt/blob/master/src/tsp/beam_search.py
        [3]: https://github.com/chaitjo/graph-convnet-tsp/blob/master/utils/beamsearch.py
    """
    def __init__(self, beam_width, trans_probs, random_start=False):
        # all transition probabilities
        self.trans_probs = trans_probs
        self.batch_size = trans_probs.size(0)
        self.num_nodes = trans_probs.size(1)
        
        assert len(trans_probs.shape) == 3, "transition probabilities need to be 3-dimensional"
        assert trans_probs.size(1) == trans_probs.size(2), "transition probabilities are not square"
        
        # Beamsearch parameters
        self.beam_width = beam_width

        # tensor data types and device
        self.device = None
        self.float = torch.float32
        self.long = torch.int64
        
        if random_start == True:
            # starting at random nodes
            start_nodes = torch.randint(0, num_nodes, (batch_size, beam_size))
        else:
            # starting at node zero
            start_nodes = torch.zeros(batch_size, beam_size)
        
        self.start_nodes = start_nodes.type(self.long)
        
        # Mask for constructing valid hypothesis
        self.mask = torch.ones(batch_size, beam_size, num_nodes).type(self.float)
        self.update_mask(self.start_nodes)  # Mask the starting node of the beam search
        
        # Score for each translation on the beam
        self.scores = torch.zeros(batch_size, beam_size).type(self.float)
        self.all_scores = []
        
        # Backpointers at each time-step
        self.parent_pointer  = []
        
        # Outputs at each time-step
        self.next_nodes = [self.start_nodes]

    def get_current_nodes(self):
        """
        Get the nodes to expand at the current timestep
        """
        current_nodes = self.next_nodes[-1]
        current_nodes = current_nodes.unsqueeze(2).expand(self.batch_size,
                                                          self.beam_width, 
                                                          self.num_nodes)
        return current_nodes

    def get_parent_pointer(self):
        """
        Get the pointers of the parents
        """
        return self.parent_pointer[-1]

    def step(self):
        """
        Transition to the next timestep of the beam search
        """
        current_nodes = self.get_current_nodes()
        trans_probs = self.trans_probs.gather(1, current_nodes)
        
        if len(self.parent_pointer) == 0:
            # first transition, only use the starting nodes
            beam_prob = trans_probs
            beam_prob[:, 1:] = torch.zeros(beam_prob[:, 1:].size()).type(self.float)
        else:
            # multiply the previous scores (probabilities) with the current ones
            expanded_scores = self.scores.unsqueeze(2).expand_as(trans_probs) # b x beam_width x num_nodes
            beam_prob = trans_probs * expanded_scores
        
        # mask out visited nodes
        beam_prob = beam_prob * self.mask
        
        beam_prob = beam_prob.view(beam_prob.size(0), -1) # flatten to (b x beam_width * num_nodes)
        
        # get k=beam_width best scores and indices
        best_scores, best_score_idxs = beam_prob.topk(self.beam_width,
                                                      dim=1, largest=True, sorted=True)
        
        self.scores = best_scores
        parent_index = torch.floor_divide(best_score_idxs, self.num_nodes).type(self.long)
        
        self.parent_pointer.append(parent_index)
        
        # next nodes
        next_node = best_score_idxs - (parent_index * self.num_nodes) # convert flattened indices back
        self.next_nodes.append(next_node)
        
        # keep rows of the promising parents (for next masking)
        parent_mask = parent_index.unsqueeze(2).expand_as(self.mask)  # (batch_size, beam_size, num_nodes)
        self.mask = self.mask.gather(1, parent_mask)
        
        # mask next nodes (newly added nodes)
        self.update_mask(next_node)

    def update_mask(self, new_nodes):
        """
        Sets indicies of new_nodes = 0 in the mask.
        
        Args:
            new_nodes: (batch_size, beam_width) of new node indicies
        """
        index = torch.arange(0, self.num_nodes, dtype=self.long).expand_as(self.mask)
        new_nodes = new_nodes.unsqueeze(2).expand_as(self.mask)
        
        # set the mask = 0 at the new_node_idx positions
        update_mask = 1 - torch.eq(index, new_nodes).type(self.float)
        
        self.mask = self.mask * update_mask

    def sort_best(self):
        """
        Sort the beam.
        """
        # TODO: CHECK
        return torch.sort(self.scores, 0, True)

    def get_best(self):
        """
        Get the score and index of the best hypothesis in the beam.
        """
        # TODO: CHECK
        scores, ids = self.sort_best()
        
        return scores[1], ids[1]

    def get_beam(self, beam_idx):
        """
        Construct the beam for the given index

        Args:
            beam_idx: Index of the beam to construct (0 = best, ..., n = worst)
        """
        # TODO: Fix assertion (with vehicles)
        # assert self.num_nodes == len(self.parent_pointers) + 1
        
        prev_pointer = torch.ones(self.batch_size, 1).type(self.long) * beam_idx
        last_node = self.next_nodes[-1].gather(1, prev_pointer)
        
        path = [last_node]
        
        for i in range(len(self.parent_pointer) - 1, -1, -1):
            prev_pointer = self.parent_pointer[i].gather(1, prev_pointer)
            last_node = self.next_nodes[i].gather(1, prev_pointer)
            
            path.append(last_node)
            
        path = list(reversed(path))
        path = torch.cat(path, dim=-1)
        
        return path

In [214]:
y_pred

tensor([[[0.3918, 0.0471, 0.1286, 0.1734, 0.9169],
         [0.2668, 0.5420, 0.8222, 0.1416, 0.7185],
         [0.8625, 0.7068, 0.5043, 0.6735, 0.5679],
         [0.7524, 0.6256, 0.6101, 0.8428, 0.5195],
         [0.4787, 0.6768, 0.8746, 0.9288, 0.8710]]])

In [215]:
beamsearch = Beamsearch(beam_width, trans_probs=y_pred)
bs = beamsearch

for step in range(num_nodes - 1):
    beamsearch.step()

In [216]:
shortest_tours = beamsearch.get_beam(0)
shortest_tours

tensor([[0, 4, 2, 3, 1]])