# VRP GraphNet

model inputs from the paper:

| Variable             | Meaning                           | Dimensions                |
|----------------------|-----------------------------------|---------------------------|
| batch_edges          | Adj matrix special connections*   | B x num_nodes x num_nodes |
| batch_edges_values   | Distance Matrix                   | B x num_nodes x num_nodes |
| batch_edges_target   | Target adj matrix                 | B x num_nodes x num_nodes |
| batch_nodes          | Ones vector                       | B x num_nodes             |
| batch_nodes_coord    | Coordinates                       | B x num_nodes x 2         |
| *batch_nodes_target* | Value represents ordering in tour | B x num_nodes             |


*special connections:
* 1 - k-nearest neighbour
* 2 - self connections
* 0 - otherwise

In [None]:
try:
    from google.colab import drive

    drive.mount('/content/gdrive')

    %cd gdrive/My Drive/machine-learning-course
    %pip install -r requirements-colab.txt
    %load_ext tensorboard
    %tensorboard --logdir runs
    IN_COLAB = True
except:
    IN_COLAB = False

In [None]:
from pathlib import Path

import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import utils.beam_search as beam_search
from model import GraphNet
from utils import load_config, get_metrics, get_device, save_checkpoint, _n
from utils.data import load_and_split_dataset, process_datasets, sparse_matrix_from_routes, distance_from_sparse_matrix

sns.set_theme()

## Load datasets

In [None]:
dsets = load_and_split_dataset("data/vrp_20_3s_random_depot.pkl", test_size=1000)
train_dataset, test_dataset = process_datasets(dsets, k=6)

print(len(train_dataset), len(test_dataset))

In [None]:
edge_class_weights = train_dataset.class_weights()
edge_class_weights

## Basic Config

In [None]:
device = get_device()
print("Device", device)

In [None]:
default_config = load_config(hidden_dim=32, num_gcn_layers=5, num_mlp_layers=3,
                             learning_rate=0.001, train_batch_size=64, test_batch_size=64, num_epochs=50)
default_config

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=default_config.train_batch_size,
                              shuffle=True)

model = GraphNet(default_config).to(device)

## Test Forward Pass

In [None]:
features, _ = next(iter(train_dataloader))

y_pred = model.forward(features["node_features"].to(device),
                       features["dist_matrix"].to(device),
                       features["edge_feat_matrix"].to(device))

y_pred.shape

## Validation loop

In [None]:
def count_violations(tours, batch_node_features):
    violations = []

    for i, tour in enumerate(tours):
        violations_count = 0
        capacity = 0
        for j in range(len(tour) + 1):
            node = tour[j] if j < len(tour) else 0

            if node == 0:
                if capacity > 1:
                    violations_count += 1
                capacity = 0
            else:
                capacity += batch_node_features[i, node, 2]

        violations.append(violations_count)

    return np.array(violations)


def greedy_tour_lengths(y_preds, batch_dist_matrix, num_vehicles):
    # only keep the probability of selecting the edge
    y_preds = y_preds[..., 1]

    beamsearch = beam_search.BeamSearch(y_preds, beam_width=1, num_vehicles=num_vehicles)
    beamsearch.search()

    # get most probable tours (index = 0)
    tours = beamsearch.get_beam(0)

    sp_targets = sparse_matrix_from_routes(tours, batch_dist_matrix.size(-1))
    tour_lengths = distance_from_sparse_matrix(sp_targets, batch_dist_matrix)

    return tour_lengths, tours


def shortest_tour_lengths(y_preds, batch_dist_matrix, num_vehicles, beam_width=1024):
    # only keep the probability of selecting the edge
    y_preds = y_preds[..., 1]

    beamsearch = beam_search.BeamSearch(y_preds, beam_width=beam_width, num_vehicles=num_vehicles)
    beamsearch.search()

    shortest_tours = torch.zeros((batch_dist_matrix.size(0), len(beamsearch.next_nodes)))
    shortest_tour_distances = torch.full((batch_dist_matrix.size(0),), np.inf)

    for i in range(beamsearch.beam_width):
        tours = beamsearch.get_beam(i)

        # creating the sparse matrix is expensive
        sp_adj_matrix = sparse_matrix_from_routes(tours, batch_dist_matrix.size(-1))
        tour_lengths = distance_from_sparse_matrix(sp_adj_matrix, batch_dist_matrix)

        # keep the shortest tours
        condition = torch.lt(tour_lengths, shortest_tour_distances).unsqueeze(-1)
        shortest_tours = torch.where(condition, tours, shortest_tours)

        # keep the shortest tour lengths
        shortest_tour_distances = torch.minimum(shortest_tour_distances, tour_lengths)

    return shortest_tour_distances, shortest_tours


def eval_model(batch_node_features, batch_dist_matrix, batch_edge_features, model):
    model.eval()

    with torch.no_grad():
        preds = model(batch_node_features, batch_dist_matrix, batch_edge_features)
        preds = F.softmax(preds, dim=3)

        return preds


def validate(dataloader, model, criterion):
    running_loss = 0
    running_tour_lengths = []
    running_tour_violations = []
    targets = []
    predictions = []

    for batch_features, batch_targets in dataloader:
        batch_node_features = batch_features["node_features"].to(device)
        batch_dist_matrix = batch_features["dist_matrix"].to(device)
        batch_edge_features = batch_features["edge_feat_matrix"].to(device)
        batch_num_vehicles = batch_features["num_vehicles"].to(device)
        batch_targets = batch_targets.to(device)

        y_preds = eval_model(batch_node_features, batch_dist_matrix, batch_edge_features,
                             model=model)

        # Loss
        loss = get_loss(y_preds, batch_targets, criterion)
        running_loss += loss.item()

        # Tour lengths (mean per batch)
        for vehicles in torch.unique(batch_num_vehicles):
            mask = torch.eq(batch_num_vehicles, vehicles)

            tour_lengths, tours = greedy_tour_lengths(y_preds[mask].cpu(), batch_dist_matrix[mask].cpu(),
                                                      num_vehicles=vehicles)

            running_tour_violations.extend(count_violations(tours, batch_node_features[mask]))
            running_tour_lengths.extend(_n(tour_lengths))

        y_preds = y_preds.argmax(dim=3)
        y_preds = y_preds.cpu().numpy()

        targets.append(batch_targets.cpu().numpy())
        predictions.append(y_preds)

    targets = np.concatenate(targets)
    predictions = np.concatenate(predictions)
    mean_running_loss = running_loss / len(dataloader)

    running_tour_lengths = np.mean(running_tour_lengths)
    running_tour_violations = np.mean(running_tour_violations)

    return targets, predictions, mean_running_loss, running_tour_lengths, running_tour_violations

## Training Loop

In [None]:
def get_loss(preds, targets, criterion):
    preds_perm = preds.permute(0, 3, 1, 2)

    return criterion(preds_perm, targets)


def train_one_epoch(dataloader, model, optimizer, criterion):
    running_loss = 0

    model.train()

    for batch_idx, (batch_features, batch_targets) in enumerate(dataloader):
        optimizer.zero_grad()

        batch_node_features = batch_features["node_features"].to(device)
        batch_dist_matrix = batch_features["dist_matrix"].to(device)
        batch_edge_features = batch_features["edge_feat_matrix"].to(device)
        batch_targets = batch_targets.to(device)

        preds = model(batch_node_features, batch_dist_matrix, batch_edge_features)
        loss = get_loss(preds, batch_targets, criterion)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    return running_loss


def train(num_epochs, train_dl, test_dl, model, optimizer, criterion, writer):
    best_loss = np.inf

    for epoch in range(num_epochs):
        # Train
        running_loss = train_one_epoch(train_dl, model=model, optimizer=optimizer, criterion=criterion)

        # Losses
        epoch_loss = running_loss / len(train_dl)

        # Validation Metrics
        targets, predictions, validation_loss, tour_length, violations = validate(test_dl, model=model,
                                                                                  criterion=criterion)
        metrics = get_metrics(targets, predictions)

        writer.add_scalar("Metrics/accuracy", metrics.acc, epoch)
        writer.add_scalar("Metrics/bal. accuracy", metrics.bal_acc, epoch)
        writer.add_scalar("Metrics/precision", metrics.precision, epoch)
        writer.add_scalar("Metrics/recall", metrics.recall, epoch)
        writer.add_scalar("Metrics/f1 score", metrics.f1_score, epoch)
        writer.add_scalar("Metrics/tour length", tour_length, epoch)
        writer.add_scalar("Metrics/violations", violations, epoch)

        writer.add_scalar("Loss/train", epoch_loss, epoch)
        writer.add_scalar("Loss/test", validation_loss, epoch)

        # Save (validation) checkpoint
        if validation_loss < best_loss:
            best_loss = validation_loss
            save_checkpoint(writer.log_dir / "best_validation_loss_model.pt",
                            model=model, optimizer=optimizer,
                            epoch=epoch, config={**config}, train_loss=epoch_loss, test_loss=validation_loss)

        # Save (epoch) checkpoint
        save_checkpoint(writer.log_dir / "last_epoch_model.pt",
                        model=model, optimizer=optimizer,
                        epoch=epoch, config={**config}, train_loss=epoch_loss, test_loss=validation_loss)

        print(f'Epoch: {epoch:02d}, Loss: {epoch_loss:.4f}')

## Baseline Model

In [None]:
LOG_DIR = Path(f"runs/exp_baseline")

config = load_config(**default_config)
config.hidden_dim = 16
config.num_gcn_layers = 5
config.num_epochs = 20

train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=True)

torch.manual_seed(0)

edge_class_weights = train_dataset.class_weights().to(device)
model = GraphNet(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss(edge_class_weights)

writer = SummaryWriter(log_dir=LOG_DIR)

train(config.num_epochs, train_dl=train_dataloader, test_dl=test_dataloader,
      model=model, optimizer=optimizer, criterion=criterion, writer=writer)
writer.flush()
writer.close()

## Plot Results

In [None]:
# get random instances to plot
batch_features, batch_targets = next(iter(test_dataloader))

num_plots = 10

choices = np.random.choice(len(batch_targets), num_plots)

In [None]:
from utils.plot import plot_graph, plot_heatmap, plot_beam_search_tour
from utils.data import distance_from_adj_matrix
import matplotlib.pyplot as plt

batch_node_features = batch_features["node_features"].to(device)
batch_dist_matrix = batch_features["dist_matrix"].to(device)
batch_edge_features = batch_features["edge_feat_matrix"].to(device)
batch_num_vehicles = batch_features["num_vehicles"].to(device)
batch_targets = batch_targets.to(device)

preds = eval_model(batch_node_features, batch_dist_matrix, batch_edge_features,
                   model=model)

ground_truth_distance = distance_from_adj_matrix(batch_targets, batch_dist_matrix)
shortest_distance, tours = shortest_tour_lengths(preds.cpu(), batch_dist_matrix.cpu(), beam_width=1024,
                                                 num_vehicles=torch.max(batch_num_vehicles))

for i in choices:
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))

    plot_graph(_n(batch_node_features[i, :, :2]), _n(batch_targets[i]), ax=ax[0])
    plot_heatmap(_n(batch_node_features[i, :, :2]), _n(batch_targets[i]), _n(preds[i][..., 1]), ax=ax[1])
    plot_beam_search_tour(_n(batch_node_features[i, :, :2]), _n(batch_targets[i]), _n(tours[i]), ax=ax[2])

    ax[0].set_title(f"Ground truth ({ground_truth_distance[i]:.2f})")
    ax[1].set_title("Predictions")
    ax[2].set_title(f"Shortest tour ({shortest_distance[i]:.2f})")
    fig.tight_layout()

    plt.show()