# Dataset Comparison Warcraft

This notebooks checks the optimal solutions for the Warcraft dataset.
It shows that the solutions from the original repository were not optimal for about 10% of the instances.

In [None]:
# Load the data
import sys
from torch.utils.data import DataLoader, Dataset
from typing import Literal
import os
import numpy as np
import sys
sys.path.insert(0, '/home/schaetz/robust-dfl')
from helpers import seed_all
from warcraft.comb_modules.dijkstra import get_solver
from dotenv import load_dotenv
from warcraft.Trainer.utils import shortest_pathsolution_np

In [None]:
load_dotenv()

In [None]:
seed = 0 
g = seed_all(seed)
generator = g
batch_size = 128
num_workers = 0
img_sizes = [12, 18, 24, 30]
neighbourhood_fn = "8-grid"
normalization = "zscore"
use_test_set = True
# TODO: Ajust this to your data directory
data_directory = ""

In [None]:
img_sizes = [24]

for img_size in img_sizes:
    data_dir = os.path.join(
        data_directory,
        f"{str(img_size)}x{str(img_size)}",
    )

    # Now define the solver for the shortest path
    solver = get_solver(neighbourhood_fn)

    train_prefix = "train"
    val_prefix = "val"
    test_prefix = "test"
    data_suffix = "maps"
    train_inputs = np.load(
        os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")
    ).astype(np.float32)
    train_inputs = train_inputs.transpose(0, 3, 1, 2)  # channel first

    val_inputs = np.load(
        os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")
    ).astype(np.float32)
    val_inputs = val_inputs.transpose(0, 3, 1, 2)  # channel first
    if use_test_set:
        test_inputs = np.load(
            os.path.join(data_dir, test_prefix + "_" + data_suffix + ".npy")
        ).astype(np.float32)
        test_inputs = test_inputs.transpose(0, 3, 1, 2)  # channel first

    assert normalization == "zscore", "This dataset only supports zscore normalization"
    mean, std = (
        np.mean(train_inputs, axis=(0, 2, 3), dtype=np.float64, keepdims=True),
        np.std(train_inputs, axis=(0, 2, 3), dtype=np.float64, keepdims=True),
    )
    # Also store the mean and std for denormalization
    train_inputs -= mean
    train_inputs /= std
    val_inputs -= mean
    val_inputs /= std
    if use_test_set:
        test_inputs -= mean
        test_inputs /= std

    train_true_weights = np.load(
        os.path.join(data_dir, train_prefix + "_vertex_weights.npy")
    ).astype(np.float32)

    val_true_weights = np.load(
        os.path.join(data_dir, val_prefix + "_vertex_weights.npy")
    ).astype(np.float32)

    val_full_images = np.load(os.path.join(data_dir, val_prefix + "_maps.npy"))
    if use_test_set:
        test_true_weights = np.load(
            os.path.join(data_dir, test_prefix + "_vertex_weights.npy")
        ).astype(np.float32)
    
    # Now load the optimal solutions 
    train_labels = np.load(os.path.join(data_dir, train_prefix + "_shortest_paths.npy"))
    val_labels = np.load(os.path.join(data_dir, val_prefix + "_shortest_paths.npy"))
    if use_test_set:
        test_labels = np.load(os.path.join(data_dir, test_prefix + "_shortest_paths.npy"))
    
    # Now also compute the optimal solutions using the solver
    train_labels_solver = shortest_pathsolution_np(solver, train_true_weights)
    val_labels_solver = shortest_pathsolution_np(solver, val_true_weights)
    if use_test_set:
        test_labels_solver = shortest_pathsolution_np(solver, test_true_weights)
    # Check which yields the better results
    # shape of scores is (batch_size, 1) -> so sum over the last two dims 
    scores_file_train = (train_true_weights*train_labels).sum(axis=(1, 2))
    scores_solver_train = (train_true_weights*train_labels_solver).sum(axis=(1, 2))
    scores_file_val = (val_true_weights*val_labels).sum(axis=(1, 2))
    scores_solver_val = (val_true_weights*val_labels_solver).sum(axis=(1, 2))
    if use_test_set:
        scores_file_test = (test_true_weights*test_labels).sum(axis=(1, 2))
        scores_solver_test = (test_true_weights*test_labels_solver).sum(axis=(1, 2))
    # Now check the percentages 
    print(f"Size: {img_size}")
    print(f"Percentage of where the solver yields a better solutions than the one loaded from the file for train: {np.sum(scores_solver_train<scores_file_train)/len(scores_solver_train)}")
    print(f"Percentage of where the solver yields a better solutions than the one loaded from the file for val: {np.sum(scores_solver_val<scores_file_val)/len(scores_solver_val)}")
    if use_test_set:
        print(f"Percentage of where the solver yields a better solutions than the one loaded from the file for test: {np.sum(scores_solver_test<scores_file_test)/len(scores_solver_test)}")