In [1]:
import os
import numpy as np
import torch
import pickle
from utils import load_model
from problems.tsp.problem_tsp import TSPDataset
from utils.mcts_utils import evaluate_tour
from mcts.mcts import MCTS_TSP

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model 20 nodes and set temperature
temperature = 1

model, _ = load_model('pretrained/tsp_20/')
model.eval()
print("model loaded")

  [*] Loading model from pretrained/tsp_20/epoch-99.pt
model loaded


In [3]:
# Load test set graphs 20 nodes
# If this block does not work, make sure you called:
# python generate_data.py --problem all --name test --seed 1234
with open("data/tsp/tsp20_test_seed1234.pkl", "rb") as f:
    data = pickle.load(f)
    dataset = TSPDataset(None, 0, 0, 0, None)
    dataset.data = [torch.FloatTensor(row) for row in (data[0:0+10000])]
    dataset.size = len(dataset.data)
    graphs = []
    for sample in dataset.data:
        graphs.append(sample)

In [6]:
# Perform greedy evaluation on the first n graphs
n = 100
total_len = 0
for i in range(n):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph_batch))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[0] - 1):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph_batch.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()
            tour.append(np.argmax(p))
        tour.append(0) # Return to the starting position
        print(i, evaluate_tour(graph.numpy(), tour), tour)
        total_len += evaluate_tour(graph.numpy(), tour)
print(total_len/n)

0 5.784772947430611 [0, 14, 7, 10, 13, 1, 6, 17, 15, 2, 11, 5, 3, 16, 8, 12, 4, 9, 19, 18, 0]
1 4.342111557722092 [0, 5, 10, 4, 18, 2, 15, 9, 8, 7, 13, 17, 1, 3, 19, 14, 12, 16, 6, 11, 0]
2 4.117342567071319 [0, 7, 4, 12, 2, 17, 11, 1, 19, 10, 3, 8, 6, 16, 14, 9, 13, 5, 15, 18, 0]
3 6.3465187959373 [0, 15, 11, 8, 6, 14, 2, 12, 9, 5, 3, 17, 18, 1, 19, 10, 4, 7, 16, 13, 0]
4 5.649249363690615 [0, 7, 8, 1, 2, 4, 19, 11, 5, 14, 3, 10, 9, 17, 16, 13, 6, 12, 18, 15, 0]
5 4.553947269916534 [0, 19, 2, 11, 16, 1, 9, 18, 6, 12, 5, 15, 3, 14, 4, 10, 13, 17, 8, 7, 0]
6 4.271065201610327 [0, 15, 9, 13, 14, 18, 12, 10, 7, 17, 5, 11, 8, 4, 3, 1, 19, 2, 16, 6, 0]
7 5.366192378103733 [0, 16, 17, 11, 2, 4, 14, 19, 3, 6, 9, 12, 7, 1, 15, 8, 18, 10, 5, 13, 0]
8 5.081662759184837 [0, 9, 19, 2, 8, 12, 10, 18, 15, 6, 13, 17, 5, 7, 4, 14, 16, 1, 3, 11, 0]
9 4.361514372751117 [0, 9, 5, 15, 14, 2, 17, 8, 11, 18, 3, 13, 12, 16, 6, 19, 1, 7, 10, 4, 0]
10 4.851467456668615 [0, 9, 16, 11, 19, 8, 12, 6, 2, 4, 10, 14

In [4]:
# Perform MCTS evaluation on the first n graphs
n = 10
total_len = 0
for i in range(n):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    mcts_20_nodes = MCTS_TSP(graph.numpy(), 0, 50, 50, model, "best")

    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph_batch))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[0] - 1):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph_batch.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()

            mcts_20_nodes.update_priors(p)
            next_node = mcts_20_nodes.mcts_decide()
            # print("N children", mcts_20_nodes.root._children.keys())
            # print(p)
            tour.append(next_node)
            mcts_20_nodes.move_to(next_node)
        
        tour.append(0) # Return to the starting position
        print(i, evaluate_tour(graph.numpy(), tour), tour)
        total_len += evaluate_tour(graph.numpy(), tour)
        # break
print(total_len/n)

0 6.146226655691862 [0, 19, 1, 13, 11, 3, 15, 6, 8, 2, 16, 12, 4, 9, 17, 18, 10, 7, 5, 14, 0]
1 4.463504403829575 [0, 19, 5, 3, 1, 2, 15, 8, 9, 7, 13, 17, 10, 18, 4, 11, 16, 6, 12, 14, 0]
2 5.051968475803733 [0, 16, 8, 6, 1, 19, 2, 17, 12, 11, 10, 3, 14, 9, 13, 5, 15, 18, 4, 7, 0]
3 5.829832036048174 [0, 14, 2, 7, 16, 10, 3, 9, 5, 4, 12, 6, 13, 8, 11, 15, 1, 19, 17, 18, 0]
4 4.874682184308767 [0, 4, 1, 2, 19, 11, 6, 7, 16, 8, 13, 9, 5, 10, 3, 14, 15, 18, 12, 17, 0]
5 6.993110969662666 [0, 19, 9, 15, 6, 2, 1, 3, 12, 14, 10, 7, 17, 8, 13, 4, 5, 18, 16, 11, 0]
6 5.77114150300622 [0, 11, 8, 5, 15, 13, 9, 1, 4, 3, 19, 2, 10, 14, 18, 12, 17, 7, 6, 16, 0]
7 4.75518112257123 [0, 16, 17, 11, 7, 2, 14, 4, 12, 1, 9, 15, 19, 3, 6, 8, 18, 10, 5, 13, 0]
8 6.1930660754442215 [0, 16, 1, 2, 11, 7, 4, 14, 17, 5, 13, 15, 6, 18, 10, 12, 8, 19, 3, 9, 0]


KeyboardInterrupt: 

In [None]:
4.927719319625758
5.103555862233042