In [4]:
import math

# for model_n in ["gcn", "gat", "sage", 'gin']:
import numpy as np
import torch
from scipy.spatial.distance import cosine, euclidean, correlation, chebyshev, braycurtis, canberra, cityblock, \
    sqeuclidean

from utils.attack_utils import concatenate_first_k_arrays, get_similarity_dict, append_dict_contents

DISTANCE_METRICS_LIST = [cosine, euclidean, correlation, chebyshev, braycurtis, canberra, cityblock, sqeuclidean]
DISTANCE_METRICS_LIST_NAME = ['cosine', 'euclidean', 'correlation', 'chebyshev', 'braycurtis', 'canberra', 'cityblock',
                              'sqeuclidean']
dataset_name = 'twitch'
attack_name = "dp2"
model_name = "gcn"

import random
from sklearn.metrics import recall_score, precision_score, f1_score


def get_average(data_list):
    return sum(data_list) / len(data_list)


def eval_batch(dataset, node):
    models = ['gat', 'gcn', 'gin', 'sage']
    for m in models:
        res = torch.load(f"outputs/model_outputs/{dataset}/{m}/uncons/4/dp2_random_dynamic_0.01_insertNode_{node}.pt")
        print(m)
        print(evaluate_dp2(res, 15, "con"))


def exp_threshold(data, num_positives, f_known):
    known_count = max(1, int(f_known * num_positives))

    # Randomly sample `known_count` scores from the first `num_positives`
    known_scores = random.sample(list(data[:num_positives]), known_count)

    # Use the minimum of known scores as the threshold
    threshold = min(known_scores)
    return threshold


def evaluate_with_threshold_given(data_list, num_connected, threshold):
    y_true = [1] * num_connected + [0] * (len(data_list) - num_connected)
    # threshold - 1 == the threshold-th largest number in the sorted list
    data_list = np.array(data_list)
    y_true = np.array(y_true)
    permutation = np.random.permutation(len(data_list))
    shuffled_probs = data_list[permutation]
    shuffled_labels = y_true[permutation]
    y_pred = []
    for cur in shuffled_probs:
        if cur >= threshold:
            y_pred.append(1)
        else:
            y_pred.append(0)
    y_pred = np.array(y_pred)
    recall = recall_score(shuffled_labels, y_pred)
    precision = precision_score(shuffled_labels, y_pred)

    f1 = f1_score(shuffled_labels, y_pred)
    return recall, precision, f1


def evaluate_attack_auc_ap(connected, unconnected):
    """
    Evaluates the attack performance using ROC AUC and average precision.

    Args:
        connected (list or array): Predicted scores for connected samples (positive class, label 0).
        unconnected (list or array): Predicted scores for unconnected samples (negative class, label 1).

    Returns:
        auc (float): Area Under the ROC Curve (ensured to be >= 0.5).
        ap (float): Average precision score.
    """
    # y = [0] * len(connected) + [1] * len(unconnected)
    # pred = connected + unconnected
    # fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=0)
    # auc = metrics.auc(fpr, tpr)
    # auc = max(auc, 1 - auc)
    # ap = metrics.average_precision_score(y, pred)
    return 0, 0


def get_average_metrics_nodes_new(data_list, num_connected_list, f_known=1):
    recalls = []
    precisions = []
    f1s = []
    aps = []
    aucs = []
    num_samples = len(data_list)

    for i in range(num_samples):
        n_degree = num_connected_list[i]
        cur_data = data_list[i]
        threshold = exp_threshold(cur_data, n_degree, f_known=f_known)

        recall, precision, f1 = evaluate_with_threshold_given(cur_data, n_degree, threshold)
        for num in range(len(cur_data)):
            if math.isnan(cur_data[num]):
                cur_data[num] = -100
        auc, ap = evaluate_attack_auc_ap(cur_data[:n_degree], cur_data[n_degree:])
        recalls.append(recall)
        precisions.append(precision)
        f1s.append(f1)
        aps.append(ap)
        aucs.append(auc)

    return get_average(recalls), get_average(precisions), get_average(f1s), get_average(
        aps), get_average(aucs)


def data_processor(data, mode="con", len=15):
    if mode == "con":
        return concatenate_first_k_arrays(data, len)
    elif mode == "mean":
        return np.mean(data, axis=0)
    elif mode == "median":
        return np.median(data, axis=0)


def evaluate_with_threshold(data_list, num_connected, threshold):
    print(data_list)
    print(num_connected)
    y_true = [1] * num_connected + [0] * (len(data_list) - num_connected)
    sorted_res = sorted(data_list, reverse=True)
    # threshold - 1 == the threshold-th largest number in the sorted list
    k_th_largest = sorted_res[threshold]
    data_list = np.array(data_list)
    y_true = np.array(y_true)
    permutation = np.random.permutation(len(data_list))
    shuffled_probs = data_list[permutation]
    shuffled_labels = y_true[permutation]
    y_pred = []
    cnt = 0
    for cur in shuffled_probs:
        if k_th_largest == 0:
            if cur > 0:
                y_pred.append(1)
                cnt += 1
            else:
                y_pred.append(0)
        elif cur > k_th_largest and cnt < threshold:
            y_pred.append(1)
            cnt += 1
        else:
            y_pred.append(0)
    y_pred = np.array(y_pred)
    recall = recall_score(shuffled_labels, y_pred)
    precision = precision_score(shuffled_labels, y_pred)
    # if recall != precision:
    #     print(recall)
    #     print(precision)
    #     print(shuffled_labels)
    #     print(y_pred)
    f1 = f1_score(shuffled_labels, y_pred)
    return recall, precision, f1


def get_average_metrics(data_list, num_connected_list, threshold_ratio=1):
    recalls = []
    precisions = []
    f1s = []
    aps = []
    aucs = []
    num_samples = len(data_list)

    for i in range(num_samples):
        n_degree = num_connected_list[i]
        if n_degree < 3:
            continue
        cur_data = data_list[i]
        if threshold_ratio > 1:
            threshold = math.ceil(threshold_ratio * n_degree)
        else:
            threshold = math.floor(threshold_ratio * n_degree)
        recall, precision, f1 = evaluate_with_threshold(data_list[i], n_degree, threshold)
        for num in range(len(cur_data)):
            if math.isnan(cur_data[num]):
                cur_data[num] = -100
        auc, ap = 0, 0
        recalls.append(recall)
        precisions.append(precision)
        f1s.append(f1)
        aps.append(ap)
        aucs.append(auc)
    return get_average(recalls), get_average(precisions), get_average(f1s), get_average(
        aps), get_average(aucs)


def evaluate_dp2(result, l=15, mode="con"):
    statistics = result['origin_output']
    average_recall_list, average_precision_list, average_f1_list, average_ap, average_auc = {}, {}, {}, {}, {}
    agg_dict_list = []
    for output in statistics:
        target_change_all_list = output["target"]
        perturbed_change_all_list = output["perturbed"]
        anchor_change_all_list = output["anchor"]
        all_simi = []
        for i in range(len(perturbed_change_all_list)):
            target_change = data_processor(target_change_all_list[i], mode, l)
            perturbed_change = data_processor(perturbed_change_all_list[i], mode, l)
            anchor_change = data_processor(anchor_change_all_list[i], mode, l)
            inf_change = perturbed_change - anchor_change
            simi_dict = get_similarity_dict(target_change, perturbed_change)
            simi_dict = {key: value for key, value in simi_dict.items()}
            all_simi.append(simi_dict)
        agg_dict = append_dict_contents(all_simi)
        agg_dict_list.append(agg_dict)
    metrics_outputs = {}
    for metrics_name in DISTANCE_METRICS_LIST_NAME:
        current_metric_outputs = []
        for output in agg_dict_list:
            current_metric_outputs.append(output[metrics_name])
        average_recall_list[metrics_name], average_precision_list[metrics_name], average_f1_list[
            metrics_name], average_ap[metrics_name], average_auc[metrics_name] = get_average_metrics(
            current_metric_outputs, result["num_direct_connections"],
            threshold_ratio=1)
        metrics_outputs[metrics_name] = current_metric_outputs
    return average_f1_list, metrics_outputs



In [13]:
mode = "con"
l = 15
dataset_name = 'lastfm'
attack_name = "dp2"
model_name = "gin"
for model_n in [model_name]:
    cur_result = torch.load(f"res/twitch/dynamic/{model_name}/3/dim_256_lr_0.001/prate_1.0/is_balanced_False/sm/dp2_random_dynamic_0.02_insertNode_same_neighborinsert_True_0.2.pt")
    # cur_result = torch.load(f"res/{dataset_name}/dynamic/{model_n}/3/dim_64_lr_0.01/prate_1.0_drate0.02/is_balanced_False/{attack_name}_random_dynamic_0.02_insertNode_same_neighborinsert_True_0.2.pt")
    statistics = cur_result["statistic"]
    metrics_name = "cosine"
    n_c = cur_result["num_direct_connections"]
    current_metric_outputs =[]
    print(cur_result.keys)
    if attack_name == "dp2":
        statistics = cur_result['origin_output']
        average_recall_list, average_precision_list, average_f1_list, average_ap, average_auc = {}, {}, {}, {}, {}
        agg_dict_list = []
        for output in statistics:
            target_change_all_list = output["target"]
            perturbed_change_all_list = output["perturbed"]
            anchor_change_all_list = output["anchor"]
            all_simi = []
            for i in range(len(perturbed_change_all_list)):
                target_change = data_processor(target_change_all_list[i], mode, l)
                perturbed_change = data_processor(perturbed_change_all_list[i], mode, l)
                anchor_change = data_processor(anchor_change_all_list[i], mode, l)
                inf_change = perturbed_change - anchor_change
                simi_dict = get_similarity_dict(target_change, perturbed_change)
                simi_dict = {key: value for key, value in simi_dict.items()}
                all_simi.append(simi_dict)
            agg_dict = append_dict_contents(all_simi)
            agg_dict_list.append(agg_dict)
        metrics_outputs = {}
        
        for output in agg_dict_list:
            current_metric_outputs.append(output[metrics_name])
        print(current_metric_outputs)
        current_metric_outputs = current_metric_outputs
    else:
        for output in statistics:
            if cur_result["attack_type"] == 'simi_1' or cur_result["attack_type"] == 'lsa_post':
                current_metric_outputs.append([(1 - i) for i in output[metrics_name]])
            elif cur_result["attack_type"] == "inf_3":
                local_out = []
                for influence_score_1, influence_score_3 in output:
                    if influence_score_3 != 0:
                        local_out.append(influence_score_1 / influence_score_3)
                    else:
                        local_out.append(influence_score_1)
                current_metric_outputs.append(local_out)
            else:
                current_metric_outputs.append(output[metrics_name])

    for f in [0.1, 0.2, 0.5]:
        recalls, precision, _, _, _ = get_average_metrics_nodes_new(current_metric_outputs, n_c, f)
        print(f"current fraction f: {f}")
        print(f"print recall: {recalls}")
        print(f"print precison: {precision}")


<built-in method keys of dict object at 0x000001A989569C80>
[[0.9391741752624512, 0.9838962554931641, 0.9337500929832458, 0.9562998414039612, 0.5338369011878967, 0.7818976640701294, 0.8447232246398926, 0.8106058835983276, 0.9498746395111084, 0.7919634580612183, -0.6331015825271606, -0.8743025064468384, -0.7348055243492126, -0.9182385802268982, 0.753666341304779, 0.16359052062034607, 0.8343740105628967, -0.9447759985923767, 0.8943159580230713, 0.46258971095085144, 0.5336700677871704, -0.2676227390766144, -0.880857527256012, -0.9417806267738342, -0.9115344882011414, 0.562484622001648, 0.7582706212997437, -0.8269256353378296, 0.8923115134239197, -0.9144995808601379, -0.32534259557724, -0.9737531542778015, 0.868882417678833, 0.8997818827629089, 0.8539026975631714, -0.8484165668487549, -0.8191038966178894, 0.9313744902610779, 0.5201218128204346, 0.9139025807380676, 0.35634469985961914, -0.6898310780525208, 0.6668199896812439, 0.3184705972671509, 0.935234546661377, 0.7994601726531982, 0.6656