This release contains a working implementation of Neural Diving from the paper Solving Mixed Integer Programs Using Neural Networks with data generation built on top of that from Exact Combinatorial Optimization with Graph Convolutional Neural Networks.
The following gives a brief overview of the contents; more detailed documentation is available within each file:
- config_train.py: Configuration file for training parameters.
- data_generation.py: Generates the feature set for each MIP instance.
- data_utils.py: Utility functions for feature extraction.
- evaluate_solvers.py: Compares solver performance between warm starts with predicted solutions and cold starts.
- instance_generation.py: Generates each MIP instance.
- layer_norm.py: Model layer normalisation and dropout utilities.
- light_gnn.py: The GNN model used for training.
- sampling.py: Sampling strategies for Neural LNS.
- solvers.py: Neural diving and feature generation implementation.
- train.py: Training script for neural diving model.
- data: Directory with tfrecord files to run training.
This project borrows config_train.py
, data_utils.py
, layer_norm.py
,
light_gnn.py
, sampling.py
, and train.py
from Neural Local Neighborhood
Search and instance_generation.py
from
Learn2Branch (respectively the implementations
for the two papers linked above). Our contributions are data_generation.py
,
evaluate_solvers.py
, and solvers.py
, which form the glue code for the borrowed
files.
To install the dependencies of this implementation, please run:
conda env create -f environment.yml
conda activate neural diving
- Generate a collection of MIP instances with
instance_generation.py
. - Generate the features and labels to train the neural diving model with
data_generation.py
. - Specify valid training and validation paths in
config_train.py
(i.e. <dataset_absolute_training_path> and <dataset_absolute_validation_path>). - Train the neural diving model with
train.py
. - Compare warm starting Gurobi with neural diving predictions to cold starting Gurobi with
evaluate_solvers.py
.