In [1]:
import sys
import random
import itertools
import logging
from pprint import pprint
from collections import Counter

import numpy as np
import tensorflow as tf
import tensorflow.keras as ks
import matplotlib.pyplot as plt

from gnn_teacher_student.main import StudentTeacherExplanationAnalysis
from gnn_teacher_student.training import ReferenceStudentTraining, ExplanationLoss, ExplanationPreTraining
from gnn_teacher_student.students import StudentTemplate, GeneralAttentionStudent
from gnn_teacher_student.layers import (NodeImportanceSubNetwork,
                                        EdgeImportanceSubNetwork,
                                        ConvolutionalSubNetwork)
from gnn_teacher_student.data import generate_color_pairs_dataset
from gnn_teacher_student.visualization import (draw_colors_graph,
                                               draw_graph_node_importances,
                                               draw_graph_edge_importances)


# Disabling all kinds of warnings
import warnings
warnings.filterwarnings('ignore')
warnings.warn = lambda *args, **kwargs: 0
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Enable eager execution
tf.compat.v1.enable_eager_execution()

In [5]:
# ~ Generating the dataset

dataset = generate_color_pairs_dataset(
    length=2000,
    node_count_cb=lambda: random.randint(5, 30),
    additional_edge_count_cb=lambda: random.randint(1, 3),
    colors=[
        (1, 0, 0),  # red
        (0, 1, 0),  # green
        (0, 0, 1),  # blue
        #(1, 1, 0),  # yellow
        #(1, 0, 1),  # magenta
    ],
    exclude_empty=True
)

# ~ Printing information about dataset
graph_sizes = [len(g['node_indices']) for g in dataset]
graph_labels = [int(g['graph_labels']) for g in dataset]
graph_labels_counter = Counter(graph_labels)

graph_colors = [[tuple(color) for color in g['node_attributes']] for g in dataset]
graph_colors_combined = list(itertools.chain(*graph_colors))
graph_colors_counter = Counter(graph_colors_combined)

print('DATASET SUMMARY')
print('==============================================================')
print(f'Number of Graphs: {len(dataset)}')
print('--------------------------------------------------------------')
print(f'Min Graph Size: {min(graph_sizes)}')
print(f'Max Graph Size: {max(graph_sizes)}')
print('--------------------------------------------------------------')
print(f'Number of distinct Colors: {len(graph_colors_counter)}')
print(f'Color Distribution:')
for color, count in graph_colors_counter.items():
    percentage = count / len(graph_colors_combined)
    print(f'  Color {color}: {count} total nodes ({percentage*100:.1f}%)')
print('--------------------------------------------------------------')
print('GT Labels: Total number of color pairs')
print('Label Distribution:')
for label, count in sorted(graph_labels_counter.items(), key=lambda i: i[0]):
    percentage = count / len(graph_labels)
    print(f'   Label {label}: {count:<5} ({percentage*100:.1f}%)')

DATASET SUMMARY
Number of Graphs: 2000
--------------------------------------------------------------
Min Graph Size: 5
Max Graph Size: 30
--------------------------------------------------------------
Number of distinct Colors: 3
Color Distribution:
  Color (0.0, 0.0, 1.0): 11923 total nodes (32.9%)
  Color (1.0, 0.0, 0.0): 12107 total nodes (33.4%)
  Color (0.0, 1.0, 0.0): 12235 total nodes (33.7%)
--------------------------------------------------------------
GT Labels: Total number of color pairs
Label Distribution:
   Label 1: 319   (16.0%)
   Label 2: 457   (22.9%)
   Label 3: 433   (21.6%)
   Label 4: 332   (16.6%)
   Label 5: 226   (11.3%)
   Label 6: 124   (6.2%)
   Label 7: 65    (3.2%)
   Label 8: 32    (1.6%)
   Label 9: 7     (0.4%)
   Label 10: 3     (0.1%)
   Label 12: 2     (0.1%)


In [6]:
def create_node_importance_network():
    return NodeImportanceSubNetwork(
        unitss=[3],
        activation='tanh',
        use_softmax=True,
        use_bias=False
    )

def create_edge_importance_network():
    return EdgeImportanceSubNetwork(
        unitss=[3],
        activation='tanh',
        use_softmax=True,
        use_bias=False
    )

def create_prediction_network():
    return ConvolutionalSubNetwork(
        unitss=[2, 2],
        use_bias=False,
        activation='tanh'
    )

attention_student = StudentTemplate(
    student_name='gas',
    student_class=GeneralAttentionStudent,
    lay_node_importance_cb=create_node_importance_network,
    lay_edge_importance_cb=create_edge_importance_network,
    lay_prediction_cb=create_prediction_network
)

In [None]:
EPOCHS = 5000
BATCH_SIZE = 200
LEARNING_RATE = 1e-3

student_teacher_analysis = StudentTeacherExplanationAnalysis(
    student_template=attention_student,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    optimizer=ks.optimizers.Adam(learning_rate=LEARNING_RATE),
    prediction_metric=ks.metrics.MeanSquaredError(),
    explanation_metric=ks.metrics.MeanAbsoluteError()
)

student_teacher_analysis.logger.addHandler(logging.StreamHandler(stream=sys.stdout))

_dataset = {field: [g[field] for g in dataset] for field in dataset[0].keys()}

explanation_pre_training = ExplanationPreTraining(
    epochs=int(0.3 * EPOCHS),
    loss=[
        ks.losses.MeanSquaredError(),
        ExplanationLoss(loss_function=ks.losses.mean_absolute_error),
        ExplanationLoss(loss_function=ks.losses.mean_absolute_error)
    ],
    post_weights=[1, 0.3, 0.3],
    lock_explanation=False
)

reference_student_training = ReferenceStudentTraining()

with tf.device('/cpu:0'):
    results = student_teacher_analysis.fit(
        dataset=_dataset,
        train_split=0.7,
        variant_kwargs={
            'exp': explanation_pre_training(),
            'ref': reference_student_training()
        },
        log_progress=int(0.2 * EPOCHS)
    )


starting student training "gas:exp" [LOSS:] prediction="mean_squared_error*<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>" node_importance="explanation_loss*<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>" edge_importance="explanation_loss*<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>" [BATCHING:] batch_size=200 supports_batching="True" [MODEL:] parameters=57 [TRAINING:] epochs=5000 optimizer=Adam dataset_size=6) 
