# GIN Experiment
This notebook will implement the evaluation pipeline with the FID calculation using the GIN

## Setup

In [6]:
%load_ext autoreload
%autoreload 2
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pickle
from tqdm import tqdm
from util import load_graph_list
from util import load_data, load_synth_data, separate_data , load_graph_asS2Vgraph
from models.graphcnn import GraphCNN


criterion = nn.CrossEntropyLoss()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
def train(iters_per_epoch, batch_size, model, device, train_graphs, optimizer, epoch):
    model.train()

    total_iters = iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')

    loss_accum = 0
    for pos in pbar:
        selected_idx = np.random.permutation(len(train_graphs))[:batch_size]

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        output = model(batch_graph)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

        #compute loss
        loss = criterion(output, labels)

        #backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()         
            optimizer.step()
        

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        #report
        pbar.set_description('epoch: %d' % (epoch))

    average_loss = loss_accum/total_iters
    print("loss training: %f" % (average_loss))
    
    return average_loss

###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size = 64):
    model.eval()
    output = []
    idx = np.arange(len(graphs))
    for i in range(0, len(graphs), minibatch_size):
        sampled_idx = idx[i:i+minibatch_size]
        if len(sampled_idx) == 0:
            continue
        output.append(model([graphs[j] for j in sampled_idx]).detach())
    return torch.cat(output, 0)

def test(model, device, train_graphs, test_graphs, epoch):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_train = correct / float(len(train_graphs))

    output = pass_data_iteratively(model, test_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_test = correct / float(len(test_graphs))

    print("accuracy train: %f test: %f" % (acc_train, acc_test))

    return acc_train, acc_test

## Training settings

In [7]:
dataset = None
device = 0 
batch_size = 32
iters_per_epoch = 50
epochs = 60
lr = 0.01
seed = 0
fold_idx = 1
num_layers = [5]
num_mlp_layers = 3
hidden_dims = [64]
final_dropout = 0.5
graph_pooling_type = "sum"
neighbor_pooling_type = "sum"
learn_eps = False
degree_as_tag = True
filename = ""
random = 0
onehot=True

## Training 

In [8]:
from itertools import product

train_acc,test_acc=[],[]

#set up seeds and gpu device
torch.manual_seed(0)
np.random.seed(0)    
device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

if dataset != None :
    graphs, num_classes = load_data(dataset, degree_as_tag)
else :
    graphs, num_classes, tagset , lentagset = load_synth_data(True, random, onehot,True)

for num_layer,hidden_dim in product(num_layers,hidden_dims):
    fold_test_accuracy=[]
    fold_train_accuracy=[]
    for fold_idx in range(1,2):
        train_graphs, test_graphs = separate_data(graphs, seed, fold_idx)

        model = GraphCNN(num_layer, num_mlp_layers, train_graphs[0].node_features.shape[1], hidden_dim, num_classes, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type,random, device).to(device)

        optimizer = optim.Adam(model.parameters(),lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        
        for epoch in range(1, epochs + 1):
            scheduler.step()

            avg_loss = train(iters_per_epoch,batch_size,model, device, train_graphs, optimizer, epoch)
            acc_train, acc_test = test( model, device, train_graphs, test_graphs, epoch)
            fold_test_accuracy.append(acc_test)
            fold_train_accuracy.append(acc_train)
    train_acc.append((np.mean(fold_train_accuracy),num_layer,hidden_dim))
    test_acc.append((np.mean(fold_test_accuracy),num_layer,hidden_dim))

RuntimeError: CUDA error: device-side assert triggered

In [None]:
output = pass_data_iteratively(model, graphs)

In [None]:
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
%matplotlib inline

import seaborn as sns
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
                rc={"lines.linewidth": 2.5})

def fashion_scatter(x, colors):
    # choose a color palette with seaborn.
    num_classes = len(np.unique(colors))
    palette = np.array(sns.color_palette("hls", num_classes))

    # create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=palette[colors.astype(np.int)])
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('tight')

    # add the labels for each digit corresponding to the label
    txts = []

    for i in range(num_classes):

        # Position of each label at median of data points.

        xtext, ytext = np.median(x[colors == i, :], axis=0)
        txt = ax.text(xtext, ytext, str(i), fontsize=24)
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)

    return f, ax, sc, txts

In [None]:
X=output.cpu()
y = np.array([g.label for g in graphs])
############################################################
from sklearn.manifold import TSNE
import time
time_start = time.time()

fashion_tsne = TSNE(random_state=0).fit_transform(X)
fashion_scatter(fashion_tsne, y)

In [None]:
np.unique(y)

In [None]:
import networkx as nx
from scipy import linalg
## Coming from https://github.com/mseitzer/pytorch-fid
def compute_FID(mu1, mu2, cov1, cov2, eps = 1e-6):
    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert cov1.shape == cov2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(cov1.dot(cov2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
                'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(cov1.shape[0]) * eps
        covmean = linalg.sqrtm((cov1 + offset).dot(cov2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(cov1) +
            np.trace(cov2) - 2 * tr_covmean)

def compute_fid(ref_graph,pred_graph,model):
    device = 0
    
    with torch.no_grad():
        embed_graphs_ref = model.get_graph_embed_sum(ref_graph)
        embed_graphs_ref=embed_graphs_ref.cpu().detach().numpy()
        mu_ref = np.mean(embed_graphs_ref, axis = 0)
        cov_ref = np.cov(embed_graphs_ref, rowvar = False)

        embed_graphs_pred = model.get_graph_embed_sum(pred_graph)
        embed_graphs_pred=embed_graphs_pred.cpu().detach().numpy()
        mu_pred = np.mean(embed_graphs_pred, axis = 0)
        cov_pred = np.cov(embed_graphs_pred, rowvar = False)

    fid = compute_FID(mu_ref,mu_pred,cov_ref,cov_pred)
    return fid

def test_acc(model, device, train_graphs):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc = correct / float(len(train_graphs))


    #print("accuracy : %f" % (acc))

    return acc

In [None]:

def get_fid(filename,label,model):
    graph_gen = load_graph_list(filename,False)
    g_list,_=load_graph_asS2Vgraph(graph_gen,label,random,tagset,lentagset,onehot=onehot)
    print(filename+" fid : ",compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==label],g_list,model))
    print(filename+" GIN accuracy",test_acc(model,device,g_list))

In [None]:
# Grid graph

get_fid('../../generated_graphs/grid_GRANMixtureBernoulli_DFS.p',4,model)
get_fid('../../generated_graphs/grid_RNN_BFS.p',4,model)
get_fid('../../generated_graphs/grid_RNN_MLP_BFS.p',4,model)

In [None]:
#Barabasi
get_fid('../../generated_graphs/barabasi_GRANMixtureBernoulli_BFS.p',1,model)
get_fid('../../generated_graphs/barabasi_GRANMixtureBernoulli_DFS.p',1,model)
get_fid('../../generated_graphs/barabasi_GRANMixtureBernoulli_degree_decent.p',1,model)
get_fid('../../generated_graphs/barabasi_GRANMixtureBernoulli_k_core.p',1,model)
get_fid('../../generated_graphs/barabasi_GRANMixtureBernoulli_no_order.p',1,model)

get_fid('../../generated_graphs/barabasi_RNN_BFS.p',1,model)
get_fid('../../generated_graphs/barabasi_RNN_BFSMAX.p',1,model)
get_fid('../../generated_graphs/barabasi_RNN_DFS.p',1,model)
get_fid('../../generated_graphs/barabasi_RNN_nobfs.p',1,model)

In [None]:
#watts

get_fid('../../generated_graphs/wattsSW_GRANMixtureBernoulli_BFS.p',0,model)
get_fid('../../generated_graphs/wattsSW_GRANMixtureBernoulli_DFS.p',0,model)
get_fid('../../generated_graphs/wattsSW_GRANMixtureBernoulli_degree_descent.p',0,model)
get_fid('../../generated_graphs/wattsSW_GRANMixtureBernoulli_k_core.p',0,model)
get_fid('../../generated_graphs/wattsSW_GRANMixtureBernoulli_no_order.p',0,model)

get_fid('../../generated_graphs/wattsSW_RNN_BFS.p',0,model)
get_fid('../../generated_graphs/wattsSW_RNN_BFSMAX.p',0,model)
get_fid('../../generated_graphs/wattsSW_RNN_DFS.p',0,model)
get_fid('../../generated_graphs/wattsSW_RNN_nobfs.p',0,model)

In [None]:
#community2

get_fid('../../generated_graphs/community2small_GRANMixtureBernoulli_BFS.p',2,model)
get_fid('../../generated_graphs/community2small_GRANMixtureBernoulli_DFS.p',2,model)
get_fid('../../generated_graphs/community2small_GRANMixtureBernoulli_degree_decent.p',2,model)
get_fid('../../generated_graphs/community2small_GRANMixtureBernoulli_k_core.p',2,model)
get_fid('../../generated_graphs/community2small_GRANMixtureBernoulli_degree_accent.p',2,model)

get_fid('../../generated_graphs/community2small_RNN_BFS.p',2,model)
get_fid('../../generated_graphs/community2small_RNN_BFSMAX.p',2,model)
get_fid('../../generated_graphs/community2small_RNN_DFS.p',2,model)
get_fid('../../generated_graphs/community2small_RNN_nobfs.p',2,model)



In [None]:
def test_acc(model, device, train_graphs):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc = float(correct / float(len(train_graphs)))


    print("accuracy : %f" % (acc))

    return acc

def n_community(c_sizes, p_inter=0.01):
    graphs = [nx.gnp_random_graph(c_sizes[i], 0.7, seed=i) for i in range(len(c_sizes))]
    G = nx.disjoint_union_all(graphs)
    communities = [G.subgraph(c) for c in nx.connected_components(G)]
    for i in range(len(communities)):
        subG1 = communities[i]
        nodes1 = list(subG1.nodes())
        for j in range(i + 1, len(communities)):
            subG2 = communities[j]
            nodes2 = list(subG2.nodes())
            has_inter_edge = False
            for n1 in nodes1:
                for n2 in nodes2:
                    if np.random.rand() < p_inter:
                        G.add_edge(n1, n2)
                        has_inter_edge = True
            if not has_inter_edge:
                G.add_edge(nodes1[0], nodes2[0])
    # print('connected comp: ', len(list(nx.connected_component_subgraphs(G))))
    return G

test_acc(model,'cuda:0',g_list_pred_rnn)

## Perturbation Test

In [None]:
def perturb_new(graph_list, p):
    ''' Perturb the list of graphs by adding/removing edges.
    Args:
        p_add: probability of adding edges. If None, estimate it according to graph density,
            such that the expected number of added edges is equal to that of deleted edges.
        p_del: probability of removing edges
    Returns:
        A list of graphs that are perturbed from the original graphs
    '''
    perturbed_graph_list = []
    for G_original in graph_list:
        G = G_original.copy()
        edge_remove_count = 0
        for (u, v) in list(G.edges()):
            if np.random.rand()<p:
                G.remove_edge(u, v)
                edge_remove_count += 1
        # randomly add the edges back
        for i in range(edge_remove_count):
            while True:
                u = np.random.randint(0, G.number_of_nodes())
                v = np.random.randint(0, G.number_of_nodes())
                if (not G.has_edge(u,v)) and (u!=v):
                    break
            G.add_edge(u, v)
        perturbed_graph_list.append(G)
    return perturbed_graph_list

graph_ba_regen=[]
for i in range(10,20):
    for j in range(10,20):
        g=nx.grid_2d_graph(i,j)
        graph_ba_regen.append(g)
g_list_ref_ba_regen0,_ = load_graph_asS2Vgraph(graph_ba_regen,1,random, tagset , lentagset)
g_list_ref_ba_regen10,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.1),1,random, tagset , lentagset)
g_list_ref_ba_regen20,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.2),1,random, tagset , lentagset)
g_list_ref_ba_regen30,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.3),1,random, tagset , lentagset)
g_list_ref_ba_regen40,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.4),1,random, tagset , lentagset)
g_list_ref_ba_regen50,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.5),1,random, tagset , lentagset)
g_list_ref_ba_regen60,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.6),1,random, tagset , lentagset)
g_list_ref_ba_regen70,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.7),1,random, tagset , lentagset)

In [None]:

g_list_ref_ba_regen80,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.8),1,random, tagset , lentagset)
g_list_ref_ba_regen90,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.9),1,random, tagset , lentagset)
g_list_ref_ba_regen100,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,1.0),1,random, tagset , lentagset)


In [None]:
g_list_ref_ba_regen70,_ = load_graph_asS2Vgraph(perturb_new(graph_ba_regen,0.7),1,random, tagset , lentagset)


In [None]:
fids3=[]
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen0,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen10,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen20,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen30,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen40,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen50,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen70,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen80,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen90,model))
fids3.append(compute_fid([graphs[i] for i in range(len(graphs)) if graphs[i].label ==4],g_list_ref_ba_regen100,model))




In [None]:

import matplotlib.pyplot as plt
perturb = np.arange(0,0.8,0.1)
plt.figure(figsize=(10,7))
plt.xlabel("Noise perturbation")
plt.ylabel("Fréchet Distance to the ref graph")
plt.plot(perturb,fids3,label='grid')
plt.draw()
plt.legend()