In [1]:
import os
import numpy as np
import torch
import pickle
from utils import load_model
from problems.tsp.problem_tsp import TSPDataset
from utils.mcts import MCTS_TSP, evaluate_tour

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load model 20 nodes and set temperature
temperature = 1

model, _ = load_model('pretrained/tsp_20/')
model.eval()
print("model loaded")

In [None]:
# Load test set graphs 20 nodes
# If this block does not work, make sure you called:
# python generate_data.py --problem all --name test --seed 1234
with open("data/tsp/tsp20_test_seed1234.pkl", "rb") as f:
    data = pickle.load(f)
    dataset = TSPDataset(None, 0, 0, 0, None)
    dataset.data = [torch.FloatTensor(row) for row in (data[0:0+10000])]
    dataset.size = len(dataset.data)
    graphs = []
    for sample in dataset.data:
        graphs.append(sample)

In [None]:
# Perform greedy evaluation on the first n graphs
n = 10

for i in range(n):
    graph = graphs[i][None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[1]):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()
            tour.append(np.argmax(p))
    print(evaluate_tour(tour), tour)


In [None]:
# Perform MCTS evaluation on the first n graphs
n = 10

for i in range(n):
    graph = graphs[i][None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    mcts_20_nodes = MCTS_TSP(graph.numpy(), 0, 100, 500, model, "best")
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[1]):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()

            mcts_20_nodes.update_priors(p)
            next_node = mcts_20_nodes.mcts_decide()
            tour.append(next_node)

    print(evaluate_tour(tour), tour)