In [1]:
import torch
import numpy as np

In [2]:
beam_width = 10
batch_size = 2
num_nodes = 6

In [3]:
torch.manual_seed(42)

y = torch.randn(batch_size, num_nodes, num_nodes, 2).type(torch.float)
y_pred = torch.nn.functional.softmax(y, dim=3)
y_pred = y_pred[:, :, :, 1]

In [32]:
y_pred = torch.zeros(batch_size, num_nodes, num_nodes).type(torch.float)

for i in range(num_nodes):
    j = (i + 1) % num_nodes
    k = (i + 2) % num_nodes
    y_pred[0][i][i] = 0
    y_pred[0][i][j] = 0.8
    y_pred[0][i][k] = 0.2

In [4]:
class Beamsearch:
    """
    Beam search procedure class.

    References:
        [1]: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam.py
        [2]: https://github.com/alexnowakvila/QAP_pt/blob/master/src/tsp/beam_search.py
        [3]: https://github.com/chaitjo/graph-convnet-tsp/blob/master/utils/beamsearch.py
    """
    def __init__(self, beam_width, trans_probs, num_vehicles=1, random_start=False):
        # all transition probabilities
        self.trans_probs = trans_probs
        self.batch_size = trans_probs.size(0)
        self.num_nodes = trans_probs.size(1)
        self.num_vehicles = num_vehicles
        
        assert len(trans_probs.shape) == 3, "transition probabilities need to be 3-dimensional"
        assert trans_probs.size(1) == trans_probs.size(2), "transition probabilities are not square"
        
        # Beamsearch parameters
        self.beam_width = beam_width

        # TODO: Move tensors to GPU device for faster computation
        # tensor data types and device
        self.device = None
        self.float = torch.float32
        self.long = torch.int64
        
        if random_start == True:
            # starting at random nodes
            start_nodes = torch.randint(0, self.num_nodes, (self.batch_size, self.beam_width))
        else:
            # starting at node zero
            start_nodes = torch.zeros(self.batch_size, self.beam_width)
        
        self.start_nodes = start_nodes.type(self.long)
        self.depot_visits_counter = torch.zeros(self.batch_size, self.beam_width)
        
        # TODO: could also mask self-connections
        # mask for removing visited nodes etc.
        self.mask = torch.ones(self.batch_size, self.beam_width, self.num_nodes).type(self.float)
        
        # start by masking the starting nodes
        self.update_mask(self.start_nodes)
        
        # transition probability scores up-until current timestep
        self.scores = torch.zeros(self.batch_size, self.beam_width).type(self.float)
        
        # pointers to parents for each timestep
        self.parent_pointer  = []
        
        # nodes at each timestep
        self.next_nodes = [self.start_nodes]

    def get_current_nodes(self):
        """
        Get the nodes to expand at the current timestep
        """
        current_nodes = self.next_nodes[-1]
        current_nodes = current_nodes.unsqueeze(2).expand(self.batch_size,
                                                          self.beam_width, 
                                                          self.num_nodes)
        return current_nodes
    
    @property
    def num_iterations(self):
        # -1 for num_nodes because we already start at depot
        # -1 to offset num_vehicles
        return self.num_nodes - 1 + self.num_vehicles - 1
    
    def start(self):
        """
        Start beam search
        """
        
        for step in range(self.num_iterations):
            self.step()

    def step(self):
        """
        Transition to the next timestep of the beam search
        """
        current_nodes = self.get_current_nodes()
        trans_probs = self.trans_probs.gather(1, current_nodes)
        
        if len(self.parent_pointer) == 0:
            # first transition, only use the starting nodes
            beam_prob = trans_probs
            beam_prob[:, 1:] = torch.zeros_like(beam_prob[:, 1:])
        else:
            # multiply the previous scores (probabilities) with the current ones
            expanded_scores = self.scores.unsqueeze(2).expand_as(trans_probs) # b x beam_width x num_nodes
            beam_prob = trans_probs * expanded_scores
        
        # mask out visited nodes
        beam_prob = beam_prob * self.mask
        
        beam_prob = beam_prob.view(beam_prob.size(0), -1) # flatten to (b x beam_width * num_nodes)
        
        # get k=beam_width best scores and indices
        best_scores, best_score_idxs = beam_prob.topk(self.beam_width,
                                                      dim=1, largest=True, sorted=True)
        
        self.scores = best_scores
        parent_index = torch.floor_divide(best_score_idxs, self.num_nodes).type(self.long)
        
        self.parent_pointer.append(parent_index)
        
        # next nodes
        next_node = best_score_idxs - (parent_index * self.num_nodes) # convert flat indices back to original
        self.next_nodes.append(next_node)
        
        # keep masked rows from parents (for next step)
        parent_mask = parent_index.unsqueeze(2).expand_as(self.mask)  # (batch_size, beam_size, num_nodes)
        self.mask = self.mask.gather(1, parent_mask)
        
        # keep depot counter from parent (for next step)
        self.depot_visits_counter = self.depot_visits_counter.gather(1, parent_index)
        
        # mask next nodes (newly added nodes)
        self.update_mask(next_node)

    def update_mask(self, new_nodes):
        """
        Sets indicies of new_nodes = 0 in the mask.
        
        Args:
            new_nodes: (batch_size, beam_width) of new node indicies
        """
        index = torch.arange(0, self.num_nodes, dtype=self.long).expand_as(self.mask)
        new_nodes = new_nodes.unsqueeze(2).expand_as(self.mask)
        
        # set the mask = 0 at the new_node_idx positions
        update_mask = torch.ne(index, new_nodes).type(self.float)
        
        # are we currently visiting the depot?
        is_visiting_depot = torch.eq(update_mask[:, :, 0], 0)
        not_visiting_depot = torch.logical_not(is_visiting_depot)

        # increment depot visit counter where visited
        self.depot_visits_counter += is_visiting_depot.type(self.float)
        is_depot_available = torch.lt(self.depot_visits_counter, self.num_vehicles)
        
        # allow another depot visit (as floats)
        allow_depot_visit = torch.logical_and(not_visiting_depot,
                                              is_depot_available).type(self.float)
        
        self.mask = self.mask * update_mask
        self.mask[:, :, 0] = allow_depot_visit

    def get_beam(self, beam_idx):
        """
        Construct the beam for the given index

        Args:
            beam_idx int: Index of the beam to construct (0 = best, ..., n = worst)
        """
        # TODO: Fix assertion (with vehicles)
        # assert self.num_nodes == len(self.parent_pointers) + 1
        
        prev_pointer = torch.ones(self.batch_size, 1).type(self.long) * beam_idx
        last_node = self.next_nodes[-1].gather(1, prev_pointer)
        
        path = [last_node]
        
        for i in range(len(self.parent_pointer) - 1, -1, -1):
            prev_pointer = self.parent_pointer[i].gather(1, prev_pointer)
            last_node = self.next_nodes[i].gather(1, prev_pointer)
            
            path.append(last_node)
        
        path = list(reversed(path))
        path = torch.cat(path, dim=-1)
        
        return path
    
    def validate(self, beam, batch_idx=None, beam_idx=None):
        bin_count = torch.bincount(beam)
        
        assert bin_count[0] <= self.num_vehicles, f"Batch={batch_idx}, beam={beam_idx}: too many depot visits {bin_count[0]} > {self.num_vehicles}\n{beam}"
        # want them seperate for sanity
        assert torch.all(bin_count[1:] <= 1), f"Batch={batch_idx}, beam={beam_idx}: too many node visits\n{beam}"
        assert torch.all(bin_count[1:] > 0), f"Batch={batch_idx}, beam={beam_idx}: not all node visited\n{beam}"
    
    def sanity_check(self):
        for beam_idx in range(self.beam_width):
            beams = self.get_beam(beam_idx)
            
            for batch_idx in range(beams.size(0)):
                beam = beams[batch_idx]
                
                self.validate(beam, batch_idx=batch_idx, beam_idx=beam_idx)

In [5]:
num_vehicles = 3

beamsearch = Beamsearch(beam_width, trans_probs=y_pred, num_vehicles=num_vehicles)
beamsearch.start()

In [6]:
shortest_tours = beamsearch.get_beam(0)
shortest_tours

tensor([[0, 4, 2, 5, 1, 0, 3, 0],
        [0, 3, 2, 0, 5, 0, 1, 4]])

In [21]:
beamsearch.sanity_check()

AssertionError: Batch=0, beam=4: too many node visits
tensor([0, 4, 1, 3, 2, 5, 0, 4])

The problem with the double visits is when the second to last node was the depot

In [7]:
for beam_idx in range(beamsearch.beam_width):
    beams = beamsearch.get_beam(beam_idx)
    print(beam_idx, beams)
    print()

0 tensor([[0, 4, 2, 5, 1, 0, 3, 0],
        [0, 3, 2, 0, 5, 0, 1, 4]])

1 tensor([[0, 4, 2, 1, 5, 0, 3, 0],
        [0, 5, 0, 3, 2, 0, 1, 4]])

2 tensor([[0, 4, 2, 3, 5, 0, 1, 0],
        [0, 3, 2, 0, 1, 4, 5, 0]])

3 tensor([[0, 4, 3, 2, 5, 0, 1, 0],
        [0, 1, 3, 2, 0, 5, 4, 0]])

4 tensor([[0, 4, 1, 3, 2, 5, 0, 4],
        [0, 5, 3, 2, 0, 1, 4, 0]])

5 tensor([[0, 4, 1, 3, 2, 5, 0, 5],
        [0, 3, 2, 0, 5, 4, 0, 1]])

6 tensor([[0, 4, 2, 1, 3, 5, 0, 0],
        [0, 3, 2, 0, 5, 4, 1, 0]])

7 tensor([[0, 4, 2, 1, 3, 5, 0, 1],
        [0, 1, 3, 2, 0, 5, 0, 4]])

8 tensor([[0, 4, 2, 1, 3, 5, 0, 2],
        [0, 5, 0, 3, 2, 4, 0, 1]])

9 tensor([[0, 4, 2, 1, 3, 5, 0, 3],
        [0, 5, 0, 3, 2, 4, 1, 0]])



## Tour distance

The code below can serve as a good starting point for calculating the distance of the tour found so far.

In [102]:
# offset by one to calculate total travel distance
idx = (torch.arange(0, shortest_tours.size(1)) + 1) % shortest_tours.size(1)
idx = idx.expand_as(shortest_tours)

In [61]:
shortest_tours.gather(1, x)

tensor([[4, 2, 3, 1, 0],
        [3, 2, 4, 1, 0],
        [4, 3, 1, 2, 0]])