In [1]:
import os
import math
import numpy as np
import pickle
import torch
from utils import load_model
from problems.tsp.problem_tsp import TSPDataset
import time
from mcts.mcts_utils import evaluate_tour
from eval import get_best

  from .autonotebook import tqdm as notebook_tqdm


In [47]:
### PARAMETERS TO BE SET ###
n_graphs = 100
n_nodes = 100
width = 6
temperature = 1

In [48]:
# Load model 20 nodes and set temperature
model, _ = load_model(F'pretrained/tsp_{n_nodes}/')
model.eval()
print("model loaded")
# Load test set graphs 20 nodes
# If this block does not work, make sure you called:
# python generate_data.py --problem all --name test --seed 1234
with open(F"data/tsp/tsp{n_nodes}_test_seed1234.pkl", "rb") as f:
    data = pickle.load(f)
    dataset = TSPDataset(None, 0, 0, 0, None)
    dataset.data = [torch.FloatTensor(row) for row in (data[0:0+10000])]
    dataset.size = len(dataset.data)
    graphs = []
    for sample in dataset.data:
        graphs.append(sample)

  [*] Loading model from pretrained/tsp_100/epoch-99.pt
model loaded


In [None]:
# Greedy experiments
greedy_timestamps = []
greedy_results = []

for i in range(n_graphs):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    tour = [0] # Start at first node, unconventional, TODO: fix this
    t_s = time.perf_counter()
    with torch.no_grad():
        embeddings, _ = model.embedder(model._init_embed(graph_batch))

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        fixed = model._precompute(embeddings)
        for visit in range(graph.shape[0] - 1):
            tour_tensor = torch.tensor(tour).long()
            if len(tour_tensor) == 0:
                step_context = model.W_placeholder
            else:
                step_context = torch.cat((embeddings[0, tour_tensor[0]],
                                        embeddings[0, tour_tensor[-1]]), -1)
            query = fixed.context_node_projected + model.project_step_context(step_context[None, None, :])
            mask = torch.zeros(graph_batch.shape[1], dtype=torch.uint8) > 0
            mask[tour_tensor] = 1
            mask = mask[None, None, :]

            log_p, _ = model._one_to_many_logits(query, fixed.glimpse_key, fixed.glimpse_val, fixed.logit_key, mask)
            p = torch.softmax(log_p / temperature, -1)[0, 0]
            assert (p[tour_tensor] == 0).all()
            assert (p.sum() - 1).abs() < 1e-5
            p = p.numpy()
            next_node = np.argmax(p)
            tour.append(next_node)
        t_e = time.perf_counter()
        tour.append(0) # Return to the starting position
        tour_len = evaluate_tour(graph.numpy(), tour)
        greedy_results.append(tour_len)
        greedy_timestamps.append(t_e - t_s)
        print("\nGraph", i)
        print(evaluate_tour(graph.numpy(), tour))

In [17]:
print("Average greedy results", np.mean(greedy_results))
print("Average duration", np.mean(greedy_timestamps))

Average greedy results 13.369455
Average duration 0.04313218084999974


In [18]:
greedy_results_dict = {
    "greedy_result" : greedy_results,
    "greedy_timestamps" : greedy_timestamps
}

In [19]:
save_string = F"experiments/greedy_{n_nodes}.pkl"
with open(save_string, "wb") as f:
    pickle.dump(greedy_results_dict, f)

In [None]:
# Beam search experiments
bs_timestamps = []
bs_results = []

for i in range(n_graphs):
    graph = graphs[i]
    graph_batch = graph[None] # Add batch dimension
    t_s = time.perf_counter()
    with torch.no_grad():
        cum_log_p, sequences, costs, ids, batch_size = model.beam_search(graph_batch, beam_size=width)
        t_e = time.perf_counter()
        if sequences is None:
            sequences = None
            costs = math.inf
        else:
            sequences, costs = get_best(
                sequences.cpu().numpy(), costs.cpu().numpy(),
                ids.cpu().numpy() if ids is not None else None,
                batch_size
            )
        tour = sequences[0].tolist()
        tour.append(tour[0]) # Return to the starting position
        tour_len = evaluate_tour(graph.numpy(), tour)
        bs_results.append(tour_len)
        bs_timestamps.append(t_e - t_s)
        print("\nGraph", i)
        print(tour_len)
        print(tour)

In [50]:
print("Average beam search results", np.mean(bs_results))
print("Average duration", np.mean(bs_timestamps))

Average beam search results 14.617053
Average duration 0.11620078358999933


In [51]:
bs_results_dict = {
    "bs_result" : bs_results,
    "bs_timestamps" : bs_timestamps
}

In [52]:
save_string = F"experiments/beam_search_{n_nodes}_{width}.pkl"
with open(save_string, "wb") as f:
    pickle.dump(bs_results_dict, f)