### helper file to handle UPFD GCNFN model implementation in https://github.com/safe-graph/GNN-FakeNews

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from util import run_model
from GNNFakeNews.utils.helpers.hyperparameter_factory import HparamFactory
from GNNFakeNews.utils.helpers.gnn_model_explainer import  GNNModelExplainer
from GNNFakeNews.utils.enums import GNNModelTypeEnum, GNNFeatureTypeEnum

  return torch._C._cuda_getDeviceCount() > 0


# 1.1. UPFD_GCNFN

## DATASET TYPE = POLITIFACT, FEATURE = SPACY

In [3]:
model_type = GNNModelTypeEnum.UPFD_GCNFN
# follow the paper of UPFD for best performance.
model, dataset_manager = run_model(model_type, hparams=HparamFactory(model_type, feature=GNNFeatureTypeEnum.BERT))

#################################
-----> The hyperparameters are set!
model_type = GNNModelTypeEnum.UPFD_GCNFN
dataset = GNNDatasetTypeEnum.POLITIFACT
batch_size = 128
lr = 0.01
weight_decay = 0.001
n_hidden = 128
epochs = 60
transform = <GNNFakeNews.utils.data_loader.ToUndirected object at 0x7f362c096b50>
feature = GNNFeatureTypeEnum.BERT
concat = True
#################################


ValueError: device cannot be DeviceTypeEnum.GPU, because CUDA is not available.

In [None]:
sample_data = dataset_manager.get_random_train_samples(device=model.m_args.device, label=0)[0]
explainer = GNNModelExplainer(model, sample_data)

In [None]:
sample_data.x.size()

In [None]:
explainer.visualize_explaining_graph()

In [None]:
len(sample_data.x[0])

### Randomly sample 10 fake news instances and explain them

In [None]:
sample_data_list = dataset_manager.get_random_train_samples(device=model.m_args.device, label=0, len_samples=10)
explainers = []
user_ids = []
for sample in sample_data_list:
    e = GNNModelExplainer(model, sample)
    e.visualize_explaining_graph()
    user_ids.append(e.get_node_ids_of_explaining_subgraph())
    explainers.append(e)

In [None]:
explainer.visualize_adjacency_matrix()

In [None]:
import numpy as np
arr = user_ids[0]
for i in range(1, len(user_ids)):
    arr = np.intersect1d(arr, user_ids[i])
print(arr)

### Randomly sample 10 real news instances and explain them

In [None]:
sample_data_list = dataset_manager.get_random_train_samples(device=model.m_args.device, label=1, len_samples=10)
explainers = []
user_ids = []
for sample in sample_data_list:
    e = GNNModelExplainer(model, sample)
    e.visualize_explaining_graph()
    user_ids.append(e.get_node_ids_of_explaining_subgraph())
    explainers.append(e)

In [None]:
user_ids

In [None]:
from util import load_pkl_files
node_id_user_id_map, node_id_time_map = load_pkl_files(model.m_hparams.dataset)

for i in explainer.subgraph.nodes.items():
    print(i[0])

In [None]:
from util import get_news_id_node_id_user_id_dict

news_node_user_dict, index_news_id_dict = get_news_id_node_id_user_id_dict(model.m_hparams.dataset)

In [None]:
# we make a sanity check
for idx in dataset_manager.train_set.indices:
    news_id = index_news_id_dict[idx]
    node_user_dict = news_node_user_dict[news_id]
    # get the node number of current graph
    num_nodes_ds = dataset_manager.train_set.dataset.get(idx).num_nodes
    num_nodes_dict = len(node_user_dict.keys()) + 1 # +1 is for root node which is not included in dict
    if num_nodes_ds != num_nodes_dict:
        print(f'num nodes in dataset: {num_nodes_ds} ## num nodes in dict: {num_nodes_dict}')

In [None]:
# fetch all fake news
fake_news_torch_ds = dataset_manager.fetch_all_news(label=0)
