# Test the performance of the pre-trained GNN model for the problem of *Graph Coloring*

In [1]:
%load_ext autoreload
%autoreload 2

In [30]:
import random
__counter__ = random.randint(0,2e9)

from IPython.display import HTML, display

In [2]:
import sys, os, shutil, json

gnn_path = "../classification/GNN/"
# adding GNN folder to the system path
sys.path.insert(0, gnn_path)

from train import testing
from data_loader_c import dataset_processing
from transfer_model import re_training, tune_parameters

Device is: cuda:0
Device is: cuda:0


A helper function to ensure that the dataset is made from scratch when NN is tested. 

In [3]:
def delete_folder_contents(folders):
    for folder in folders:
        for filename in os.listdir(folder):
            file_path = os.path.join(folder, filename)
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)

### Test set from a graph coloring dataset

Previously, we have trained a GNN model that solves the SAT-3 problem with all metrics > 90%. Now, we will use the pre-trained GNN model in order to solve the satisfiability of the graph-coloring problem as a classification tast. <br>

All NP-complete problems can be easily transformed to one another using algorithms of polynomial-time complexity. Let us give an example of how a 3-color graph-coloring problem can be transformed into a satisfiability one. <br>

Suppose we have the following graph that we want to check if it is 3-colorable i.e. every vertex will have a different color from its neighbors and the total number of colors that can be used is (less than) 3.

<img src="./plots/graph.jpg" height=100 width=200>

In order to transform this problem to a satisfiability one we follow the next steps by order:
<ol>
    <li>We assign <b>3 logic variables for every vertex</b> that represent whether this vertex is colored using the corresponding color. For example, for vertex $x_1$ we have: $x_{11}$ that is true if the vertex $x_1$ is colored using the first color and false otherwise, $x_{12}$ that is true if the vertex $x_1$ is colored using the second color and false otherwise and $x_{13}$ that is true if the vertex $x_1$ is colored using the third color and false otherwise. </li><br>
     <li> Given these new variables, we have to write the logical expressions that will be used to evaluate the colorability of the graph. More specifically,</li>
    <ul> 
        <li>A vertex should be colored with <b>at least one color</b>. For example, for vertex $x_1$ this statement translated to the following expression, $$x_{11} \vee x_{12} \vee x_{13}$$ </li> <br>
        <li>A vertex <b>cannot be colored with all 3 colors at the same time</b>. For example, for vertex $x_1$ this statement translated to the following expressions, $$\neg ( x_{11} \wedge x_{12}) \wedge \neg ( x_{11} \wedge x_{13}) \wedge \neg ( x_{12} \wedge x_{13})$$ that is equal to: (using De'Morgan law) $$ ( \neg x_{11} \vee \neg x_{12}) \wedge ( \neg x_{11} \vee \neg x_{13}) \wedge ( \neg x_{12} \vee \neg x_{13})$$</li> <br>
         <li>A vertex <b>cannot be colored with the same color as its neighbor</b>. For example, for vertex $x_1$ this statement translated to the following expressions, $$\neg ( x_{11} \wedge x_{21}) \wedge \neg ( x_{12} \wedge x_{22}) \wedge \neg ( x_{13} \wedge x_{23})$$ that is equal to: (using De'Morgan law) $$ ( \neg x_{11} \vee \neg x_{21}) \wedge ( \neg x_{12} \vee \neg x_{22}) \wedge ( \neg x_{13} \vee \neg x_{23})$$</li> <br>
    </ul> 
    <li>Since all these expressions must hold true for all variables, combine them using the logical-and in order to create a <b>CNF clause</b>. This final CNF is now the represenation of the graph coloring problem as a satisfiablity one. It is not a 3-SAT one by nature, but we can make it one. For more details please refer to the report.
    </li><br>
    
</ol>

Erase the previous dataset created by PyTorch

In [4]:
delete_folder_contents(["./raw", "./processed"])

**Create the dataset**: The dataset is created from the raw data in such a format that it can be loader by the *DataLoader* module of torch.geometric. The dataset is created in a same manner as the dataset in the SAT-3 problem. <br>

It should be noted that the original dataset only consists of satisfiable instances. ***In order to train the algorithm using also negative examples a simple idea was applied: In every positive instance we randomly select 4 vertices and create a clique of size 4. That instantly makes the graph not 3-colorable.***

In [5]:
pos_weight = dataset_processing()

Start the data processing...

Satisfiable CNFs   : 1699
Unsatisfiable CNFs : 1699

Ratio of SAT   : 0.5000
Ratio of UNSAT : 0.5000

Training set size: 2718
Test set size: 680
Dataset size: 3398

Processing completed.


<b>Get the tuned parameters of the pre-trained model.</b> 

In [6]:
# Load the parameters of the pre-trained GNN model
with open(gnn_path+'best_parameters_same_sets.txt') as f:
    data = f.read()

best_pre_trained_parameters = json.loads(data)

**Test the performance of the model *without* transfer learning**.

In [7]:
print("Test before retraining...\n")
testing(params=best_pre_trained_parameters, model_name=gnn_path+'final_model_same_sets.pth')

Test before retraining...

Dataset loading...


Processing...
100%|███████████████████████████████████████████████████████████████████████████████| 680/680 [00:01<00:00, 549.61it/s]
Done!


Dataset loading completed

Model loading...
Model loading completed


Test set metrics:

 Confusion matrix: 
 [[358 315]
 [  7   0]]
F1 Score  : 0.0000
Accuracy  : 0.5265
Precision : 0.0000
Recall    : 0.0000
ROC AUC   : 0.4904
Test Loss : 0.9309185567227277


In theory, these two problems are very similar, however in terms of modelling for a GNN they are quite different, as noted from the above results. Thus, **using the pre-trained GNN model, we will apply transfer learning**. For more information about the process, please refer to the report.

## Training a GNN classifier for the Graph Coloring problem using Transfer Learning

**Tune parameters**: Tune some parameters regarding the *non-frozen layers of the model* as well as parameters such as *learning rate etc*.

In [8]:
# Tune the algo after applying transfer learning
best_parameters = tune_parameters(pos_weight=pos_weight,  model_name=gnn_path+'final_model_same_sets.pth',
                                  best_pre_trained=best_pre_trained_parameters)


Test number 0 | Start testing new parameter-combination...

Dataset loading...


Processing...
100%|█████████████████████████████████████████████████████████████████████████████| 2718/2718 [00:05<00:00, 515.14it/s]
Done!


Dataset loading completed

Model loading...
Model loading completed

EPOCH | 0
Training Loss   : 0.7279
Validation Loss : 0.6842

EPOCH | 1
Training Loss   : 0.6884
Validation Loss : 0.6771

EPOCH | 2
Training Loss   : 0.6812
Validation Loss : 0.6688

EPOCH | 3
Training Loss   : 0.6750
Validation Loss : 0.6667

EPOCH | 4
Training Loss   : 0.6700
Validation Loss : 0.6609

EPOCH | 5
Training Loss   : 0.6638
Validation Loss : 0.6565

EPOCH | 6
Training Loss   : 0.6579
Validation Loss : 0.6543

EPOCH | 7
Training Loss   : 0.6523
Validation Loss : 0.6509

EPOCH | 8
Training Loss   : 0.6470
Validation Loss : 0.6469

EPOCH | 9
Training Loss   : 0.6417
Validation Loss : 0.6407

EPOCH | 10
Training Loss   : 0.6372
Validation Loss : 0.6367

EPOCH | 11
Training Loss   : 0.6322
Validation Loss : 0.6351

EPOCH | 12
Training Loss   : 0.6272
Validation Loss : 0.6314

EPOCH | 13
Training Loss   : 0.6231
Validation Loss : 0.6266

EPOCH | 14
Training Loss   : 0.6183
Validation Loss : 0.6195

EPOCH | 15


Validation Loss : 0.6935

EPOCH | 15
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 16
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 17
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 18
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 19
Early stopping activated, with training and validation loss difference: 0.0001

Test number 5 | Start testing new parameter-combination...

Dataset loading...
Dataset loading completed

Model loading...
Model loading completed

EPOCH | 0
Training Loss   : 0.8025
Validation Loss : 0.6940

EPOCH | 1
Training Loss   : 0.6950
Validation Loss : 0.6932

EPOCH | 2
Training Loss   : 0.6941
Validation Loss : 0.6933

EPOCH | 3
Training Loss   : 0.6936
Validation Loss : 0.6934

EPOCH | 4
Training Loss   : 0.6933
Validation Loss : 0.6934

EPOCH | 5
Training Loss   : 0.6932
Validation Loss : 0.6934

EPOCH | 6
Training Loss   : 0.6934
Validation Loss : 0.6933

EPOCH | 7
Training Loss   : 0.6932
Validation Loss : 0.6934


Training Loss   : 0.6933
Validation Loss : 0.6934

EPOCH | 22
Training Loss   : 0.6933
Validation Loss : 0.6934

EPOCH | 23
Training Loss   : 0.6932
Validation Loss : 0.6934

EPOCH | 24
Training Loss   : 0.6932
Validation Loss : 0.6934

EPOCH | 25
Training Loss   : 0.6932
Validation Loss : 0.6934

EPOCH | 26
Training Loss   : 0.6931
Validation Loss : 0.6934

EPOCH | 27
Training Loss   : 0.6931
Validation Loss : 0.6934

EPOCH | 28
Training Loss   : 0.6931
Validation Loss : 0.6934

EPOCH | 29
Training Loss   : 0.6931
Validation Loss : 0.6934

EPOCH | 30
Training Loss   : 0.6930
Validation Loss : 0.6934

EPOCH | 31
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 32
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 33
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 34
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 35
Early stopping activated, with training and validation loss difference: 0.0000

Test number 9 | Start testing new parameter-combinat

Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 12
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 13
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 14
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 15
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 16
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 17
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 18
Training Loss   : 0.6929
Validation Loss : 0.6935

EPOCH | 19
Early stopping activated, with training and validation loss difference: 0.0001

Test number 14 | Start testing new parameter-combination...

Dataset loading...
Dataset loading completed

Model loading...
Model loading completed

EPOCH | 0
Training Loss   : 0.8433
Validation Loss : 0.6954

EPOCH | 1
Training Loss   : 0.6963
Validation Loss : 0.6932

EPOCH | 2
Training Loss   : 0.6935
Validation Loss : 0.6933

EPOCH | 3
Training Loss   : 0.6935
Validation Loss : 0.6938

EPOCH | 4
Training Loss   : 0.6

Validation Loss : 0.6934

EPOCH | 24
Training Loss   : 0.6931
Validation Loss : 0.6934

EPOCH | 25
Training Loss   : 0.6930
Validation Loss : 0.6934

EPOCH | 26
Training Loss   : 0.6934
Validation Loss : 0.6935

EPOCH | 27
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 28
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 29
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 30
Training Loss   : 0.6930
Validation Loss : 0.6935

EPOCH | 31
Training Loss   : 0.6942
Validation Loss : 0.6936

EPOCH | 32
Training Loss   : 0.6929
Validation Loss : 0.6936

EPOCH | 33
Early stopping activated, with training and validation loss difference: 0.0000


In [9]:
# Show best parameters
print(f'Best hyperparameters were: {best_parameters}')
# Store best parameters
with open('./best_parameters_same_sets.txt', 'w') as f:
    f.write(json.dumps(best_parameters))

f.close()

Best hyperparameters were: {'batch_size': 16, 'learning_rate': 0.001, 'weight_decay': 0.0001, 'pos_weight': 1.0, 'model_embedding_size': 64, 'model_attention_heads': 1, 'model_layers': 2, 'model_dropout_rate': 0.1, 'model_dense_neurons': 256}


**Re-train using the optimal parameters.**

In [26]:
# Access the best parameters in order to train final model
with open('best_parameters_same_sets.txt') as f:
    data = f.read()

best_parameters_loaded = json.loads(data)

# Do the transfer learning with the optimal parameters
re_training(params=best_parameters_loaded, best_pre_trained_params=best_pre_trained_parameters,
            model_name=gnn_path+'final_model_same_sets.pth')

Device is: cuda:0
Dataset loading...
Dataset loading completed

Model loading...
Model loading completed

EPOCH | 0
Training Loss   : 0.7272
Validation Loss : 0.6836

EPOCH | 1
Training Loss   : 0.6881
Validation Loss : 0.6786

EPOCH | 2
Training Loss   : 0.6838
Validation Loss : 0.6758

EPOCH | 3
Training Loss   : 0.6805
Validation Loss : 0.6682

EPOCH | 4
Training Loss   : 0.6737
Validation Loss : 0.6613

EPOCH | 5
Training Loss   : 0.6670
Validation Loss : 0.6578

EPOCH | 6
Training Loss   : 0.6627
Validation Loss : 0.6563

EPOCH | 7
Training Loss   : 0.6575
Validation Loss : 0.6514

EPOCH | 8
Training Loss   : 0.6527
Validation Loss : 0.6497

EPOCH | 9
Training Loss   : 0.6475
Validation Loss : 0.6471

EPOCH | 10
Training Loss   : 0.6429
Validation Loss : 0.6438

EPOCH | 11
Training Loss   : 0.6375
Validation Loss : 0.6387

EPOCH | 12
Training Loss   : 0.6323
Validation Loss : 0.6364

EPOCH | 13
Training Loss   : 0.6276
Validation Loss : 0.6288

EPOCH | 14
Training Loss   : 0.6237


0.6088223220670924

**Test the performance of the model *after* transfer learning.**

In [27]:
print("Test after retraining...\n")
testing(params=best_parameters_loaded, model_name=gnn_path+'final_model_same_sets_c.pth')

Device is: cuda:0
Test after retraining...

Dataset loading...
Dataset loading completed

Model loading...
Model loading completed


Test set metrics:

 Confusion matrix: 
 [[334 151]
 [ 31 164]]
F1 Score  : 0.6431
Accuracy  : 0.7324
Precision : 0.8410
Recall    : 0.5206
ROC AUC   : 0.7179
Test Loss : 0.5866465222003848


The following **Figures** show:
<ol>
<li>The <b>confusion matrix</b> for the test set-prediction</li>
<li>The <b>ROC-AUC curve</b> for the test set-prediction</li>
<li>The <b>precision recall curve</b> for the test set-prediction</li>
</ol>

In [32]:
print("\n1.")
display(HTML('<img src="plots/cm.png?%d" height=500 width=500>' % __counter__))
print("2.")
display(HTML('<img src="plots/roc_auc.png?%d" height=450 width=450>' % __counter__))
print("3.")
display(HTML('<img src="plots/pr.png?%d" height=450 width=450>' % __counter__))


1.


2.


3.


Unfortunatelly, the amount of data that we have does not make it possible to re-train more layers of the GNN and achieve better results.