In [1]:
import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as TVDatasets

import torch_geometric
from torch_geometric.data import Data as GraphData 

from torch_geometric.nn import GCNConv, GATConv, APPNP, SAGEConv
from torch_geometric.nn.models.label_prop import LabelPropagation
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx

from sklearn.metrics import classification_report
# from sklearn.calibration import CalibrationDisplay

import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Torch is running on {device}")

import os
import sys
from gnn_cp.data.data_manager import GraphDataManager
from gnn_cp.models.graph_models import GCN, GAT, APPNPNet, SAGE
from gnn_cp.models.model_manager import GraphModelManager
from gnn_cp.data.utils import make_dataset_instances
import gnn_cp.cp.transformations as cp_t
import gnn_cp.cp.graph_transformations as cp_gt
from gnn_cp.cp.graph_cp import GraphCP

Torch is running on cuda
Torch Graph Models are running on cuda
Torch Graph Models are running on cuda
Torch Graph Models are running on cuda
Torch Graph Models are running on cuda


In [2]:
config_file_dir = "configs/config.yaml"
results_dir = "results"
figures_dir = "reports/figures"

In [3]:
# loading the baseline settings
with open(config_file_dir, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
general_dataset_config = config.get("baseline", {}).get("general_dataset_config", {})


assert os.path.isdir(results_dir), "The results path does not exist!"

models_cache_dir = os.path.join(results_dir, "models")
assert os.path.isdir(models_cache_dir), "Directory to trained models is not found! Maybe first tun the make_baselines.py file"
data_dir = os.path.join(results_dir, "datasets")
assert os.path.isdir(data_dir), "Directory to Data Files is not found!"
splits_dir = os.path.join(results_dir, "splits")
assert os.path.isdir(splits_dir), "Directory to Data Splits is not found!"

splits_config = config.get("baseline", {}).get("general_dataset_config", {})

dataset_names = list(config.get("baseline", {}).get("datasets", {}).keys())
models_config = config.get("baseline", {}).get("models", {})
model_classes = list(models_config.keys())

# Making a directory to store results for CPs
cp_results_dir = os.path.join(results_dir, "cp_results")
if not os.path.isdir(cp_results_dir):
    os.mkdir(cp_results_dir)


# region
# Making dataset-split and model instances
dataset_str_list = '\n'.join([f'{i}: {dataset_name}' for i, dataset_name in enumerate(dataset_names)])
dataset_name_idx = int(input(f"specify the dataset index:\n{dataset_str_list}\n"))
dataset_key = dataset_names[int(dataset_name_idx)]

model_str_list = '\n'.join([f'{i}: {model_name}' for i, model_name in enumerate(model_classes)])
model_class_idx = int(input(f"specify the model index:\n{model_str_list}\n"))
model_class_name = model_classes[model_class_idx]

dataset_manager = GraphDataManager(data_dir, splits_dir)
dataset = dataset_manager.get_dataset_from_key(dataset_key).data

print(f"dataset = {dataset_key}")
instances = make_dataset_instances(data_dir, splits_dir, splits_config, models_cache_dir, dataset_key, model_class_name, models_config)

instances_accuracy = [instance["accuracy"] for instance in instances]
print(f"acc={np.mean(instances_accuracy)} +- {np.std(instances_accuracy)}")
best_model_accuracy = np.max(instances_accuracy)

instances_logits = [
    instance["model"].predict(dataset) for instance in instances
]

specify the dataset index:
0: cora_ml
1: cora_ml_largest_component
2: cora_full
3: pubmed
4: citeseer
5: coauthor_cs
6: coauthor_physics
7: amazon_computers
8: amazon_photo
 0
specify the model index:
0: GCN
1: GAT
2: SAGE
3: MLP
4: APPNPNet
 0


dataset = cora_ml
Dataset Loaded Successfully!
Following labeled splits:
class 0: train=20, val=20
class 1: train=20, val=20
class 2: train=20, val=20
class 3: train=20, val=20
class 4: train=20, val=20
class 5: train=20, val=20
class 6: train=20, val=20
Loading Models
Loading Models GCN
Accuracy: 0.8231307550644568 +- 0.009074238748131642
acc=0.8231307550644568 +- 0.009074238748131642


In [4]:
dataset_class = dataset_manager.get_dataset_from_key(dataset_key)
raw_data_path = os.path.join(dataset_class.raw_dir, dataset_class.raw_file_names)
raw_data = np.load(raw_data_path, allow_pickle=True)

print(list(raw_data))
raw_data["attr_text"].size

['idx_to_attr', 'attr_indices', 'attr_shape', 'idx_to_node', 'adj_shape', 'adj_indptr', 'adj_data', 'labels', 'attr_data', 'adj_indices', 'attr_indptr', 'idx_to_class', 'attr_text']


2995

In [5]:
instance_idx = 0
instance = instances[instance_idx]
train_idx, val_idx, test_idx = instance["train_idx"], instance["val_idx"], instance["test_idx"]
model = instance["model"]
accuracy = instance["accuracy"]
logits = instances_logits[instance_idx]
true_mask = (F.one_hot(dataset.y) == 1)

In [6]:
calib_fraction = 0.5

lambda_vals = np.arange(0, 1.51, 0.05).round(3)
coverage_values = np.arange(start=accuracy.round(2), stop=1.0, step=0.005)
fixed_neigh_coef = 0.55
selected_coverage = coverage_values[len(coverage_values)//2]


def singleton_hit(pred_set, true_mask):
    one_sized_pred = (pred_set.sum(axis=1) == 1)
    result = pred_set[true_mask][one_sized_pred].sum().item() / pred_set.shape[0]
    return result

singleton_hit_metric = lambda pred_set, true_mask: singleton_hit(pred_set, true_mask)
set_size_metric = lambda pred_set, true_mask: GraphCP.average_set_size(pred_set)
coverage_metric = lambda pred_set, true_mask: GraphCP.coverage(pred_set, true_mask)
argmax_accuracy = lambda pred_set, true_mask: GraphCP.argmax_accuracy(pred_set, true_mask)

metrics_dict = {
    "empi_coverage": coverage_metric,
    "average_set_size": set_size_metric,
    "singleton_hit": singleton_hit_metric,
}

In [7]:
k_reg = 0
penalty = 0.5
lambda_val = 0.7
label_mask = F.one_hot(dataset.y).bool()
calib_idx, eval_idx, calib_mask, eval_mask = dataset_manager.train_test_split(test_idx, true_mask, training_fraction=calib_fraction)

baseline_scores = cp_t.APSTransformation(softmax=True).pipe_transform(logits)
base_cp = GraphCP([], coverage_guarantee=selected_coverage)
base_cp.calibrate_from_scores(baseline_scores[calib_idx], label_mask[calib_idx])

# baseline
baseline_pred_set = base_cp.predict_from_scores(baseline_scores)

# regular
regular_scores = cp_t.RegularizerPenalty(k_reg=k_reg, penalty=penalty).pipe_transform(baseline_scores)
cp = GraphCP([], coverage_guarantee=selected_coverage)

cp.calibrate_from_scores(regular_scores[calib_idx], label_mask[calib_idx])
reg_pred_set = cp.predict_from_scores(regular_scores)

# mixing
mixing_scores = cp_gt.VertexMPTransformation(neigh_coef=lambda_val, edge_index=dataset.edge_index, n_vertices=dataset.x.shape[0]).pipe_transform(baseline_scores)
cp = GraphCP([], coverage_guarantee=selected_coverage)

cp.calibrate_from_scores(mixing_scores[calib_idx], label_mask[calib_idx])
mix_pred_set = cp.predict_from_scores(mixing_scores)

In [8]:
mix_indices = torch.where(mix_pred_set.sum(dim=-1) == 2)[0].tolist()
mix_indices = [idx for idx in mix_indices if idx in eval_idx]

reg_indices = torch.where(reg_pred_set.sum(dim=-1) == 2)[0].tolist()
reg_indices = [idx for idx in reg_indices if idx in eval_idx]

indices = []
for idx in mix_indices:
    if idx in reg_indices:
        mix_preds = mix_pred_set[idx].nonzero().squeeze().tolist()
        reg_preds = reg_pred_set[idx].nonzero().squeeze().tolist()
        if mix_preds != reg_preds:
            indices.append(idx)

In [9]:
with open('case_study.txt', 'w') as file:
    for idx in indices:
        # mixing
        mix_preds = mix_pred_set[idx].nonzero().squeeze().tolist()
        mix_preds = [raw_data['idx_to_class'].tolist()[cls] for cls in mix_preds]
        
        # regular
        reg_preds = reg_pred_set[idx].nonzero().squeeze().tolist()
        reg_preds = [raw_data['idx_to_class'].tolist()[cls] for cls in reg_preds]
        
        output = f"Node {idx}\nMixing predictions: ({mix_preds})\nRegular predictions: ({reg_preds})\nAbstract: {raw_data['attr_text'][idx]}\n\n"
        file.write(output)
