# TSP Solver avec GCN + Beam Search (Kaggle - Double GPU)

Ce notebook int√®gre tous les modules n√©cessaires pour r√©soudre le Travelling Salesman Problem (TSP) en utilisant un Residual Gated Graph Convolutional Network avec Beam Search.

**R√©f√©rence:** [An Efficient Graph ConvNet for the Travelling Salesman Problem](https://arxiv.org/pdf/1711.07553v2.pdf)

## 1. Installation des d√©pendances

In [None]:
!pip install -q tensorboardX fastprogress

## 2. Imports et configuration

In [None]:
import os
import json
import time
import glob

import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.utils import shuffle
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx

from tensorboardX import SummaryWriter
from fastprogress import master_bar, progress_bar

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from scipy.sparse import SparseEfficiencyWarning
warnings.simplefilter('ignore', SparseEfficiencyWarning)

%matplotlib inline
from IPython import get_ipython
get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'png'")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Nombre de GPU: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

## 3. Configuration (Settings)

In [None]:
class Settings(dict):
    """Experiment configuration options.
    Wrapper around in-built dict class to access members through the dot operation.
    """
    def __init__(self, config_dict):
        super().__init__()
        for key in config_dict:
            self[key] = config_dict[key]

    def __getattr__(self, attr):
        return self[attr]

    def __setitem__(self, key, value):
        return super().__setitem__(key, value)

    def __setattr__(self, key, value):
        return self.__setitem__(key, value)

    __delattr__ = dict.__delitem__

## 4. Data Reader (Google TSP Reader)

In [None]:
class DotDict(dict):
    """Wrapper around in-built dict class to access members through the dot operation."""
    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self


class GoogleTSPReader(object):
    """Iterator that reads TSP dataset files and yields mini-batches.
    Format expected as in Vinyals et al., 2015.
    """

    def __init__(self, num_nodes, num_neighbors, batch_size, filepath):
        self.num_nodes = num_nodes
        self.num_neighbors = num_neighbors
        self.batch_size = batch_size
        self.filepath = filepath
        self.filedata = shuffle(open(filepath, "r").readlines())
        self.max_iter = (len(self.filedata) // batch_size)

    def __iter__(self):
        for batch in range(self.max_iter):
            start_idx = batch * self.batch_size
            end_idx = (batch + 1) * self.batch_size
            yield self.process_batch(self.filedata[start_idx:end_idx])

    def process_batch(self, lines):
        """Helper function to convert raw lines into a mini-batch as a DotDict."""
        batch_edges = []
        batch_edges_values = []
        batch_edges_target = []
        batch_nodes = []
        batch_nodes_target = []
        batch_nodes_coord = []
        batch_tour_nodes = []
        batch_tour_len = []

        for line_num, line in enumerate(lines):
            line = line.split(" ")
            nodes = np.ones(self.num_nodes)
            nodes_coord = []
            for idx in range(0, 2 * self.num_nodes, 2):
                nodes_coord.append([float(line[idx]), float(line[idx + 1])])

            W_val = squareform(pdist(nodes_coord, metric='euclidean'))

            if self.num_neighbors == -1:
                W = np.ones((self.num_nodes, self.num_nodes))
            else:
                W = np.zeros((self.num_nodes, self.num_nodes))
                knns = np.argpartition(W_val, kth=self.num_neighbors, axis=-1)[:, self.num_neighbors::-1]
                for idx in range(self.num_nodes):
                    W[idx][knns[idx]] = 1
            np.fill_diagonal(W, 2)

            tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]

            tour_len = 0
            nodes_target = np.zeros(self.num_nodes)
            edges_target = np.zeros((self.num_nodes, self.num_nodes))
            for idx in range(len(tour_nodes) - 1):
                i = tour_nodes[idx]
                j = tour_nodes[idx + 1]
                nodes_target[i] = idx
                edges_target[i][j] = 1
                edges_target[j][i] = 1
                tour_len += W_val[i][j]

            nodes_target[j] = len(tour_nodes) - 1
            edges_target[j][tour_nodes[0]] = 1
            edges_target[tour_nodes[0]][j] = 1
            tour_len += W_val[j][tour_nodes[0]]

            batch_edges.append(W)
            batch_edges_values.append(W_val)
            batch_edges_target.append(edges_target)
            batch_nodes.append(nodes)
            batch_nodes_target.append(nodes_target)
            batch_nodes_coord.append(nodes_coord)
            batch_tour_nodes.append(tour_nodes)
            batch_tour_len.append(tour_len)

        batch = DotDict()
        batch.edges = np.stack(batch_edges, axis=0)
        batch.edges_values = np.stack(batch_edges_values, axis=0)
        batch.edges_target = np.stack(batch_edges_target, axis=0)
        batch.nodes = np.stack(batch_nodes, axis=0)
        batch.nodes_target = np.stack(batch_nodes_target, axis=0)
        batch.nodes_coord = np.stack(batch_nodes_coord, axis=0)
        batch.tour_nodes = np.stack(batch_tour_nodes, axis=0)
        batch.tour_len = np.stack(batch_tour_len, axis=0)
        return batch

## 5. Graph Utilities

In [None]:
def tour_nodes_to_W(nodes):
    """Convert ordered list of tour nodes to edge adjacency matrix."""
    W = np.zeros((len(nodes), len(nodes)))
    for idx in range(len(nodes) - 1):
        i = int(nodes[idx])
        j = int(nodes[idx + 1])
        W[i][j] = 1
        W[j][i] = 1
    W[j][int(nodes[0])] = 1
    W[int(nodes[0])][j] = 1
    return W


def tour_nodes_to_tour_len(nodes, W_values):
    """Calculate tour length from ordered list of tour nodes."""
    tour_len = 0
    for idx in range(len(nodes) - 1):
        i = nodes[idx]
        j = nodes[idx + 1]
        tour_len += W_values[i][j]
    tour_len += W_values[j][nodes[0]]
    return tour_len


def W_to_tour_len(W, W_values):
    """Calculate tour length from edge adjacency matrix."""
    tour_len = 0
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            if W[i][j] == 1:
                tour_len += W_values[i][j]
    tour_len /= 2
    return tour_len


def is_valid_tour(nodes, num_nodes):
    """Sanity check: tour visits all nodes given."""
    return sorted(nodes) == [i for i in range(num_nodes)]


def mean_tour_len_edges(x_edges_values, y_pred_edges):
    """Computes mean tour length for given batch prediction as edge adjacency matrices."""
    y = F.softmax(y_pred_edges, dim=3)
    y = y.argmax(dim=3)
    tour_lens = (y.float() * x_edges_values.float()).sum(dim=1).sum(dim=1) / 2
    mean_tour_len = tour_lens.sum().to(dtype=torch.float).item() / tour_lens.numel()
    return mean_tour_len


def mean_tour_len_nodes(x_edges_values, bs_nodes):
    """Computes mean tour length for given batch prediction as node ordering after beamsearch."""
    y = bs_nodes.cpu().numpy()
    W_val = x_edges_values.cpu().numpy()
    running_tour_len = 0
    for batch_idx in range(y.shape[0]):
        for y_idx in range(y[batch_idx].shape[0] - 1):
            i = y[batch_idx][y_idx]
            j = y[batch_idx][y_idx + 1]
            running_tour_len += W_val[batch_idx][i][j]
        running_tour_len += W_val[batch_idx][j][0]
    return running_tour_len / y.shape[0]

## 6. GCN Layers

In [None]:
class BatchNormNode(nn.Module):
    """Batch normalization for node features."""
    def __init__(self, hidden_dim):
        super(BatchNormNode, self).__init__()
        self.batch_norm = nn.BatchNorm1d(hidden_dim, track_running_stats=False)

    def forward(self, x):
        x_trans = x.transpose(1, 2).contiguous()
        x_trans_bn = self.batch_norm(x_trans)
        x_bn = x_trans_bn.transpose(1, 2).contiguous()
        return x_bn


class BatchNormEdge(nn.Module):
    """Batch normalization for edge features."""
    def __init__(self, hidden_dim):
        super(BatchNormEdge, self).__init__()
        self.batch_norm = nn.BatchNorm2d(hidden_dim, track_running_stats=False)

    def forward(self, e):
        e_trans = e.permute(0, 3, 1, 2).contiguous()
        e_trans_bn = self.batch_norm(e_trans)
        e_bn = e_trans_bn.permute(0, 2, 3, 1).contiguous()
        return e_bn


class NodeFeatures(nn.Module):
    """Convnet features for nodes.
    Using `mean` aggregation: x_i = U*x_i + ( sum_j [ gate_ij * (V*x_j) ] / sum_j [ gate_ij] )
    Using `sum` aggregation:  x_i = U*x_i + sum_j [ gate_ij * (V*x_j) ]
    """
    def __init__(self, hidden_dim, aggregation="mean"):
        super(NodeFeatures, self).__init__()
        self.aggregation = aggregation
        self.U = nn.Linear(hidden_dim, hidden_dim, True)
        self.V = nn.Linear(hidden_dim, hidden_dim, True)

    def forward(self, x, edge_gate):
        Ux = self.U(x)
        Vx = self.V(x)
        Vx = Vx.unsqueeze(1)
        gateVx = edge_gate * Vx
        if self.aggregation == "mean":
            x_new = Ux + torch.sum(gateVx, dim=2) / (1e-20 + torch.sum(edge_gate, dim=2))
        elif self.aggregation == "sum":
            x_new = Ux + torch.sum(gateVx, dim=2)
        return x_new


class EdgeFeatures(nn.Module):
    """Convnet features for edges: e_ij = U*e_ij + V*(x_i + x_j)"""
    def __init__(self, hidden_dim):
        super(EdgeFeatures, self).__init__()
        self.U = nn.Linear(hidden_dim, hidden_dim, True)
        self.V = nn.Linear(hidden_dim, hidden_dim, True)

    def forward(self, x, e):
        Ue = self.U(e)
        Vx = self.V(x)
        Wx = Vx.unsqueeze(1)
        Vx = Vx.unsqueeze(2)
        e_new = Ue + Vx + Wx
        return e_new


class ResidualGatedGCNLayer(nn.Module):
    """Convnet layer with gating and residual connection."""
    def __init__(self, hidden_dim, aggregation="sum"):
        super(ResidualGatedGCNLayer, self).__init__()
        self.node_feat = NodeFeatures(hidden_dim, aggregation)
        self.edge_feat = EdgeFeatures(hidden_dim)
        self.bn_node = BatchNormNode(hidden_dim)
        self.bn_edge = BatchNormEdge(hidden_dim)

    def forward(self, x, e):
        e_in = e
        x_in = x
        e_tmp = self.edge_feat(x_in, e_in)
        edge_gate = torch.sigmoid(e_tmp)
        x_tmp = self.node_feat(x_in, edge_gate)
        e_tmp = self.bn_edge(e_tmp).contiguous()
        x_tmp = self.bn_node(x_tmp).contiguous()
        e = F.relu(e_tmp)
        x = F.relu(x_tmp)
        x_new = x_in + x
        e_new = e_in + e
        return x_new, e_new


class MLP(nn.Module):
    """Multi-layer Perceptron for output prediction."""
    def __init__(self, hidden_dim, output_dim, L=2):
        super(MLP, self).__init__()
        self.L = L
        U = []
        for layer in range(self.L - 1):
            U.append(nn.Linear(hidden_dim, hidden_dim, True))
        self.U = nn.ModuleList(U)
        self.V = nn.Linear(hidden_dim, output_dim, True)

    def forward(self, x):
        Ux = x
        for U_i in self.U:
            Ux = U_i(Ux)
            Ux = F.relu(Ux)
        y = self.V(Ux)
        return y

## 7. Loss Functions & Beam Search

In [None]:
# ===================== LOSS FUNCTIONS =====================

def loss_edges(y_pred_edges, y_edges, edge_cw):
    """Loss function for edge predictions."""
    y = F.log_softmax(y_pred_edges, dim=3)
    y = y.permute(0, 3, 1, 2).contiguous()
    loss = nn.NLLLoss(edge_cw)(y, y_edges)
    return loss


def edge_error(y_pred, y_target, x_edges):
    """Computes edge error metrics for given batch prediction and targets."""
    y = F.softmax(y_pred, dim=3)
    y = y.argmax(dim=3)

    mask_no_edges = x_edges.long()
    err_edges, _ = _edge_error(y, y_target, mask_no_edges)

    mask_no_tour = y_target
    err_tour, err_idx_tour = _edge_error(y, y_target, mask_no_tour)

    mask_no_tsp = ((y_target + y) > 0).long()
    err_tsp, err_idx_tsp = _edge_error(y, y_target, mask_no_tsp)

    return 100 * err_edges, 100 * err_tour, 100 * err_tsp, err_idx_tour, err_idx_tsp


def _edge_error(y, y_target, mask):
    """Helper method to compute edge errors."""
    acc = (y == y_target).long()
    acc = (acc * mask)
    acc = acc.sum(dim=1).sum(dim=1).to(dtype=torch.float) / mask.sum(dim=1).sum(dim=1).to(dtype=torch.float)
    err_idx = (acc < 1.0)
    acc = acc.sum().to(dtype=torch.float).item() / acc.numel()
    err = 1.0 - acc
    return err, err_idx


# ===================== BEAM SEARCH =====================

class Beamsearch(object):
    """Class for managing internals of beamsearch procedure."""

    def __init__(self, beam_size, batch_size, num_nodes,
                 dtypeFloat=torch.FloatTensor, dtypeLong=torch.LongTensor,
                 probs_type='raw', random_start=False):
        self.batch_size = batch_size
        self.beam_size = beam_size
        self.num_nodes = num_nodes
        self.probs_type = probs_type
        self.dtypeFloat = dtypeFloat
        self.dtypeLong = dtypeLong
        self.start_nodes = torch.zeros(batch_size, beam_size).type(self.dtypeLong)
        if random_start:
            self.start_nodes = torch.randint(0, num_nodes, (batch_size, beam_size)).type(self.dtypeLong)
        self.mask = torch.ones(batch_size, beam_size, num_nodes).type(self.dtypeFloat)
        self.update_mask(self.start_nodes)
        self.scores = torch.zeros(batch_size, beam_size).type(self.dtypeFloat)
        self.all_scores = []
        self.prev_Ks = []
        self.next_nodes = [self.start_nodes]

    def get_current_state(self):
        current_state = (self.next_nodes[-1].unsqueeze(2)
                         .expand(self.batch_size, self.beam_size, self.num_nodes))
        return current_state

    def get_current_origin(self):
        return self.prev_Ks[-1]

    def advance(self, trans_probs):
        if len(self.prev_Ks) > 0:
            if self.probs_type == 'raw':
                beam_lk = trans_probs * self.scores.unsqueeze(2).expand_as(trans_probs)
            elif self.probs_type == 'logits':
                beam_lk = trans_probs + self.scores.unsqueeze(2).expand_as(trans_probs)
        else:
            beam_lk = trans_probs
            if self.probs_type == 'raw':
                beam_lk[:, 1:] = torch.zeros(beam_lk[:, 1:].size()).type(self.dtypeFloat)
            elif self.probs_type == 'logits':
                beam_lk[:, 1:] = -1e20 * torch.ones(beam_lk[:, 1:].size()).type(self.dtypeFloat)
        beam_lk = beam_lk * self.mask
        beam_lk = beam_lk.view(self.batch_size, -1)
        bestScores, bestScoresId = beam_lk.topk(self.beam_size, 1, True, True)
        self.scores = bestScores
        prev_k = bestScoresId // self.num_nodes
        self.prev_Ks.append(prev_k)
        new_nodes = bestScoresId - prev_k * self.num_nodes
        self.next_nodes.append(new_nodes)
        perm_mask = prev_k.unsqueeze(2).expand_as(self.mask)
        self.mask = self.mask.gather(1, perm_mask)
        self.update_mask(new_nodes)

    def update_mask(self, new_nodes):
        arr = (torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(1)
               .expand_as(self.mask).type(self.dtypeLong))
        new_nodes = new_nodes.unsqueeze(2).expand_as(self.mask)
        update_mask = 1 - torch.eq(arr, new_nodes).type(self.dtypeFloat)
        self.mask = self.mask * update_mask
        if self.probs_type == 'logits':
            self.mask[self.mask == 0] = 1e20

    def sort_best(self):
        return torch.sort(self.scores, 0, True)

    def get_best(self):
        scores, ids = self.sort_best()
        return scores[1], ids[1]

    def get_hypothesis(self, k):
        assert self.num_nodes == len(self.prev_Ks) + 1
        hyp = -1 * torch.ones(self.batch_size, self.num_nodes).type(self.dtypeLong)
        for j in range(len(self.prev_Ks) - 1, -2, -1):
            hyp[:, j + 1] = self.next_nodes[j + 1].gather(1, k).view(1, self.batch_size)
            k = self.prev_Ks[j].gather(1, k)
        return hyp


def beamsearch_tour_nodes(y_pred_edges, beam_size, batch_size, num_nodes,
                          dtypeFloat, dtypeLong, probs_type='raw', random_start=False):
    """Performs beamsearch on edge prediction matrices and returns possible TSP tours."""
    if probs_type == 'raw':
        y = F.softmax(y_pred_edges, dim=3)[:, :, :, 1]
    elif probs_type == 'logits':
        y = F.log_softmax(y_pred_edges, dim=3)[:, :, :, 1]
        y[y == 0] = -1e-20
    beamsearch = Beamsearch(beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type, random_start)
    trans_probs = y.gather(1, beamsearch.get_current_state())
    for step in range(num_nodes - 1):
        beamsearch.advance(trans_probs)
        trans_probs = y.gather(1, beamsearch.get_current_state())
    ends = torch.zeros(batch_size, 1).type(dtypeLong)
    return beamsearch.get_hypothesis(ends)


def beamsearch_tour_nodes_shortest(y_pred_edges, x_edges_values, beam_size, batch_size, num_nodes,
                                   dtypeFloat, dtypeLong, probs_type='raw', random_start=False):
    """Beamsearch with shortest tour heuristic."""
    if probs_type == 'raw':
        y = F.softmax(y_pred_edges, dim=3)[:, :, :, 1]
    elif probs_type == 'logits':
        y = F.log_softmax(y_pred_edges, dim=3)[:, :, :, 1]
        y[y == 0] = -1e-20
    beamsearch = Beamsearch(beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type, random_start)
    trans_probs = y.gather(1, beamsearch.get_current_state())
    for step in range(num_nodes - 1):
        beamsearch.advance(trans_probs)
        trans_probs = y.gather(1, beamsearch.get_current_state())
    ends = torch.zeros(batch_size, 1).type(dtypeLong)
    shortest_tours = beamsearch.get_hypothesis(ends)
    shortest_lens = [1e6] * len(shortest_tours)
    for idx in range(len(shortest_tours)):
        shortest_lens[idx] = tour_nodes_to_tour_len(shortest_tours[idx].cpu().numpy(),
                                                    x_edges_values[idx].cpu().numpy())
    for pos in range(1, beam_size):
        ends = pos * torch.ones(batch_size, 1).type(dtypeLong)
        hyp_tours = beamsearch.get_hypothesis(ends)
        for idx in range(len(hyp_tours)):
            hyp_nodes = hyp_tours[idx].cpu().numpy()
            hyp_len = tour_nodes_to_tour_len(hyp_nodes, x_edges_values[idx].cpu().numpy())
            if hyp_len < shortest_lens[idx] and is_valid_tour(hyp_nodes, num_nodes):
                shortest_tours[idx] = hyp_tours[idx]
                shortest_lens[idx] = hyp_len
    return shortest_tours


def update_learning_rate(optimizer, lr):
    """Updates learning rate for given optimizer."""
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer

## 8. GCN Model

In [None]:
class ResidualGatedGCNModel(nn.Module):
    """Residual Gated GCN Model for outputting predictions as edge adjacency matrices.
    Reference: https://arxiv.org/pdf/1711.07553v2.pdf
    """

    def __init__(self, config, dtypeFloat, dtypeLong):
        super(ResidualGatedGCNModel, self).__init__()
        self.dtypeFloat = dtypeFloat
        self.dtypeLong = dtypeLong
        self.num_nodes = config.num_nodes
        self.node_dim = config.node_dim
        self.voc_nodes_in = config['voc_nodes_in']
        self.voc_nodes_out = config['num_nodes']
        self.voc_edges_in = config['voc_edges_in']
        self.voc_edges_out = config['voc_edges_out']
        self.hidden_dim = config['hidden_dim']
        self.num_layers = config['num_layers']
        self.mlp_layers = config['mlp_layers']
        self.aggregation = config['aggregation']
        # Embeddings
        self.nodes_coord_embedding = nn.Linear(self.node_dim, self.hidden_dim, bias=False)
        self.edges_values_embedding = nn.Linear(1, self.hidden_dim // 2, bias=False)
        self.edges_embedding = nn.Embedding(self.voc_edges_in, self.hidden_dim // 2)
        # GCN Layers
        gcn_layers = []
        for layer in range(self.num_layers):
            gcn_layers.append(ResidualGatedGCNLayer(self.hidden_dim, self.aggregation))
        self.gcn_layers = nn.ModuleList(gcn_layers)
        # MLP classifier
        self.mlp_edges = MLP(self.hidden_dim, self.voc_edges_out, self.mlp_layers)

    def forward(self, x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw):
        # Embeddings
        x = self.nodes_coord_embedding(x_nodes_coord)
        e_vals = self.edges_values_embedding(x_edges_values.unsqueeze(3))
        e_tags = self.edges_embedding(x_edges)
        e = torch.cat((e_vals, e_tags), dim=3)
        # GCN layers
        for layer in range(self.num_layers):
            x, e = self.gcn_layers[layer](x, e)
        # MLP classifier
        y_pred_edges = self.mlp_edges(e)
        # Loss
        edge_cw = torch.Tensor(edge_cw).type(self.dtypeFloat)
        loss = loss_edges(y_pred_edges, y_edges, edge_cw)
        return y_pred_edges, loss

## 9. Plot Utilities

In [None]:
def plot_tsp(p, x_coord, W, W_val, W_target, title="default"):
    """Plot TSP tours."""
    def _edges_to_node_pairs(W):
        pairs = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] == 1:
                    pairs.append((r, c))
        return pairs

    G = nx.from_numpy_array(W_val)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    adj_pairs = _edges_to_node_pairs(W)
    target_pairs = _edges_to_node_pairs(W_target)
    colors = ['g'] + ['b'] * (len(x_coord) - 1)
    nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50)
    nx.draw_networkx_edges(G, pos, edgelist=adj_pairs, alpha=0.3, width=0.5)
    nx.draw_networkx_edges(G, pos, edgelist=target_pairs, alpha=1, width=1, edge_color='r')
    p.set_title(title)
    return p


def plot_tsp_heatmap(p, x_coord, W_val, W_pred, title="default"):
    """Plot predicted TSP tours with edge strength as confidence."""
    def _edges_to_node_pairs(W):
        pairs = []
        edge_preds = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] > 0.25:
                    pairs.append((r, c))
                    edge_preds.append(W[r][c])
        return pairs, edge_preds

    G = nx.from_numpy_array(W_val)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    node_pairs, edge_color = _edges_to_node_pairs(W_pred)
    node_color = ['g'] + ['b'] * (len(x_coord) - 1)
    nx.draw_networkx_nodes(G, pos, node_color=node_color, node_size=50)
    nx.draw_networkx_edges(G, pos, edgelist=node_pairs, edge_color=edge_color, edge_cmap=plt.cm.Reds, width=0.75)
    p.set_title(title)
    return p


def plot_predictions_beamsearch(x_nodes_coord, x_edges, x_edges_values, y_edges, y_pred_edges, bs_nodes, num_plots=3):
    """Plots groundtruth TSP tour vs. predicted tours (with beamsearch)."""
    y = F.softmax(y_pred_edges, dim=3)
    y_bins = y.argmax(dim=3)
    y_probs = y[:, :, :, 1]
    for f_idx, idx in enumerate(np.random.choice(len(y), min(num_plots, len(y)), replace=False)):
        f = plt.figure(f_idx, figsize=(15, 5))
        x_coord = x_nodes_coord[idx].cpu().numpy()
        W = x_edges[idx].cpu().numpy()
        W_val = x_edges_values[idx].cpu().numpy()
        W_target = y_edges[idx].cpu().numpy()
        W_sol_probs = y_probs[idx].cpu().numpy()
        W_bs = tour_nodes_to_W(bs_nodes[idx].cpu().numpy())
        plt1 = f.add_subplot(131)
        plot_tsp(plt1, x_coord, W, W_val, W_target, 'GT: {:.3f}'.format(W_to_tour_len(W_target, W_val)))
        plt2 = f.add_subplot(132)
        plot_tsp_heatmap(plt2, x_coord, W_val, W_sol_probs, 'Heatmap')
        plt3 = f.add_subplot(133)
        plot_tsp(plt3, x_coord, W, W_val, W_bs, 'BS: {:.3f}'.format(W_to_tour_len(W_bs, W_val)))
        plt.show()

## 10. Configuration de l'exp√©rience

**Modifiez les param√®tres ci-dessous selon vos besoins** (TSP10, TSP20, etc.)

In [None]:
# ============================================================
# PARAMETRES A MODIFIER SELON VOTRE EXPERIENCE
# ============================================================
# Choisir: 'tsp10' ou 'tsp20'
TSP_SIZE = 'tsp10'

# Chemin vers les donn√©es dans Kaggle (adaptez si n√©cessaire)
# Les donn√©es sont typiquement dans /kaggle/input/<nom-dataset>/
DATA_DIR = '/kaggle/input'  # Dossier racine des donn√©es

# Trouver automatiquement les fichiers de donn√©es
print("Fichiers disponibles dans /kaggle/input:")
for root, dirs, files in os.walk('/kaggle/input'):
    for f in files:
        if 'tsp' in f.lower():
            print(f"  {os.path.join(root, f)}")

In [None]:
# ============================================================
# DETECTION AUTOMATIQUE DES CHEMINS DES DONNEES
# ============================================================
def find_data_file(pattern):
    """Cherche un fichier de donn√©es dans /kaggle/input/"""
    matches = glob.glob(f'/kaggle/input/**/{pattern}', recursive=True)
    if matches:
        return matches[0]
    # Essayer aussi sans sous-dossier
    matches = glob.glob(f'/kaggle/input/{pattern}')
    if matches:
        return matches[0]
    raise FileNotFoundError(f"Fichier '{pattern}' non trouv√© dans /kaggle/input/")

if TSP_SIZE == 'tsp10':
    config_dict = {
        "expt_name": "tsp10",
        "gpu_id": "0,1",
        "train_filepath": find_data_file("tsp10_train_concorde.txt"),
        "val_filepath": find_data_file("tsp10_val_concorde.txt"),
        "test_filepath": find_data_file("tsp10_test_concorde.txt"),
        "num_nodes": 10,
        "num_neighbors": -1,
        "node_dim": 2,
        "voc_nodes_in": 2,
        "voc_nodes_out": 2,
        "voc_edges_in": 3,
        "voc_edges_out": 2,
        "beam_size": 1280,
        "hidden_dim": 300,
        "num_layers": 30,
        "mlp_layers": 3,
        "aggregation": "mean",
        "max_epochs": 1500,
        "val_every": 5,
        "test_every": 100,
        "batch_size": 20,
        "batches_per_epoch": 500,
        "accumulation_steps": 1,
        "learning_rate": 0.001,
        "decay_rate": 1.01
    }
elif TSP_SIZE == 'tsp20':
    config_dict = {
        "expt_name": "tsp20",
        "gpu_id": "0,1",
        "train_filepath": find_data_file("tsp20_train_concorde.txt"),
        "val_filepath": find_data_file("tsp20_val_concorde.txt"),
        "test_filepath": find_data_file("tsp20_test_concorde.txt"),
        "num_nodes": 20,
        "num_neighbors": -1,
        "node_dim": 2,
        "voc_nodes_in": 2,
        "voc_nodes_out": 2,
        "voc_edges_in": 3,
        "voc_edges_out": 2,
        "beam_size": 1280,
        "hidden_dim": 300,
        "num_layers": 30,
        "mlp_layers": 3,
        "aggregation": "mean",
        "max_epochs": 1500,
        "val_every": 5,
        "test_every": 100,
        "batch_size": 20,
        "batches_per_epoch": 500,
        "accumulation_steps": 1,
        "learning_rate": 0.001,
        "decay_rate": 1.01
    }

config = Settings(config_dict)
print(f"Configuration charg√©e pour {TSP_SIZE}:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 11. Configuration GPU (Double GPU)

In [None]:
# Configuration pour utiliser les 2 GPU Kaggle
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)  # "0,1"

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"CUDA disponible - {num_gpus} GPU(s) d√©tect√©(s)")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)} - "
              f"{torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")
    dtypeFloat = torch.cuda.FloatTensor
    dtypeLong = torch.cuda.LongTensor
    torch.cuda.manual_seed(1)
else:
    print("CUDA non disponible, utilisation du CPU")
    dtypeFloat = torch.FloatTensor
    dtypeLong = torch.LongTensor
    torch.manual_seed(1)

## 12. Exploration des donn√©es

In [None]:
num_nodes = config.num_nodes
num_neighbors = config.num_neighbors
batch_size = config.batch_size
train_filepath = config.train_filepath

dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, train_filepath)
print(f"Nombre de batches de taille {batch_size}: {dataset.max_iter}")

t = time.time()
batch = next(iter(dataset))
print(f"G√©n√©ration du batch: {time.time() - t:.3f} sec")

print(f"\nFormes des tenseurs:")
print(f"  edges:         {batch.edges.shape}")
print(f"  edges_values:  {batch.edges_values.shape}")
print(f"  edges_targets: {batch.edges_target.shape}")
print(f"  nodes:         {batch.nodes.shape}")
print(f"  nodes_target:  {batch.nodes_target.shape}")
print(f"  nodes_coord:   {batch.nodes_coord.shape}")
print(f"  tour_nodes:    {batch.tour_nodes.shape}")
print(f"  tour_len:      {batch.tour_len.shape}")

# Visualiser un exemple
idx = 0
f = plt.figure(figsize=(5, 5))
a = f.add_subplot(111)
plot_tsp(a, batch.nodes_coord[idx], batch.edges[idx], batch.edges_values[idx], batch.edges_target[idx])
plt.title(f"Exemple TSP{num_nodes} - Tour length: {batch.tour_len[idx]:.3f}")
plt.show()

## 13. Training & Evaluation Functions

In [None]:
def train_one_epoch(net, optimizer, config, master_bar):
    """Train for one epoch."""
    net.train()

    # Afficher le device utilis√©
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(f"üöÄ Training sur {torch.cuda.device_count()} GPU(s): {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
    else:
        print("‚ö†Ô∏è Training sur CPU")

    num_nodes = config.num_nodes
    num_neighbors = config.num_neighbors
    batch_size = config.batch_size
    batches_per_epoch = config.batches_per_epoch
    accumulation_steps = config.accumulation_steps
    train_filepath = config.train_filepath

    dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, train_filepath)
    if batches_per_epoch != -1:
        batches_per_epoch = min(batches_per_epoch, dataset.max_iter)
    else:
        batches_per_epoch = dataset.max_iter

    dataset = iter(dataset)
    edge_cw = None

    running_loss = 0.0
    running_pred_tour_len = 0.0
    running_gt_tour_len = 0.0
    running_nb_data = 0
    running_nb_batch = 0

    start_epoch = time.time()
    for batch_num in progress_bar(range(batches_per_epoch), parent=master_bar):
        try:
            batch = next(dataset)
        except StopIteration:
            break

        # Cr√©er les tenseurs et les d√©placer vers GPU si disponible
        x_edges = torch.LongTensor(batch.edges).type(dtypeLong)
        x_edges_values = torch.FloatTensor(batch.edges_values).type(dtypeFloat)
        x_nodes = torch.LongTensor(batch.nodes).type(dtypeLong)
        x_nodes_coord = torch.FloatTensor(batch.nodes_coord).type(dtypeFloat)
        y_edges = torch.LongTensor(batch.edges_target).type(dtypeLong)
        y_nodes = torch.LongTensor(batch.nodes_target).type(dtypeLong)

        if torch.cuda.is_available():
            x_edges = x_edges.cuda()
            x_edges_values = x_edges_values.cuda()
            x_nodes = x_nodes.cuda()
            x_nodes_coord = x_nodes_coord.cuda()
            y_edges = y_edges.cuda()
            y_nodes = y_nodes.cuda()

        if type(edge_cw) != torch.Tensor:
            edge_labels = y_edges.cpu().numpy().flatten()
            edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        loss = loss.mean()
        loss = loss / accumulation_steps
        loss.backward()

        if (batch_num + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        pred_tour_len = mean_tour_len_edges(x_edges_values, y_preds)
        gt_tour_len = np.mean(batch.tour_len)

        running_nb_data += batch_size
        running_loss += batch_size * loss.data.item() * accumulation_steps
        running_pred_tour_len += batch_size * pred_tour_len
        running_gt_tour_len += batch_size * gt_tour_len
        running_nb_batch += 1

        result = ('loss:{loss:.4f} pred:{pred:.3f} gt:{gt:.3f}'.format(
            loss=running_loss / running_nb_data,
            pred=running_pred_tour_len / running_nb_data,
            gt=running_gt_tour_len / running_nb_data))
        master_bar.child.comment = result

    loss = running_loss / running_nb_data
    pred_tour_len = running_pred_tour_len / running_nb_data
    gt_tour_len = running_gt_tour_len / running_nb_data

    return time.time() - start_epoch, loss, 0, 0, 0, pred_tour_len, gt_tour_len


def test(net, config, master_bar, mode='test'):
    """Evaluate on validation or test set."""
    net.eval()

    # Afficher le device utilis√©
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(f"üöÄ Evaluation ({mode}) sur {torch.cuda.device_count()} GPU(s): {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
    else:
        print(f"‚ö†Ô∏è Evaluation ({mode}) sur CPU")

    num_nodes = config.num_nodes
    num_neighbors = config.num_neighbors
    batch_size = config.batch_size
    beam_size = config.beam_size
    val_filepath = config.val_filepath
    test_filepath = config.test_filepath

    if mode == 'val':
        dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size=batch_size, filepath=val_filepath)
    elif mode == 'test':
        dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size=batch_size, filepath=test_filepath)
    batches_per_epoch = dataset.max_iter

    dataset = iter(dataset)
    edge_cw = None

    running_loss = 0.0
    running_pred_tour_len = 0.0
    running_gt_tour_len = 0.0
    running_nb_data = 0
    running_nb_batch = 0

    with torch.no_grad():
        start_test = time.time()
        for batch_num in progress_bar(range(batches_per_epoch), parent=master_bar):
            try:
                batch = next(dataset)
            except StopIteration:
                break

            # Cr√©er les tenseurs et les d√©placer vers GPU si disponible
            x_edges = torch.LongTensor(batch.edges).type(dtypeLong)
            x_edges_values = torch.FloatTensor(batch.edges_values).type(dtypeFloat)
            x_nodes = torch.LongTensor(batch.nodes).type(dtypeLong)
            x_nodes_coord = torch.FloatTensor(batch.nodes_coord).type(dtypeFloat)
            y_edges = torch.LongTensor(batch.edges_target).type(dtypeLong)
            y_nodes = torch.LongTensor(batch.nodes_target).type(dtypeLong)

            if torch.cuda.is_available():
                x_edges = x_edges.cuda()
                x_edges_values = x_edges_values.cuda()
                x_nodes = x_nodes.cuda()
                x_nodes_coord = x_nodes_coord.cuda()
                y_edges = y_edges.cuda()
                y_nodes = y_nodes.cuda()

            if type(edge_cw) != torch.Tensor:
                edge_labels = y_edges.cpu().numpy().flatten()
                edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

            y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
            loss = loss.mean()

            if mode == 'val':
                bs_nodes = beamsearch_tour_nodes(
                    y_preds, beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type='logits')
            elif mode == 'test':
                bs_nodes = beamsearch_tour_nodes_shortest(
                    y_preds, x_edges_values, beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type='logits')

            pred_tour_len = mean_tour_len_nodes(x_edges_values, bs_nodes)
            gt_tour_len = np.mean(batch.tour_len)

            running_nb_data += batch_size
            running_loss += batch_size * loss.data.item()
            running_pred_tour_len += batch_size * pred_tour_len
            running_gt_tour_len += batch_size * gt_tour_len
            running_nb_batch += 1

            result = ('loss:{loss:.4f} pred:{pred:.3f} gt:{gt:.3f}'.format(
                loss=running_loss / running_nb_data,
                pred=running_pred_tour_len / running_nb_data,
                gt=running_gt_tour_len / running_nb_data))
            master_bar.child.comment = result

    loss = running_loss / running_nb_data
    pred_tour_len = running_pred_tour_len / running_nb_data
    gt_tour_len = running_gt_tour_len / running_nb_data

    return time.time() - start_test, loss, 0, 0, 0, pred_tour_len, gt_tour_len


def metrics_to_str(epoch, time, learning_rate, loss, err_edges, err_tour, err_tsp, pred_tour_len, gt_tour_len):
    result = ('epoch:{epoch:0>2d} '
              'time:{time:.1f}h '
              'lr:{learning_rate:.2e} '
              'loss:{loss:.4f} '
              'pred_tour_len:{pred_tour_len:.3f} '
              'gt_tour_len:{gt_tour_len:.3f}'.format(
                  epoch=epoch, time=time / 3600,
                  learning_rate=learning_rate, loss=loss,
                  pred_tour_len=pred_tour_len, gt_tour_len=gt_tour_len))
    return result

## 14. Pipeline d'entra√Ænement complet (Double GPU avec DataParallel)

In [None]:
def main(config):
    """Full training pipeline with multi-GPU support."""
    # Instancier le mod√®le
    net = ResidualGatedGCNModel(config, dtypeFloat, dtypeLong)

    # D√©placer explicitement vers GPU AVANT DataParallel
    if torch.cuda.is_available():
        net = net.cuda()
        net = nn.DataParallel(net)  # DataParallel APR√àS .cuda()
        print(f"Mod√®le distribu√© sur {torch.cuda.device_count()} GPU(s)")
    else:
        print("Utilisation du CPU")

    print(net)

    # Nombre de param√®tres
    nb_param = sum(p.numel() for p in net.parameters())
    print(f'Nombre de param√®tres: {nb_param:,}')

    # Cr√©er le r√©pertoire de logs
    log_dir = f"/kaggle/working/logs/{config.expt_name}/"
    os.makedirs(log_dir, exist_ok=True)
    json.dump(dict(config), open(f"{log_dir}/config.json", "w"), indent=4)
    writer = SummaryWriter(log_dir)

    # Param√®tres d'entra√Ænement
    max_epochs = config.max_epochs
    val_every = config.val_every
    test_every = config.test_every
    learning_rate = config.learning_rate
    decay_rate = config.decay_rate
    val_loss_old = 1e6
    best_pred_tour_len = 1e6

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    print(optimizer)

    epoch_bar = master_bar(range(max_epochs))
    for epoch in epoch_bar:
        writer.add_scalar('learning_rate', learning_rate, epoch)

        # Train
        train_time, train_loss, _, _, _, train_pred_tour_len, train_gt_tour_len = \
            train_one_epoch(net, optimizer, config, epoch_bar)
        epoch_bar.write('t: ' + metrics_to_str(epoch, train_time, learning_rate, train_loss,
                                                0, 0, 0, train_pred_tour_len, train_gt_tour_len))
        writer.add_scalar('loss/train_loss', train_loss, epoch)
        writer.add_scalar('pred_tour_len/train_pred_tour_len', train_pred_tour_len, epoch)
        writer.add_scalar('optimality_gap/train_opt_gap', train_pred_tour_len / train_gt_tour_len - 1, epoch)

        if epoch % val_every == 0 or epoch == max_epochs - 1:
            # Validate
            val_time, val_loss, _, _, _, val_pred_tour_len, val_gt_tour_len = \
                test(net, config, epoch_bar, mode='val')
            epoch_bar.write('v: ' + metrics_to_str(epoch, val_time, learning_rate, val_loss,
                                                    0, 0, 0, val_pred_tour_len, val_gt_tour_len))
            writer.add_scalar('loss/val_loss', val_loss, epoch)
            writer.add_scalar('pred_tour_len/val_pred_tour_len', val_pred_tour_len, epoch)
            writer.add_scalar('optimality_gap/val_opt_gap', val_pred_tour_len / val_gt_tour_len - 1, epoch)

            # Sauvegarder le meilleur mod√®le
            if val_pred_tour_len < best_pred_tour_len:
                best_pred_tour_len = val_pred_tour_len
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                }, log_dir + "best_val_checkpoint.tar")

            # Mise √† jour du learning rate
            if val_loss > 0.99 * val_loss_old:
                learning_rate /= decay_rate
                optimizer = update_learning_rate(optimizer, learning_rate)
            val_loss_old = val_loss

        if epoch % test_every == 0 or epoch == max_epochs - 1:
            # Test
            test_time, test_loss, _, _, _, test_pred_tour_len, test_gt_tour_len = \
                test(net, config, epoch_bar, mode='test')
            epoch_bar.write('T: ' + metrics_to_str(epoch, test_time, learning_rate, test_loss,
                                                    0, 0, 0, test_pred_tour_len, test_gt_tour_len))
            writer.add_scalar('loss/test_loss', test_loss, epoch)
            writer.add_scalar('pred_tour_len/test_pred_tour_len', test_pred_tour_len, epoch)

        # Checkpoint √† chaque epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss if 'val_loss' in dir() else 0,
        }, log_dir + "last_train_checkpoint.tar")

        # Checkpoint toutes les 250 epochs
        if epoch != 0 and (epoch % 250 == 0 or epoch == max_epochs - 1):
            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss if 'val_loss' in dir() else 0,
            }, log_dir + f"checkpoint_epoch{epoch}.tar")

    writer.close()
    return net

## 15. Lancer l'entra√Ænement

In [None]:
net = main(config)

## 16. Charger le meilleur checkpoint et √©valuer

In [None]:
# Charger le meilleur checkpoint
log_dir = f"/kaggle/working/logs/{config.expt_name}/"
if torch.cuda.is_available():
    checkpoint = torch.load(log_dir + "best_val_checkpoint.tar")
else:
    checkpoint = torch.load(log_dir + "best_val_checkpoint.tar", map_location='cpu')

net.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
val_loss = checkpoint['val_loss']
print(f"Checkpoint charg√© depuis l'epoch {epoch}")
print(f"  Train loss: {train_loss:.4f}")
print(f"  Val loss: {val_loss:.4f}")

## 17. Visualisation des pr√©dictions

In [None]:
net.eval()

viz_batch_size = 10
num_nodes = config.num_nodes
num_neighbors = config.num_neighbors
beam_size = config.beam_size
test_filepath = config.test_filepath

dataset = iter(GoogleTSPReader(num_nodes, num_neighbors, viz_batch_size, test_filepath))
batch = next(dataset)

with torch.no_grad():
    x_edges = Variable(torch.LongTensor(batch.edges).type(dtypeLong), requires_grad=False)
    x_edges_values = Variable(torch.FloatTensor(batch.edges_values).type(dtypeFloat), requires_grad=False)
    x_nodes = Variable(torch.LongTensor(batch.nodes).type(dtypeLong), requires_grad=False)
    x_nodes_coord = Variable(torch.FloatTensor(batch.nodes_coord).type(dtypeFloat), requires_grad=False)
    y_edges = Variable(torch.LongTensor(batch.edges_target).type(dtypeLong), requires_grad=False)

    edge_labels = y_edges.cpu().numpy().flatten()
    edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)
    print(f"Class weights: {edge_cw}")

    y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
    loss = loss.mean()

    bs_nodes = beamsearch_tour_nodes_shortest(
        y_preds, x_edges_values, beam_size, viz_batch_size, num_nodes,
        dtypeFloat, dtypeLong, probs_type='logits')

    pred_tour_len = mean_tour_len_nodes(x_edges_values, bs_nodes)
    gt_tour_len = np.mean(batch.tour_len)
    print(f"Tour pr√©dit (moyenne): {pred_tour_len:.3f}")
    print(f"Tour GT (moyenne):     {gt_tour_len:.3f}")
    print(f"Gap d'optimalit√©:      {(pred_tour_len/gt_tour_len - 1)*100:.2f}%")

    # V√©rification de validit√©
    for idx, nodes in enumerate(bs_nodes):
        if not is_valid_tour(nodes, num_nodes):
            print(f"  Tour invalide #{idx}: {nodes}")

    # Visualiser
    plot_predictions_beamsearch(x_nodes_coord, x_edges, x_edges_values, y_edges,
                                y_preds, bs_nodes, num_plots=viz_batch_size)

## 18. √âvaluation finale (Greedy, Beam Search, BS*)

In [None]:
learning_rate = config.learning_rate
epoch_bar = master_bar(range(epoch + 1, epoch + 2))
config.batch_size = 200  # Plus grand batch pour √©valuation

# Utiliser le test set pour toutes les √©valuations
config_eval = Settings(dict(config))
config_eval.val_filepath = config.test_filepath

for ep in epoch_bar:
    # Greedy search (beam_size=1)
    config_eval.beam_size = 1
    t = time.time()
    val_time, val_loss, _, _, _, val_pred_tour_len, val_gt_tour_len = test(net, config_eval, epoch_bar, mode='val')
    print(f"Greedy time: {time.time()-t:.1f}s")
    epoch_bar.write('Greedy: ' + metrics_to_str(ep, val_time, learning_rate, val_loss,
                                                 0, 0, 0, val_pred_tour_len, val_gt_tour_len))

    # Vanilla beam search
    config_eval.beam_size = 1280
    t = time.time()
    val_time, val_loss, _, _, _, val_pred_tour_len, val_gt_tour_len = test(net, config_eval, epoch_bar, mode='val')
    print(f"BS time: {time.time()-t:.1f}s")
    epoch_bar.write('BS:     ' + metrics_to_str(ep, val_time, learning_rate, val_loss,
                                                 0, 0, 0, val_pred_tour_len, val_gt_tour_len))

    # Beam search with shortest tour heuristic
    config_eval.beam_size = 1280
    t = time.time()
    test_time, test_loss, _, _, _, test_pred_tour_len, test_gt_tour_len = test(net, config_eval, epoch_bar, mode='test')
    print(f"BS* time: {time.time()-t:.1f}s")
    epoch_bar.write('BS*:    ' + metrics_to_str(ep, test_time, learning_rate, test_loss,
                                                 0, 0, 0, test_pred_tour_len, test_gt_tour_len))

---
**Notes:**
- Le mod√®le utilise `nn.DataParallel` pour distribuer automatiquement les batches sur les 2 GPU Kaggle.
- Les checkpoints sont sauvegard√©s dans `/kaggle/working/logs/<expt_name>/`.
- Modifiez `TSP_SIZE` dans la cellule 10 pour basculer entre TSP10 et TSP20.
- Les donn√©es doivent √™tre dans `/kaggle/input/` (ajoutez le dataset via l'interface Kaggle).