In [1]:
import os
import math
import numpy as np
import pickle
import torch
from utils import load_model
from problems.tsp.problem_tsp import TSPDataset
import time
from mcts.mcts_utils import evaluate_tour
from eval import get_best

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### PARAMETERS TO BE SET ###
n_graphs = 100
n_nodes = 20
width = 3
temperature = 1

In [3]:
# Load model 20 nodes and set temperature
model, _ = load_model(F'pretrained/tsp_{n_nodes}/')
model.eval()
print("model loaded")
# Load test set graphs 20 nodes
# If this block does not work, make sure you called:
# python generate_data.py --problem all --name test --seed 1234
with open(F"data/tsp/tsp{n_nodes}_test_seed1234.pkl", "rb") as f:
    data = pickle.load(f)
    dataset = TSPDataset(None, 0, 0, 0, None)
    dataset.data = [torch.FloatTensor(row) for row in (data[0:0+10000])]
    dataset.size = len(dataset.data)
    graphs = []
    for sample in dataset.data:
        graphs.append(sample)

  [*] Loading model from pretrained/tsp_20/epoch-99.pt
model loaded


In [8]:
# Greedy experiments
greedy_timestamps = []
greedy_results = []

for i in range(n_graphs):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    t_s = time.perf_counter()
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph_batch))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[0] - 1):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph_batch.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()
            next_node = np.argmax(p)
            tour.append(next_node)
        t_e = time.perf_counter()
        tour.append(0) # Return to the starting position
        tour_len = evaluate_tour(graph.numpy(), tour)
        greedy_results.append(tour_len)
        greedy_timestamps.append(t_e - t_s)
        print("\nGraph", i)
        print(evaluate_tour(graph.numpy(), tour))


Graph 0
5.784773

Graph 1
4.342112

Graph 2
4.1173425

Graph 3
6.346519

Graph 4
5.6492496

Graph 5
4.553947

Graph 6
4.271065

Graph 7
5.3661923

Graph 8
5.0816627

Graph 9
4.361514

Graph 10
4.851467

Graph 11
4.20928

Graph 12
8.482102

Graph 13
3.9690795

Graph 14
5.7688513

Graph 15
4.524123

Graph 16
6.244032

Graph 17
4.79237

Graph 18
4.2341666

Graph 19
3.8638964

Graph 20
4.291158

Graph 21
5.8514915

Graph 22
6.1824718

Graph 23
5.085176

Graph 24
7.110321

Graph 25
6.004989

Graph 26
5.3201456

Graph 27
4.287364

Graph 28
6.0501404

Graph 29
4.5536556

Graph 30
4.703373

Graph 31
3.6298907

Graph 32
6.4400434

Graph 33
4.4730453

Graph 34
5.0265646

Graph 35
4.5558476

Graph 36
4.019388

Graph 37
4.776818

Graph 38
5.0708075

Graph 39
5.5861826

Graph 40
5.4362097

Graph 41
4.009011

Graph 42
4.3141522

Graph 43
3.7196767

Graph 44
4.679436

Graph 45
5.2841215

Graph 46
4.0583687

Graph 47
3.724585

Graph 48
5.0580688

Graph 49
8.100692

Graph 50
4.8545322

Graph 51
4.1499

In [9]:
print("Average greedy results", np.mean(greedy_results))
print("Average duration", np.mean(greedy_timestamps))

Average greedy results 4.9277196
Average duration 0.011673191109784965


In [6]:
greedy_results_dict = {
    "greedy_result" : greedy_results,
    "greedy_timestamps" : greedy_timestamps
}

In [7]:
save_string = F"experiments/benmark_greedy.pkl"
with open(save_string, "wb") as f:
    pickle.dump(greedy_results_dict, f)

In [4]:
# Beam search experiments
bs_timestamps = []
bs_results = []

for i in range(n_graphs):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    t_s = time.perf_counter()
    with torch.no_grad():
        cum_log_p, sequences, costs, ids, batch_size = model.beam_search(graph_batch, beam_size=width)
        t_e = time.perf_counter()
        if sequences is None:
            sequences = None
            costs = math.inf
        else:
            sequences, costs = get_best(
                sequences.cpu().numpy(), costs.cpu().numpy(),
                ids.cpu().numpy() if ids is not None else None,
                batch_size
            )
        tour = sequences[0].tolist()
        tour.append(tour[0]) # Return to the starting position
        tour_len = evaluate_tour(graph.numpy(), tour)
        bs_results.append(tour_len)
        bs_timestamps.append(t_e - t_s)
        print("\nGraph", i)
        print(tour_len)
        print(tour)

  flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) // ind_topk.size(-1)



Graph 0
4.558143
[0, 14, 10, 7, 5, 16, 2, 15, 12, 4, 9, 17, 6, 1, 13, 19, 18, 3, 11, 8, 0]

Graph 1
4.3081446
[0, 5, 10, 4, 18, 2, 15, 9, 8, 7, 13, 17, 1, 3, 19, 14, 12, 6, 16, 11, 0]

Graph 2
4.0959578
[0, 7, 4, 12, 17, 2, 11, 1, 19, 10, 3, 8, 6, 16, 14, 9, 13, 5, 15, 18, 0]

Graph 3
6.0702825
[0, 18, 15, 11, 8, 6, 14, 2, 12, 10, 3, 5, 9, 19, 1, 17, 4, 7, 16, 13, 0]

Graph 4
5.108242
[0, 7, 8, 1, 2, 4, 19, 11, 5, 14, 3, 10, 9, 12, 18, 15, 13, 6, 16, 17, 0]

Graph 5
4.387086
[0, 11, 19, 2, 16, 1, 9, 18, 6, 12, 5, 15, 3, 14, 4, 10, 13, 17, 8, 7, 0]

Graph 6
4.271065
[0, 15, 9, 13, 14, 18, 12, 10, 7, 17, 5, 11, 8, 4, 3, 1, 19, 2, 16, 6, 0]

Graph 7
5.2591357
[0, 16, 17, 11, 2, 4, 14, 19, 3, 6, 9, 12, 1, 7, 15, 8, 18, 10, 5, 13, 0]

Graph 8
5.0478
[0, 9, 8, 19, 2, 3, 16, 1, 12, 10, 18, 15, 6, 13, 17, 5, 7, 4, 14, 11, 0]

Graph 9
4.3290777
[0, 9, 15, 5, 14, 2, 17, 8, 11, 18, 3, 13, 12, 16, 6, 19, 1, 7, 10, 4, 0]

Graph 10
4.8199754
[0, 9, 16, 11, 19, 12, 8, 6, 2, 4, 10, 14, 5, 7, 13, 18, 

In [5]:
print("Average beam search results", np.mean(bs_results))
print("Average duration", np.mean(bs_timestamps))

Average beam search results 4.750436
Average duration 0.02011635949000003


In [6]:
bs_results_dict = {
    "bs_result" : bs_results,
    "bs_timestamps" : bs_timestamps
}

In [7]:
save_string = F"experiments/beam_search_{width}.pkl"
with open(save_string, "wb") as f:
    pickle.dump(bs_results_dict, f)