# Comparison between GraphRNN and GRAN

## Setup

In [4]:
import os
import sys
import torch
import logging
import traceback
import numpy as np
from pprint import pprint
import pandas as pd
from runner.train_runners import *
from utils.logger import setup_logging
from utils.arg_helper import parse_arguments, get_config
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
torch.set_printoptions(profile='full')


## Useful Functions

In [5]:
df = pd.read_csv("save_model_learning.csv")
df

Unnamed: 0.1,Unnamed: 0,Date,dataset_name,dataset_num,model_name,num_epochs,file_dir
0,0,2021-Aug-15-01-56-10,community2,500,GRANMixtureBernoulli,5,exp/GRAN/community2\GRANMixtureBernoulli_commu...
1,0,2021-Aug-15-02-09-00,community4,500,GRANMixtureBernoulli,5,exp/GRAN/community4\GRANMixtureBernoulli_commu...
2,0,2021-Aug-15-02-29-46,community8,500,GRANMixtureBernoulli,5,exp/GRAN/community8\GRANMixtureBernoulli_commu...
3,0,2021-Aug-16-14-28-27,watts,500,GRANMixtureBernoulli,5,exp/GRAN/watts\GRANMixtureBernoulli_watts_2021...
4,0,2021-Aug-17-16-56-29,barabasi,500,GRANMixtureBernoulli,15,exp/GRAN/barabasi\GRANMixtureBernoulli_barabas...
5,0,2021-Aug-17-22-00-39,barabasi,500,GRANMixtureBernoulli,50,exp/GRAN/barabasi\GRANMixtureBernoulli_barabas...
6,0,2021-Aug-18-13-17-52,community2,500,RNN,1000,exp/GraphRNN/rnn/community2\RNN_community2_202...
7,0,2021-Aug-18-16-05-01,community4,500,RNN,1000,exp/GraphRNN/rnn/community4\RNN_community4_202...
8,0,2021-Aug-18-17-43-34,community4,500,RNN,1000,exp/GraphRNN/rnn/community4\RNN_community4_202...
9,0,2021-Aug-18-22-08-24,community8,500,RNN,1000,exp/GraphRNN/rnn/community8\RNN_community8_202...


## Research Questions
1) Which one of the model is the better for each dataset?

2) About GRAN, which node ordering gives better results ?

3) Is there a bias from GRAN to generates communities ? ( How many ? Scalability ? Robustness ?)

4) What are optimal M-parametes for GraphRNN for each dataset ?
Does the parameter tweaking change the efficiency significantly ?

5) Does the SotA autoregressive model are able to retain the small-world propertie from the graph ( = avg length btw 2 nodes is proportional to the log of the number of nodes N )
(create GNN classifier "small-world" or GAN)

## Experiments
####  1) Which one of the model is the better for each dataset?
####  2) About GRAN, which node ordering gives better results ?
####  3) Is there a bias from GRAN to generates communities ? ( How many ? Scalability ? Robustness ?)
####  4) What are optimal M-parametes for GraphRNN for each dataset ? Does the parameter tweaking change the efficiency significantly

## Datasets

### Erdos Renyi dataset
Parameters used : 500 graphs btw 100-200 nodes with p=0.1

Node ordering (GRAN) :

### Barabasi Albert Dataset
Parameters used :Parameters used : 500 graphs btw 100-200 nodes with k=4/5

Node ordering (GRAN) :

### Watts Strogatz Dataset
Parameters used :Parameters used : 500 graphs btw 100-200 nodes with p=0.01

Node ordering (GRAN) : DFS


### Community Dataset
Parameters used : graphs of 2/4/8 communities of between 12 to 17 nodes

Node ordering (GRAN) : DFS


## Results

In [6]:
def get_stats_from_trained_model(config,seed):
    """Return all mmd statistical results from
    generated graph by the trained model, in the form of a dict"""

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    config.use_gpu = config.use_gpu and torch.cuda.is_available()
    torch.cuda.empty_cache()

    runner = eval(config.runner)(config)

    mmd_degree_test, mmd_clustering_test, mmd_4orbits_test, mmd_spectral_test = runner.test()
    return {"mmd_degree_test": mmd_degree_test, "mmd_clustering_test": mmd_clustering_test,
            "mmd_4orbits_test": mmd_4orbits_test, "mmd_spectral_test": mmd_spectral_test}

max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.81s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:27<00:00,  2.77s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:27<00:00,  2.78s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:29<00:00,  2.91s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.88s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.88s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.89s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.90s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.88s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574


100%|██████████| 10/10 [00:28<00:00,  2.88s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:37<00:00,  9.75s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:35<00:00,  9.56s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:36<00:00,  9.69s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:37<00:00,  9.76s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:37<00:00,  9.77s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:37<00:00,  9.76s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:37<00:00,  9.78s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:35<00:00,  9.52s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:35<00:00,  9.53s/it]


max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038


100%|██████████| 10/10 [01:35<00:00,  9.50s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:06<00:00, 36.62s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:03<00:00, 36.30s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:12<00:00, 37.28s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [05:58<00:00, 35.89s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:00<00:00, 36.09s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [05:55<00:00, 35.54s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [05:59<00:00, 35.97s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:00<00:00, 36.09s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:10<00:00, 37.09s/it]


max # nodes = 115 || mean # nodes = 115.0
max # edges = 614 || mean # edges = 593.484


100%|██████████| 10/10 [06:14<00:00, 37.47s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:36<00:00, 16.84s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:34<00:00, 16.73s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:37<00:00, 16.86s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:41<00:00, 17.09s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:36<00:00, 16.85s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:45<00:00, 17.30s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:49<00:00, 17.45s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:41<00:00, 17.06s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:41<00:00, 17.09s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:43<00:00, 17.20s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:19<00:00, 30.96s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:17<00:00, 30.88s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:18<00:00, 30.93s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:19<00:00, 31.00s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:20<00:00, 31.05s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:15<00:00, 30.79s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:14<00:00, 30.71s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:17<00:00, 30.88s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:21<00:00, 31.07s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:13<00:00, 30.70s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:14<00:00, 30.71s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:14<00:00, 30.71s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:15<00:00, 30.76s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:17<00:00, 30.88s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:18<00:00, 30.95s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:14<00:00, 30.70s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:14<00:00, 30.75s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:22<00:00, 31.13s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:08<00:00, 30.43s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 780 || mean # edges = 582.0


100%|██████████| 20/20 [10:18<00:00, 30.94s/it]


max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574




max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 32 || mean # nodes = 32.0
max # edges = 170 || mean # edges = 162.574
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 

100%|██████████| 20/20 [05:50<00:00, 17.52s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:54<00:00, 17.74s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:47<00:00, 17.35s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:44<00:00, 17.23s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:44<00:00, 17.21s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:54<00:00, 17.71s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:50<00:00, 17.50s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:50<00:00, 17.52s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:43<00:00, 17.18s/it]


max # nodes = 199 || mean # nodes = 149.5
max # edges = 199 || mean # edges = 149.5


100%|██████████| 20/20 [05:47<00:00, 17.36s/it]


In [47]:
row_list = []
for training_path in df['file_dir']:

    try :
        config_path = os.path.join(training_path, 'config.yaml')
        config = get_config(config_path)
    except :
        continue

    for i in range(10):
        if training_path.find('mlp') == -1:
            dict_results = {"dataset_name": config.dataset.name, "model_name": config.model.name,
                    "num_epochs": config.train.max_epoch}
        else :
            dict_results = {"dataset_name": config.dataset.name, "model_name": config.model.name+"_MLP",
                    "num_epochs": config.train.max_epoch}
        dict_stats = get_stats_from_trained_model(config,11*(i^3))
        dict_results.update(dict_stats)
        row_list.append(dict_results)
        torch.cuda.empty_cache()

result_df=pd.DataFrame(row_list)
torch.cuda.empty_cache()

TypeError: get_stats_from_trained_model() takes 1 positional argument but 2 were given

In [7]:
result_df
result_df.to_csv("statsResults.csv")

Unnamed: 0,dataset_name,model_name,num_epochs,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
0,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
1,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
2,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
3,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
4,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
...,...,...,...,...,...,...,...
195,watts_ring,GRANMixtureBernoulli,50,0.001969,0.000000,0.000016,0.039995
196,watts_ring,GRANMixtureBernoulli,50,0.001969,0.000000,0.000016,0.039995
197,watts_ring,GRANMixtureBernoulli,50,0.001969,0.000000,0.000016,0.039995
198,watts_ring,GRANMixtureBernoulli,50,0.001969,0.000000,0.000016,0.039995


In [8]:
result_df.style

Unnamed: 0,dataset_name,model_name,num_epochs,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
0,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
1,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
2,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
3,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
4,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
5,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
6,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
7,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
8,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
9,community2,GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724


Here are the mmd (with emd) metrics results for each data
#### 2-Community Dataset stats

In [31]:
result_df[result_df['dataset_name']=="community2"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)

Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,5,0.017374,0.220196,0.139863,0.072724
RNN,1000,0.033613,0.250731,0.369406,0.121066
RNN_MLP,1000,0.051372,0.250667,0.493045,0.133577


#### 4-Community Dataset stats

In [32]:
result_df[result_df['dataset_name']=="community4"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)

Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,5,0.008479,0.027914,0.455951,0.018261
RNN,1000,0.022312,0.057453,0.573198,0.027762
RNN_MLP,1000,0.051044,0.057449,0.571152,0.039798


#### 8-Community Dataset stats

In [33]:
result_df[result_df['dataset_name']=="community8"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)

Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,5,0.017635,0.014175,0.518267,0.004593
RNN,1000,0.062131,0.048314,0.665375,0.058272
RNN_MLP,1000,0.102437,0.051067,0.702279,0.066848


#### Barabasi Dataset stats

In [34]:
result_df[result_df['dataset_name']=="barabasi"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)

Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,15,0.049913,0.084001,0.030208,0.001003
GRANMixtureBernoulli,50,0.030368,0.081585,0.028486,0.004365
RNN,1000,0.143479,0.365482,0.180747,0.027119
RNN_MLP,1000,0.14542,0.342212,0.179807,0.026639


#### Watts-Strogatz (p=0.05 Graph mode) Dataset stats

In [35]:
result_df[result_df['dataset_name']=="watts"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)


Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,5,0.121858,0.029865,0.00175,0.179487
RNN,1000,0.000193,1.4e-05,5e-06,0.011747
RNN_MLP,1000,7.4e-05,2e-05,0.0,0.009267


#### Watts-Strogatz (Ring p =0.0 Graph mode) Dataset stats

In [36]:
result_df[result_df['dataset_name']=="watts_ring"].groupby(['model_name', 'num_epochs']).mean().style.highlight_min(color = 'lightblue', axis = 0)

Unnamed: 0_level_0,Unnamed: 1_level_0,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
model_name,num_epochs,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
GRANMixtureBernoulli,50,0.001969,0.0,1.6e-05,0.039995
RNN,1000,0.014658,2.3e-05,2.5e-05,0.07543
RNN_MLP,1000,0.005352,0.0,6.6e-05,0.051669


#### Summary measurements

In [38]:
result_df.groupby(['dataset_name', 'model_name', 'num_epochs']).agg(['mean', 'std'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mmd_degree_test,mmd_degree_test,mmd_clustering_test,mmd_clustering_test,mmd_4orbits_test,mmd_4orbits_test,mmd_spectral_test,mmd_spectral_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,mean,std,mean,std,mean,std
dataset_name,model_name,num_epochs,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
barabasi,GRANMixtureBernoulli,15,0.049913,0.000481,0.084001,0.002337,0.03020837,0.0009096069,0.001003,4.3e-05
barabasi,GRANMixtureBernoulli,50,0.030368,0.000551,0.081585,0.001243,0.02848619,0.00125836,0.004365,8e-05
barabasi,RNN,1000,0.143479,0.000818,0.365482,0.002339,0.1807474,0.00111748,0.027119,0.000196
barabasi,RNN_MLP,1000,0.14542,0.002246,0.342212,0.032601,0.1798068,0.02338711,0.026639,0.004585
community2,GRANMixtureBernoulli,5,0.017374,0.0,0.220196,0.0,0.1398629,0.0,0.072724,0.0
community2,RNN,1000,0.033613,0.0,0.250731,0.0,0.3694064,0.0,0.121066,0.0
community2,RNN_MLP,1000,0.051372,0.0,0.250667,0.0,0.4930449,0.0,0.133577,0.0
community4,GRANMixtureBernoulli,5,0.008479,1e-06,0.027914,3e-06,0.4559513,3.842114e-05,0.018261,7e-06
community4,RNN,1000,0.022312,0.0,0.057453,0.0,0.5731978,0.0,0.027762,0.0
community4,RNN_MLP,1000,0.051044,0.0,0.057449,0.0,0.5711516,0.0,0.039798,0.0


Testing repeated measurements with the different seed ?

In [48]:
row_list = []
for training_path in df['file_dir']:

    try :
        config_path = os.path.join(training_path, 'config.yaml')
        config = get_config(config_path)
    except :
        continue
    if config.dataset.name == "community4" and config.model.name == "RNN":
        for i in range(10):
            if training_path.find('mlp') == -1:
                dict_results = {"dataset_name": config.dataset.name, "model_name": config.model.name,
                        "num_epochs": config.train.max_epoch}
            else :
                dict_results = {"dataset_name": config.dataset.name, "model_name": config.model.name+"_MLP",
                        "num_epochs": config.train.max_epoch}
            dict_stats = get_stats_from_trained_model(config)
            dict_results.update(dict_stats)
            row_list.append(dict_results)
            torch.cuda.empty_cache()

result_community_df=pd.DataFrame(row_list)
torch.cuda.empty_cache()

max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038




max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 335.038
max # nodes = 64 || mean # nodes = 64.0
max # edges = 347 || mean # edges = 

In [46]:
result_community_df


Unnamed: 0,dataset_name,model_name,num_epochs,mmd_degree_test,mmd_clustering_test,mmd_4orbits_test,mmd_spectral_test
0,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
1,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
2,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
3,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
4,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
5,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
6,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
7,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
8,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762
9,community4,RNN,1000,0.022312,0.057453,0.573198,0.027762


In [49]:
result_community_df.groupby(['dataset_name', 'model_name', 'num_epochs']).agg(['mean', 'std'])



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mmd_degree_test,mmd_degree_test,mmd_clustering_test,mmd_clustering_test,mmd_4orbits_test,mmd_4orbits_test,mmd_spectral_test,mmd_spectral_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,mean,std,mean,std,mean,std
dataset_name,model_name,num_epochs,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
community4,RNN,1000,0.022312,0.0,0.057453,0.0,0.573198,0.0,0.027762,0.0
community4,RNN_MLP,1000,0.051044,0.0,0.057449,0.0,0.571152,0.0,0.039798,0.0


In [51]:
result_df.groupby(['dataset_name', 'model_name', 'num_epochs']).describe()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_degree_test,mmd_clustering_test,mmd_clustering_test,...,mmd_4orbits_test,mmd_4orbits_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test,mmd_spectral_test
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
dataset_name,model_name,num_epochs,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2
barabasi,GRANMixtureBernoulli,15,10.0,0.049913,0.000481,0.04924,0.049667,0.049841,0.050379,0.050537,10.0,0.084001,...,0.03082363,0.03183272,10.0,0.001003,4.3e-05,0.000923,0.000984,0.001006,0.001025,0.001076
barabasi,GRANMixtureBernoulli,50,10.0,0.030368,0.000551,0.029278,0.030204,0.03037,0.03065,0.031299,10.0,0.081585,...,0.02954832,0.03037812,10.0,0.004365,8e-05,0.004231,0.004317,0.004375,0.004427,0.004477
barabasi,RNN,1000,10.0,0.143479,0.000818,0.141956,0.143162,0.143391,0.143955,0.14502,10.0,0.365482,...,0.1813814,0.182174,10.0,0.027119,0.000196,0.026826,0.027006,0.027119,0.027187,0.027548
barabasi,RNN_MLP,1000,20.0,0.14542,0.002246,0.142889,0.14322,0.145417,0.147277,0.149144,20.0,0.342212,...,0.2023374,0.2038438,20.0,0.026639,0.004585,0.021974,0.022166,0.026762,0.031085,0.031266
community2,GRANMixtureBernoulli,5,10.0,0.017374,0.0,0.017374,0.017374,0.017374,0.017374,0.017374,10.0,0.220196,...,0.1398629,0.1398629,10.0,0.072724,0.0,0.072724,0.072724,0.072724,0.072724,0.072724
community2,RNN,1000,10.0,0.033613,0.0,0.033613,0.033613,0.033613,0.033613,0.033613,10.0,0.250731,...,0.3694064,0.3694064,10.0,0.121066,0.0,0.121066,0.121066,0.121066,0.121066,0.121066
community2,RNN_MLP,1000,10.0,0.051372,0.0,0.051372,0.051372,0.051372,0.051372,0.051372,10.0,0.250667,...,0.4930449,0.4930449,10.0,0.133577,0.0,0.133577,0.133577,0.133577,0.133577,0.133577
community4,GRANMixtureBernoulli,5,10.0,0.008479,1e-06,0.008477,0.008477,0.008479,0.00848,0.00848,10.0,0.027914,...,0.4559877,0.4559877,10.0,0.018261,7e-06,0.018254,0.018254,0.018261,0.018267,0.018267
community4,RNN,1000,10.0,0.022312,0.0,0.022312,0.022312,0.022312,0.022312,0.022312,10.0,0.057453,...,0.5731978,0.5731978,10.0,0.027762,0.0,0.027762,0.027762,0.027762,0.027762,0.027762
community4,RNN_MLP,1000,10.0,0.051044,0.0,0.051044,0.051044,0.051044,0.051044,0.051044,10.0,0.057449,...,0.5711516,0.5711516,10.0,0.039798,0.0,0.039798,0.039798,0.039798,0.039798,0.039798
