In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina' # 'retina'
import logging
import torch
import numpy as np
import pickle
from pathlib import Path
import pandas as pd
import copy
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import time

from data.simulation import run_mc_repeats_
from main.dgisp import DGISP, DiffusionPropagate, Identity
from main.models.MLP import MLPTransform
from main.models.GraphSAGE import SupervisedGraphSage
from main.models.GAT import GAT
from main.models.SGC import SGC
from main.utils import to_nparray, to_torch, sp2adj_lists
from main.training import train_model, train_monstor, get_predictions_new_seeds, get_predictions_monstor, PIteration, FeatureCons, PIteration2
from main.earlystopping import stopping_args
from main.utils import load_dataset, load_latest_ckpt
from im.influspread import IS

logging.basicConfig(
    format='%(asctime)s:%(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO)
plt.style.use('seaborn')
me_op = lambda x, y: np.mean(np.abs(x - y))
te_op = lambda x, y: np.abs(np.sum(x) - np.sum(y))                               

# Load the dataset


### Load from saved SparseGraph object, with added prob_matrix and influ_mats

In [None]:
# key parameters
dataset = 'pubmed' # 'cora_ml', 'citeseer', 'ms_academic', 'pubmed'
model_name = 'dgisp' # 'dgisp' is just the DeepIS model

In [None]:
graph = load_dataset(dataset)
print(graph)

influ_mat_list = copy.copy(graph.influ_mat_list)
graph.influ_mat_list = graph.influ_mat_list[:50]
graph.influ_mat_list.shape, influ_mat_list.shape

# Build model

In [None]:
%load_ext autoreload
%autoreload 2

# training parameters
niter = 4 

propagate_model = lambda x, _, y:x[y]

fea_constructor = FeatureCons(model_name, niter=niter)
fea_constructor.prob_matrix = graph.prob_matrix

device = 'cpu' # 'cpu', 'cuda'
args_dict = {
    'learning_rate': 0.0001,
    'λ': 0,
    'γ': 0,
    'ckpt_dir': Path('./checkpoints'),
    'idx_split_args': {'ntraining': 1500, 'nstopping': 500, 'nval': 10, 'seed': 2413340114},  
    'test': False,
    'device': device,
    'print_interval': 1,
    'batch_size': None,
    
}

if model_name == 'dgisp':
    gnn_model = MLPTransform(input_dim=niter+1, hiddenunits=[64, 64], num_classes=1)
else:
    pass

model = DGISP( gnn_model=gnn_model, propagate=propagate_model)


## Train model from stratch

In [None]:
model, result = train_model(model_name + '_' + dataset, model, fea_constructor, graph, **args_dict)

# Prediction on NEW SEEDS

In [None]:
dataset = 'cora_ml'
graph = load_dataset(dataset)
influ_mat_list = copy.copy(graph.influ_mat_list)
graph

### predict

In [None]:
%%time

influ_mat = influ_mat_list[58]
seed_vec = influ_mat[:, 0]
seed_idx = np.argwhere(seed_vec == 1) # used by PIteration
influ_vec = influ_mat[:, -1]

fea_constructor.prob_matrix = graph.prob_matrix
preds = get_predictions_new_seeds(model, fea_constructor, seed_vec, np.arange(len(seed_vec)))
final_preds = PIteration(graph.prob_matrix, preds, seed_idx, True, 2)

print('mean error:', me(influ_vec, final_preds))
print('total error:', te(influ_vec, final_preds))
