In [1]:
import os
import numpy as np
import pickle
import torch
from utils import load_model
from problems.tsp.problem_tsp import TSPDataset
import time
from mcts.mcts import MCTS_TSP
from mcts.mcts_utils import evaluate_tour

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### PARAMETERS TO BE SET ###
n_graphs = 10
n_nodes = 20

n_expansions = 100
n_rollouts = 100
eval_selection = "best"
eval_rollout = "best"

temperature = 1

In [3]:
# Load model 20 nodes and set temperature
model, _ = load_model(F'pretrained/tsp_{n_nodes}/')
model.eval()
print("model loaded")
# Load test set graphs 20 nodes
# If this block does not work, make sure you called:
# python generate_data.py --problem all --name test --seed 1234
with open(F"data/tsp/tsp{n_nodes}_test_seed1234.pkl", "rb") as f:
    data = pickle.load(f)
    dataset = TSPDataset(None, 0, 0, 0, None)
    dataset.data = [torch.FloatTensor(row) for row in (data[0:0+10000])]
    dataset.size = len(dataset.data)
    graphs = []
    for sample in dataset.data:
        graphs.append(sample)

  [*] Loading model from pretrained/tsp_20/epoch-99.pt
model loaded


In [4]:
total_len = 0
total_len_best_seen = 0

mcts_timestamps = []
mcts_results = []
best_seen_results = []

for i in range(n_graphs):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    mcts_20_nodes = MCTS_TSP(graph.numpy(), 0, n_expansions, n_rollouts, eval_selection="best", eval_rollout="mean")
    t_s = time.perf_counter()
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph_batch))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[0] - 1):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph_batch.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()
            mcts_20_nodes.update_priors(p)
            next_node = mcts_20_nodes.mcts_decide()
            # print("N children", mcts_20_nodes.root._children.keys())
            # print(p)
            tour.append(next_node)
            mcts_20_nodes.move_to(next_node)
        t_e = time.perf_counter()
        tour.append(0) # Return to the starting position
        tour_len = evaluate_tour(graph.numpy(), tour)
        mcts_results.append(tour_len)
        best_seen_results.append(mcts_20_nodes.best_seen_length)
        mcts_timestamps.append(t_e - t_s)
        print("\nGraph", i)
        print(evaluate_tour(graph.numpy(), tour))
        print("Best tour seen", mcts_20_nodes.best_seen_length)


Graph 0
4.8627996
Best tour seen 4.727939

Graph 1
5.3296275
Best tour seen 5.174877

Graph 2
5.0051584
Best tour seen 4.987064

Graph 3
5.063277
Best tour seen 5.063277

Graph 4
4.360171
Best tour seen 4.2643023

Graph 5
4.9554367
Best tour seen 4.9554367

Graph 6
5.14106
Best tour seen 4.830803

Graph 7
4.5489144
Best tour seen 4.5489144

Graph 8
5.4290514
Best tour seen 5.4290514

Graph 9
4.7659144
Best tour seen 4.765914


In [6]:
print("Average MCTS results", np.mean(mcts_results))
print("Average best seen results", np.mean(best_seen_results))
print("Average duration", np.mean(mcts_timestamps))

Average MCTS results 4.946141
Average best seen results 4.874758
Average duration 7.104339689999999


In [None]:
results_dict = {
    "mcts_result" : mcts_results,
    "best_seen_results" : best_seen_results,
    "mcts_timestamps" : mcts_timestamps
}

In [None]:
save_string = F"experiments/{n_nodes}_{n_graphs}_{n_expansions}_{n_rollouts}_{eval_selection}_{eval_rollout}.pkl"
# print(save_string)
with open(save_string, "wb") as f:
    pickle.dump(results_dict, f)

experiments/20_10_100_100_best_best.pkl
