# VRP GraphNet

model inputs from the paper:

| Variable             | Meaning                           | Dimensions                |
|----------------------|-----------------------------------|---------------------------|
| batch_edges          | Adj matrix special connections*   | B x num_nodes x num_nodes |
| batch_edges_values   | Distance Matrix                   | B x num_nodes x num_nodes |
| batch_edges_target   | Target adj matrix                 | B x num_nodes x num_nodes |
| batch_nodes          | Ones vector                       | B x num_nodes             |
| batch_nodes_coord    | Coordinates                       | B x num_nodes x 2         |
| *batch_nodes_target* | Value represents ordering in tour | B x num_nodes             |


*special connections:
* 1 - k-nearest neighbour
* 2 - self connections
* 0 - otherwise

In [None]:
try:
    from google.colab import drive

    drive.mount('/content/gdrive')

    %cd gdrive/My Drive/vrp-thesis
    %pip install -r requirements-colab.txt
    IN_COLAB = True
except:
    IN_COLAB = False

In [None]:
if IN_COLAB:
    %reload_ext tensorboard
    %tensorboard --logdir runs

In [None]:
from pathlib import Path

import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import GraphNet
from utils import load_config, get_metrics, get_device, save_checkpoint, _n, DotDict, load_checkpoint
from utils.beam_search import BeamSearch
from utils.data import load_and_split_dataset, process_datasets, sparse_matrix_from_routes, distance_from_sparse_matrix, distance_from_adj_matrix

sns.set_theme()

## Load datasets

In [None]:
dsets = load_and_split_dataset("data/vrp_20_3s_random_depot.pkl", test_size=500)
train_dataset, test_dataset = process_datasets(dsets, k=6)

print(len(train_dataset), len(test_dataset))

## Basic Config

In [None]:
device = get_device()
print("Device", device)

In [None]:
default_config = load_config(hidden_dim=32, num_gcn_layers=5, num_mlp_layers=3,
                             learning_rate=0.001, train_batch_size=128, test_batch_size=256, num_epochs=50)
default_config

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=default_config.train_batch_size,
                              shuffle=True)

model = GraphNet(default_config).to(device)

## Test Forward Pass

In [None]:
features, _ = next(iter(train_dataloader))

y_pred = model.forward(features["node_features"].to(device),
                       features["dist_matrix"].to(device),
                       features["edge_feat_matrix"].to(device))

y_pred.shape

## Validation loop

In [None]:
def adj_matrix_from_routes(routes, num_nodes):
    """
    Converts a batch of routes to a batch of adjacency matrices.
    :param routes: Batch of route
    :param num_nodes: Number of nodes
    :return: Batch of adjacency matrices
    """
    routes_rolled = np.roll(routes, -1)
    non_zero_indecies = np.stack((routes, routes_rolled), 2)

    matrix = np.zeros((routes.shape[0], num_nodes, num_nodes))

    for i, indecies in enumerate(non_zero_indecies):
        matrix[i, indecies[:, 0], indecies[:, 1]] = 1
        matrix[i, indecies[:, 1], indecies[:, 0]] = 1

    return matrix
    

def count_violations(tours, demands):
    """
    Count the number of violations per tour given the customer demands
    :param tours: (b, n) array of tours
    :param demands: (b, m) array of demands
    :return: (b,) array of count of violations
    """
    violations = np.zeros(tours.shape[0])

    for i in range(tours.shape[0]):
        count = 0
        running_load = 0

        for j in range(tours.shape[1]):
            node = tours[i, j]
            running_load += demands[i, node]

            if node == 0 or j == tours.shape[1] - 1:
                if running_load > 1:
                    count += 1
                running_load = 0

        violations[i] = count

    return violations


def shortest_valid_tour(y_preds, batch_dist_matrix, batch_node_features,
                        num_vehicles, beam_width=1024):
    # Move tensors to CPU for faster computation (due to loops and compare ops)
    y_preds = y_preds.cpu()
    batch_dist_matrix = batch_dist_matrix.cpu().numpy()
    batch_node_features = batch_node_features.cpu().numpy()

    y_preds = y_preds[..., 1]

    beamsearch = BeamSearch(y_preds, beam_width=beam_width, num_vehicles=num_vehicles)
    beamsearch.search()

    shortest_tour = np.zeros((beamsearch.batch_size, len(beamsearch.next_nodes)))
    shortest_tour_length = np.full((beamsearch.batch_size,), np.inf)
    max_violations = np.full((beamsearch.batch_size,), np.inf)

    for b in range(beamsearch.beam_width):
        # can probably be improved by moving to separate loop
        current_tour = beamsearch.get_beam(b)
        current_tour = current_tour.numpy()

        __adj_matrix = adj_matrix_from_routes(current_tour, batch_dist_matrix.shape[-1])
        tour_length = distance_from_adj_matrix(__adj_matrix, batch_dist_matrix)
        violations = count_violations(current_tour, batch_node_features[..., 2])

        for i in range(beamsearch.batch_size):
            # there are less violations, so we take it
            if violations[i] <  max_violations[i]:
                shortest_tour[i] = current_tour[i]
                shortest_tour_length[i] = tour_length[i]
                max_violations[i] = violations[i]

            # same number of violations, take the shorter distance
            elif violations[i] == max_violations[i]:
                if tour_length[i] < shortest_tour_length[i]:
                    shortest_tour[i] = current_tour[i]
                    shortest_tour_length[i] = tour_length[i]

    return shortest_tour_length, shortest_tour, max_violations

def probable_tour_lengths(y_preds, batch_dist_matrix, num_vehicles, beam_width=1024):
    # only keep the probability of selecting the edge
    y_preds = y_preds[..., 1]

    beamsearch = BeamSearch(y_preds, beam_width=beam_width, num_vehicles=num_vehicles)
    beamsearch.search()

    tours = beamsearch.get_beam(0)
    tours = tours.cpu().numpy()

    __adj_matrix = adj_matrix_from_routes(tours, batch_dist_matrix.shape[-1])
    tour_lengths = distance_from_adj_matrix(__adj_matrix, batch_dist_matrix)

    return tour_lengths, tours


def greedy_tour_lengths(y_preds, batch_dist_matrix, num_vehicles):
    # only keep the probability of selecting the edge
    y_preds = y_preds[..., 1]

    beamsearch = BeamSearch(y_preds, beam_width=1, num_vehicles=num_vehicles, allow_consecutive_visits=False)
    beamsearch.search()

    # get most probable tours (index = 0)
    tours = beamsearch.get_beam(0)
    tours = tours.cpu().numpy()

    __adj_matrix = adj_matrix_from_routes(tours, batch_dist_matrix.shape[-1])
    tour_lengths = distance_from_adj_matrix(__adj_matrix, batch_dist_matrix)

    return tour_lengths, tours


def eval_model(batch_node_features, batch_dist_matrix, batch_edge_features, model):
    model.eval()

    with torch.no_grad():
        preds = model(batch_node_features, batch_dist_matrix, batch_edge_features)
        preds = F.softmax(preds, dim=3)

        return preds


def validate(dataloader, model, criterion):
    running_loss = 0
    running_tour_lengths = []
    running_tour_violations = []
    targets = []
    predictions = []

    for batch_features, batch_targets in dataloader:
        batch_node_features = batch_features["node_features"].to(device)
        batch_dist_matrix = batch_features["dist_matrix"].to(device)
        batch_edge_features = batch_features["edge_feat_matrix"].to(device)
        # is this required??
        batch_num_vehicles = batch_features["num_vehicles"].to(device)
        batch_targets = batch_targets.to(device)

        y_preds = eval_model(batch_node_features, batch_dist_matrix, batch_edge_features,
                             model=model)

        # Loss
        loss = get_loss(y_preds, batch_targets, criterion)
        running_loss += loss.item()

        # Tour lengths (mean per batch)
        for vehicles in torch.unique(batch_num_vehicles):
            mask = batch_num_vehicles == vehicles

            tour_lengths, tours = greedy_tour_lengths(y_preds[mask], batch_dist_matrix[mask],
                                                      num_vehicles=vehicles)
            violations = count_violations(tours, batch_node_features[mask])

            running_tour_violations.extend(violations.cpu().numpy())
            running_tour_lengths.extend(tour_lengths.cpu().numpy())

        y_preds = y_preds.argmax(dim=3)
        y_preds = y_preds.cpu().numpy()

        targets.append(batch_targets.cpu().numpy())
        predictions.append(y_preds)

    targets = np.concatenate(targets)
    predictions = np.concatenate(predictions)
    mean_running_loss = running_loss / len(dataloader)

    running_tour_lengths = np.mean(running_tour_lengths)
    running_tour_violations = np.mean(running_tour_violations)

    return targets, predictions, mean_running_loss, running_tour_lengths, running_tour_violations

## Training Loop

In [None]:
def get_loss(preds, targets, criterion):
    preds_perm = preds.permute(0, 3, 1, 2)

    return criterion(preds_perm, targets)


def train_one_epoch(dataloader, model, optimizer, criterion):
    running_loss = 0

    model.train()

    for batch_idx, (batch_features, batch_targets) in enumerate(dataloader):
        optimizer.zero_grad()

        batch_node_features = batch_features["node_features"].to(device)
        batch_dist_matrix = batch_features["dist_matrix"].to(device)
        batch_edge_features = batch_features["edge_feat_matrix"].to(device)
        batch_targets = batch_targets.to(device)

        preds = model(batch_node_features, batch_dist_matrix, batch_edge_features)
        loss = get_loss(preds, batch_targets, criterion)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

    return running_loss


def train(num_epochs, train_dl, test_dl, model, optimizer, criterion, writer):
    best_loss = np.inf

    for epoch in range(num_epochs):
        # Train
        running_loss = train_one_epoch(train_dl, model=model, optimizer=optimizer, criterion=criterion)

        # Losses
        epoch_loss = running_loss / len(train_dl)

        # Validation Metrics
        targets, predictions, validation_loss, tour_length, violations = validate(test_dl, model=model,
                                                                                  criterion=criterion)
        metrics = get_metrics(targets, predictions)

        writer.add_scalar("Metrics/accuracy", metrics.acc, epoch)
        writer.add_scalar("Metrics/bal. accuracy", metrics.bal_acc, epoch)
        writer.add_scalar("Metrics/precision", metrics.precision, epoch)
        writer.add_scalar("Metrics/recall", metrics.recall, epoch)
        writer.add_scalar("Metrics/f1 score", metrics.f1_score, epoch)
        writer.add_scalar("Metrics/tour length", tour_length, epoch)
        writer.add_scalar("Metrics/violations", violations, epoch)

        writer.add_scalar("Loss/train", epoch_loss, epoch)
        writer.add_scalar("Loss/test", validation_loss, epoch)

        # Save (validation) checkpoint
        if validation_loss < best_loss:
            best_loss = validation_loss
            save_checkpoint(writer.log_dir / "best_validation_loss_model.pt",
                            model=model, optimizer=optimizer,
                            epoch=epoch, config={**config}, train_loss=epoch_loss, test_loss=validation_loss)

        # Save (epoch) checkpoint
        save_checkpoint(writer.log_dir / "last_epoch_model.pt",
                        model=model, optimizer=optimizer,
                        epoch=epoch, config={**config}, train_loss=epoch_loss, test_loss=validation_loss)

        print(f'Epoch: {epoch:02d}, Loss: {epoch_loss:.4f}')

## Baseline Model

In [None]:
LOG_DIR = Path(f"runs/exp_baseline_3")

config = load_config(**default_config)
# config.hidden_dim = 128
# config.gcn_layers = 10
config.num_epochs = 10

train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=True)

torch.manual_seed(0)

edge_class_weights = train_dataset.class_weights().to(device)
model = GraphNet(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss(edge_class_weights)

writer = SummaryWriter(log_dir=LOG_DIR)

train(config.num_epochs, train_dl=train_dataloader, test_dl=test_dataloader,
      model=model, optimizer=optimizer, criterion=criterion, writer=writer)
writer.flush()
writer.close()

## Plot Results

In [None]:
MODEL_PATH = Path("runs/exp_baseline_2")

checkpoint = load_checkpoint(MODEL_PATH / "last_epoch_model.pt")
config = DotDict(checkpoint['config'])
model = GraphNet(config).to(device)

model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=True)
batch_features, batch_targets = next(iter(test_dataloader))

In [None]:
batch_node_features = batch_features["node_features"].to(device)
batch_dist_matrix = batch_features["dist_matrix"].to(device)
batch_edge_features = batch_features["edge_feat_matrix"].to(device)
batch_num_vehicles = batch_features["num_vehicles"].to(device)
batch_targets = batch_targets.to(device)

preds = eval_model(batch_node_features, batch_dist_matrix, batch_edge_features,
                   model=model)

In [None]:
shortest_valid_tour(preds, batch_dist_matrix, batch_node_features, 4, 1024)

In [None]:
ground_truth_distance = distance_from_adj_matrix(batch_targets, batch_dist_matrix)
ground_truth_distance

In [None]:
choices = np.random.choice(len(batch_targets), 10)

In [None]:
from utils.plot import plot_graph, plot_heatmap, plot_beam_search_tour
from utils.data import distance_from_adj_matrix
import matplotlib.pyplot as plt

ground_truth_distance = distance_from_adj_matrix(batch_targets, batch_dist_matrix)
# shortest_distance, tours = shortest_tour_lengths(preds.cpu(), batch_dist_matrix.cpu(), beam_width=1024,
#                                                  num_vehicles=torch.max(batch_num_vehicles))

# for each 

for v in torch.unique(batch_num_vehicles):
    mask = batch_num_vehicles == v
    # route_distance, tours = probable_tour_lengths(preds, batch_dist_matrix, v)
    route_distance, tours = greedy_tour_lengths(preds, batch_dist_matrix, v)

    for i in np.random.choice(torch.sum(mask).cpu(), 3):
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        plot_graph(_n(batch_node_features[mask][i, :, :2]), _n(batch_targets[mask][i]), ax=ax[0])
        plot_heatmap(_n(batch_node_features[mask][i, :, :2]), _n(batch_targets[mask][i]), _n(preds[mask][i][..., 1]), ax=ax[1])
        plot_beam_search_tour(_n(batch_node_features[mask][i, :, :2]), _n(batch_targets[mask][i]), _n(tours[mask][i]), ax=ax[2])

        ax[0].set_title(f"Ground truth ({ground_truth_distance[mask][i]:.2f})")
        ax[1].set_title("Predictions")
        ax[2].set_title(f"Shortest tour ({route_distance[mask][i]:.2f})")
        fig.tight_layout()

        plt.show()