In [None]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO

As a function of model layers

In [None]:
graphs, pos_idx, neg_idx = load_data('BA-2motif')

model_dirs = ['gin-2-ba2motif.torch',
            'gin-3-ba2motif.torch',
            'gin-4-ba2motif.torch',
            'gin-5-ba2motif.torch',
            'gin-6-ba2motif.torch',
            'gin-7-ba2motif.torch']

# g = graphs[44]
S = [0,1,2,3]
alpha = 0.
verbose = False
num_samples = 50
sample_idx = np.random.choice(len(graphs),num_samples,replace=False)

model_times = []

for model_dir in tqdm(model_dirs):
    ts = []
    nn = torch.load('models/'+model_dir)
    
    t_temp = 0
    for j, i in tqdm(enumerate(sample_idx)):
        g = graphs[i]
        timea = time.time()
        subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
        
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    model_times.append(ts)

In [None]:
# time & model layers
model_times = [[3.79432821e-02, 2.85856771e-02, 2.11679459e-03],
       [2.24222307e-01, 5.61089563e-02, 4.21986103e-03],
       [1.21142797e+00, 8.27891207e-02, 5.42837620e-03],
       [6.06633543e+00, 9.72466040e-02, 6.44898415e-03],
       [3.09617811e+01, 1.18584792e-01, 8.30366135e-03],
       [1.42194685e+02, 1.39745464e-01, 9.81472492e-03]]

In [None]:
# plotting
num_layers = np.arange(1,len(model_times)+1)
model_times = np.array(model_times)
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

plt.rc('legend', fontsize=12) 
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax1.set_ylabel("Time (s)")
ax2.set_xlabel(r'$L$')
plt.xticks(num_layers, [str(i) if i % 2 == 1 else '' for i in range(1,len(model_times)+1)])

ax1.plot(num_layers, model_times[:,0], 'r-')
line2, = ax1.plot(num_layers, [0]*len(num_layers), 'b-.')
ax1.legend(['GNN-LRP naive', 'sGNN-LRP'])
line2.remove()
ax2.plot(num_layers, model_times[:,0], 'r-')
ax2.plot(num_layers, model_times[:,2], 'b-.')
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.005)  # outliers only
ax2.set_ylim(-0,0.013)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_L.eps', dpi=600, format='eps', bbox_inches='tight')

As a function of subgraph size

In [None]:
graphs, pos_idx, neg_idx = load_data('BA-2motif')

model_dir = 'gin-3-ba2motif.torch'

alpha = 0.
verbose = False
num_samples = 50
sample_idx = np.random.choice(len(graphs),num_samples,replace=False)
nn = torch.load('models/'+model_dir)

model_times = []

for size_S in tqdm(range(25)):
    S = list(range(size_S))

    ts = []
    
    t_temp = 0
    for j, i in tqdm(enumerate(sample_idx)):
        g = graphs[i]
        timea = time.time()
        subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
        
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    model_times.append(ts)

In [None]:
# time & subgraph size
model_times = [[0.0046125841140747074, 0.07880849838256836, 0.004756159782409668],
 [0.03629368782043457, 0.0721881341934204, 0.0038048934936523436],
 [0.09700009346008301, 0.06509144306182861, 0.004636597633361816],
 [0.22644277095794677, 0.06653767108917236, 0.0055991697311401365],
 [0.44647380352020266, 0.06331422805786133, 0.004125766754150391],
 [0.7955206298828125, 0.0697042989730835, 0.004573965072631836],
 [1.244990677833557, 0.05818732738494873, 0.0046390676498413086],
 [1.9092957782745361, 0.06346342086791992, 0.004363369941711426],
 [2.6129654121398924, 0.05599023818969726, 0.004363474845886231],
 [3.686415991783142, 0.05271221160888672, 0.004402637481689453],
 [4.913209948539734, 0.06820619106292725, 0.004435920715332031],
 [6.259419693946838, 0.055459322929382326, 0.0032729005813598635],
 [8.146011185646056, 0.06872797012329102, 0.003718433380126953],
 [10.04942915916443, 0.06456290721893311, 0.0037899065017700196],
 [12.10968173980713, 0.05984879970550537, 0.0038155269622802733],
 [15.019109363555907, 0.06313271999359131, 0.0058210515975952145],
 [18.107507333755493, 0.07190500736236573, 0.003952789306640625],
 [22.159293155670166, 0.06416292667388916, 0.004039020538330078],
 [25.479938478469847, 0.06957473278045655, 0.00435152530670166]]

In [None]:
# plotting
num_layers = np.arange(1,len(model_times)+1)
model_times = np.array(model_times)
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
ax2.yaxis.tick_right()

plt.rc('legend', fontsize=12)
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax2.set_xlabel(r'$|\mathcal{S}|$')
plt.xticks(num_layers, [str(i) if i % 3 == 1 else '' for i in range(1,len(model_times)+1)])

ax1.plot(num_layers, model_times[:,0], 'r-')
line2, = ax1.plot(num_layers, [0]*len(num_layers), 'b-.')
line2.remove()
ax2.plot(num_layers, model_times[:,0], 'r-')
ax2.plot(num_layers, model_times[:,2], 'b-.')

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.05)  # outliers only
ax2.set_ylim(-0,0.01)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_S.eps', dpi=600, format='eps', bbox_inches='tight')

Compare time consumptions of three methods

In [None]:
alpha = 0.
S = np.arange(5)

model_dirs = ['gin-3-ba2motif.torch',
            'gin-5-ba2motif.torch',
            'gin-7-ba2motif.torch']

efficiency_result_originals = []
efficiency_result_mp_transcs = []
efficiency_result_forward_hooks = []

for model_dir in model_dirs:
    nn = torch.load(model_dir)
    s = StringIO()
    old_stdout = sys.stdout
    sys.stdout = s
    lists = []

    for i in tqdm(range(50)):
        g = graphs[i]
        for _ in range(3):
            subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
            subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
            subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
    lists = s.getvalue().splitlines()

    sys.stdout = old_stdout

    lists_ = [[float(data.split(': ')[-1].split(',')[0]) for data in l.split('\t')[1:]] for l in lists]

    efficiency_result_original = pd.DataFrame(lists_[::3],columns=['nbnodes','layers','overhead','subrel'])
    efficiency_result_mp_transc = pd.DataFrame(lists_[1::3],columns=['nbnodes','layers','overhead','subrel'])
    efficiency_result_forward_hook = pd.DataFrame(lists_[2::3],columns=['nbnodes','layers','forward1','backward1'])

    efficiency_result_originals.append(efficiency_result_original.mean(axis=0))
    efficiency_result_mp_transcs.append(efficiency_result_mp_transc.mean(axis=0))
    efficiency_result_forward_hooks.append(efficiency_result_forward_hook.mean(axis=0))