In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs = {"bbox_inches": None}

In [None]:
figure_destination = "paper"

if figure_destination == "paper":
    figsize = (6.30045, 0.9*9.72632)
    fontsize_major = 9
    fontsize_minor = 7
    markersize_minor = 2
    markersize_major = 6

elif figure_destination == "slide":
    figsize = (6.10, 4.87)
    fontsize_major = 16
    fontsize_minor = 11
    markersize_minor = 4
    markersize_major = 8    


In [None]:
import os
import re
import pickle

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch_geometric
from torch_geometric.utils import assortativity, k_hop_subgraph
import networkx as nx
import numpy as np

def compute_graph_stats(graph):
    num_classes = len([k for k in graph.keys() if "locs" in k])

    label_locs = [graph[f"label_{l}_locs"].flatten() for l in range(num_classes)]
    all_label_locs = torch.cat(label_locs).flatten()

    label_counts = torch.tensor([len(label_locs[l]) for l in range(num_classes)])
    label_prop = label_counts / label_counts.sum()

    local_adj_matrix = torch.zeros(
        (num_classes, num_classes),
        )

    excess_homophily = [[] for l in range(num_classes)]

    # Iterate over all nodes with (potential) labels
    # Then iterate over all other potential labels nodes
    # Check if the other label nodes are in the 2-hop neighbourhood
    # i.e. check if other label nodes can influence the representations of label nodes
    for l, lbl_locs in enumerate(label_locs):
        # For each potential label, get the 2-hop subgraph
        for label_loc in lbl_locs.tolist():
            subset, _, _, _ = k_hop_subgraph(
                node_idx=label_loc, 
                num_hops=2,
                edge_index=graph["edge_index"],
            )

            subset_local_adj_matrix = torch.zeros((num_classes, ))

            # Find all other label nodes in that subgraph
            for ll, lbl_locs2 in enumerate(label_locs):
                # The -1 is to remove self-edges
                ll_neighbourhood_size = torch.isin(
                    subset,
                    lbl_locs2,
                    assume_unique=True
                ).sum() #- (1 if l == ll else 0)
                
                local_adj_matrix[l, ll] += ll_neighbourhood_size
                
                subset_local_adj_matrix[ll] = ll_neighbourhood_size

            # Find the number of homophilic labels and subtract the label propensity from it
            # Thus, the excess homophily for the current labelled node
            subset_excess_homophily = subset_local_adj_matrix[l] / subset_local_adj_matrix.sum() - label_prop[l]
            
            # Normalize relative to a perfectly homophilic label node
            subset_excess_homophily = subset_excess_homophily / (1 - label_prop[l])

            if not torch.isnan(subset_excess_homophily):
                excess_homophily[l] += [subset_excess_homophily]

    # Compute the mean relative excess homophily for each class separately
    mean_rel_excess_homophily = list(map(lambda x: torch.mean(torch.stack(x)).item(), excess_homophily))
    median_rel_excess_homophily = list(map(lambda x: torch.median(torch.stack(x)).item(), excess_homophily))
    min_rel_excess_homophily = list(map(lambda x: torch.min(torch.stack(x)).item(), excess_homophily))
    max_rel_excess_homophily = list(map(lambda x: torch.max(torch.stack(x)).item(), excess_homophily))

    # Compute graph homophilly
    h = torch.diag(local_adj_matrix) / local_adj_matrix.sum(dim=1)
    graph_homophily = (torch.sum(torch.clip(h - label_prop, min=0)) / (num_classes - 1)).item()

    # Compute graph assortativity
    mixing_matrix = local_adj_matrix / local_adj_matrix.sum()

    dot_prod_marginals = torch.dot(torch.sum(mixing_matrix, dim=0), torch.sum(mixing_matrix, dim=1))
    assortativity = (torch.trace(mixing_matrix) - dot_prod_marginals) / (1 - dot_prod_marginals)
    assortativity = assortativity.item()

    # Use networkx for some other standard graph statistics
    all_label_locs = all_label_locs.tolist()
    
    num_nodes = graph["num_nodes"].item()
    num_edges = graph["num_edges"].item()
    num_labels = len(all_label_locs)
    density = (2 * num_edges) / (num_nodes * (num_nodes - 1))
    
    networkx_graph = torch_geometric.utils.to_networkx(
        torch_geometric.data.Data(edge_index=graph["edge_index"], num_nodes=graph["num_nodes"]),
        to_undirected=True
        )

    degree_centrality = nx.degree_centrality(networkx_graph)
    degree_centrality = np.mean([degree_centrality[i] for i in all_label_locs])

    eigen_centrality = nx.eigenvector_centrality_numpy(networkx_graph, max_iter=10)
    eigen_centrality = np.mean([eigen_centrality[i] for i in all_label_locs])

    graph_stats = {
            "homophily": graph_homophily,
            "assortativity": assortativity,
            "mean_rel_excess_homophily": mean_rel_excess_homophily,
            "num_nodes": num_nodes,
            "num_edges": num_edges,
            "num_labels": num_labels,
            "graph_density": density,
            "label_density": num_labels / num_nodes,
            "degree_centrality": degree_centrality,
            "eigen_centrality": eigen_centrality,
        }

    return graph_stats


# Gossipcop

## Support

In [None]:
loc = "./data/structured/gossipcop/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[30]_topexcl[1]_userdoc[30]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0"

subdirs = next(os.walk(loc))[1]
for subdir in subdirs:
    
    split = re.search("split\=(.+?), ", subdir).group(1)
    
    version  = re.search("version\=(.+?)\)", subdir)
    
    if version is None:
        continue
    else:
        meta_split = version.group(1)
    
    if split == "train" and meta_split == "meta_train_support":
        support_batches_loc = loc + f"/{subdir}"
    elif split == "train" and meta_split == "meta_train_query":
        query_batches_loc = loc + f"/{subdir}"


In [None]:
support_graph_stats = []

batch_locs = next(os.walk(support_batches_loc))[2]

for i, batched_graph in enumerate(tqdm(batch_locs)):
    
    batched_graph_loc = support_batches_loc + f"/{batched_graph}"
    
    graph = torch.load(batched_graph_loc, map_location="cpu")

    support_graph_stats += [compute_graph_stats(graph)]


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

records = list(
    map(
        lambda x: {
            "label": 0,
            "excess_homophily": x["mean_rel_excess_homophily"][0],
            },
        support_graph_stats
        )
    )

records += list(
    map(
        lambda x: {
            "label": 1,
            "excess_homophily": x["mean_rel_excess_homophily"][1],
            },
        support_graph_stats
        )
    )

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette="tab10",
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")

ax.set_ylabel("Gossipcop", fontsize=11)
ax.set_yticks([])

ax.set_title("Support", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/gossipcop_support.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/gossipcop_support.png")

In [None]:
ax = sns.kdeplot(
    x=list(map(lambda x: x["homophily"], support_graph_stats)),
    bw_adjust=0.8,
    clip=(0, 1),
    fill=True,
    multiple="layer",
    )

ax.set_xlim([0, 1])

In [None]:
ax = sns.kdeplot(
    x=list(map(lambda x: x["assortativity"], support_graph_stats)),
    bw_adjust=0.8,
    clip=(-1, 1),
    fill=True,
    multiple="layer",
    )

ax.set_xlim([-1, 1])

In [None]:
with open("../misc/stats/gossipcop_support.pickle", "wb") as f:
    pickle.dump(support_graph_stats, f)


## Query

In [None]:
import pickle

with open("./data/structured/gossipcop/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[30]_topexcl[1]_userdoc[30]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0/socialgraph(mode=inductive, split=train, keep_cc=largest).pickle", "rb") as f:
    graph_dataset = pickle.load(f)

num_classes = graph_dataset["num_classes"]

graph_labels = graph_dataset["graph"]["y"]
label_locs = [torch.where(graph_labels == l)[0] for l in range(num_classes)]
all_label_locs = torch.cat(label_locs)

label_prop = torch.tensor(list(map(lambda x: x.shape[0], label_locs)))
label_prop = label_prop / label_prop.sum()

num_nodes = graph_dataset["graph"].num_nodes

all_num_nodes = []
all_num_edges = []
all_density = []
all_num_labels = []

rel_excess_homophily = {l: [] for l in range(num_classes)}

max_iterations = min(label_locs[l].shape[0] for l in range(num_classes))

for l in range(num_classes):
    for i, label_loc in enumerate(label_locs[l]):
        
        subset, edge_index, _, _ = k_hop_subgraph(
            node_idx=label_loc.unsqueeze(0),
            edge_index=graph_dataset["graph"]["edge_index"],
            num_hops=2,
        )

        subset_local_adj_matrix = torch.zeros((num_classes, ))

        for ll in range(num_classes):

            ll_label_nodes = torch.isin(
                test_elements=subset,
                elements=label_locs[ll],
                assume_unique=True,
            ).sum()

            subset_local_adj_matrix[ll] += ll_label_nodes

        all_num_nodes += [subset.shape[0]]
        all_num_edges += [edge_index.shape[1]]
        all_density += [(2 * all_num_edges[-1]) / (all_num_nodes[-1] * (all_num_nodes[-1] - 1))]
        all_num_labels += [subset_local_adj_matrix.sum()]

        subset_local_adj_matrix = subset_local_adj_matrix / subset_local_adj_matrix.sum()
        local_rel_excess_homophily = (subset_local_adj_matrix - label_prop) / (1 - label_prop)

        rel_excess_homophily[l] += [local_rel_excess_homophily[l].item()]

        if i == max_iterations-1:
            break

networkx_graph = torch_geometric.utils.to_networkx(
    graph_dataset["graph"],
    to_undirected=True
    )

degree_centrality = nx.degree_centrality(networkx_graph)
degree_centrality = [degree_centrality[i] for i in all_label_locs.tolist()]

eigen_centrality = nx.eigenvector_centrality_numpy(networkx_graph, max_iter=10)
eigen_centrality = [eigen_centrality[i] for i in all_label_locs.tolist()]

del graph_dataset, networkx_graph

graph_stats = {
    "mean_rel_excess_homophily": rel_excess_homophily,
    "num_nodes": all_num_nodes,
    "num_edges": all_num_edges,
    "num_labels": all_num_labels,
    "graph_density": all_density,
    "label_density": [(num_labels / num_nodes).item() for num_labels, num_nodes in zip(all_num_labels, all_num_nodes)],
    "degree_centrality": degree_centrality,
    "eigen_centrality": eigen_centrality,
    }


In [None]:
records = [
    {"label": l, "excess_homophily": metric_val}
    for l, metric_vals in rel_excess_homophily.items()
    for metric_val in metric_vals
    ]

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette="tab10",
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")
ax.get_legend().remove()

ax.set_ylabel("", fontsize=11)
ax.set_yticks([])

ax.set_title("Query", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/gossipcop_query.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/gossipcop_query.png")

In [None]:
with open("../misc/stats/gossipcop_query.pickle", "wb") as f:
    pickle.dump(graph_stats, f)


# Twitter Hate Speech

## Support

In [None]:
loc = "./data/structured/twitterHateSpeech/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[30]_topexcl[0]_userdoc[100]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0"

subdirs = next(os.walk(loc))[1]
for subdir in subdirs:
    
    split = re.search("split\=(.+?), ", subdir).group(1)
    
    version  = re.search("version\=(.+?)\)", subdir)
    
    if version is None:
        continue
    else:
        meta_split = version.group(1)
    
    if split == "train" and meta_split == "meta_train_support":
        support_batches_loc = loc + f"/{subdir}"
    elif split == "train" and meta_split == "meta_train_query":
        query_batches_loc = loc + f"/{subdir}"


In [None]:
support_graph_stats = []

batch_locs = next(os.walk(support_batches_loc))[2]
for i, batched_graph in enumerate(tqdm(batch_locs)):
    
    batched_graph_loc = support_batches_loc + f"/{batched_graph}"
    
    graph = torch.load(batched_graph_loc, map_location="cpu")

    support_graph_stats += [compute_graph_stats(graph)]


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

records = list(
    map(
        lambda x: {
            "label": 0,
            "excess_homophily": x["mean_rel_excess_homophily"][0],
            },
        support_graph_stats
        )
    )

records += list(
    map(
        lambda x: {
            "label": 1,
            "excess_homophily": x["mean_rel_excess_homophily"][1],
            },
        support_graph_stats
        )
    )

records += list(
    map(
        lambda x: {
            "label": 2,
            "excess_homophily": x["mean_rel_excess_homophily"][2],
            },
        support_graph_stats
        )
    )

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette="tab10"
    )

ax.set_xlim([-1.25, 1.25])
#ax.set_xticklabels([])
#ax.set_xlabel("")
ax.set_xlabel("Rel. Excess Homophily")

ax.set_ylabel("Twitter Hate Speech", fontsize=11)
ax.set_yticks([])

#ax.set_title("Support", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/twitterHateSpeech_support.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/twitterHateSpeech_support.png")

In [None]:
with open("../misc/stats/twitterhatespeech_support.pickle", "wb") as f:
    pickle.dump(support_graph_stats, f)

## Query

In [None]:
import pickle

with open("./data/structured/twitterHateSpeech/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[30]_topexcl[0]_userdoc[100]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0/socialgraph(mode=inductive, split=train, keep_cc=largest).pickle", "rb") as f:
    graph_dataset = pickle.load(f)

num_classes = graph_dataset["num_classes"]

graph_labels = graph_dataset["graph"]["y"]
label_locs = [torch.where(graph_labels == l)[0] for l in range(num_classes)]
all_label_locs = torch.cat(label_locs)

label_prop = torch.tensor(list(map(lambda x: x.shape[0], label_locs)))
label_prop = label_prop / label_prop.sum()

num_nodes = graph_dataset["graph"].num_nodes

all_num_nodes = []
all_num_edges = []
all_density = []
all_num_labels = []

rel_excess_homophily = {l: [] for l in range(num_classes)}

max_iterations = min(label_locs[l].shape[0] for l in range(num_classes))

for l in range(num_classes):
    for i, label_loc in enumerate(label_locs[l]):
        
        subset, edge_index, _, _ = k_hop_subgraph(
            node_idx=label_loc.unsqueeze(0),
            edge_index=graph_dataset["graph"]["edge_index"],
            num_hops=2,
        )

        subset_local_adj_matrix = torch.zeros((num_classes, ))

        for ll in range(num_classes):

            ll_label_nodes = torch.isin(
                test_elements=subset,
                elements=label_locs[ll],
                assume_unique=True,
            ).sum()

            subset_local_adj_matrix[ll] += ll_label_nodes

        all_num_nodes += [subset.shape[0]]
        all_num_edges += [edge_index.shape[1]]
        all_density += [(2 * all_num_edges[-1]) / (all_num_nodes[-1] * (all_num_nodes[-1] - 1))]
        all_num_labels += [subset_local_adj_matrix.sum()]

        subset_local_adj_matrix = subset_local_adj_matrix / subset_local_adj_matrix.sum()
        local_rel_excess_homophily = (subset_local_adj_matrix - label_prop) / (1 - label_prop)

        rel_excess_homophily[l] += [local_rel_excess_homophily[l].item()]

        #if i == max_iterations-1:
        #    break

networkx_graph = torch_geometric.utils.to_networkx(
    graph_dataset["graph"],
    to_undirected=True
    )

degree_centrality = nx.degree_centrality(networkx_graph)
degree_centrality = [degree_centrality[i] for i in all_label_locs.tolist()]

eigen_centrality = nx.eigenvector_centrality_numpy(networkx_graph, max_iter=10)
eigen_centrality = [eigen_centrality[i] for i in all_label_locs.tolist()]

del graph_dataset, networkx_graph

graph_stats = {
    "mean_rel_excess_homophily": rel_excess_homophily,
    "num_nodes": all_num_nodes,
    "num_edges": all_num_edges,
    "num_labels": all_num_labels,
    "graph_density": all_density,
    "label_density": [(num_labels / num_nodes).item() for num_labels, num_nodes in zip(all_num_labels, all_num_nodes)],
    "degree_centrality": degree_centrality,
    "eigen_centrality": eigen_centrality,
    }


In [None]:
records = [
    {"label": l, "excess_homophily": metric_val}
    for l, metric_vals in rel_excess_homophily.items()
    for metric_val in metric_vals
    ]

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.histplot(
    data=df,
    x="excess_homophily",
    stat='density',
    #bw_adjust=0.8,
    #clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="dodge",
    ax=ax,
    palette=sns.color_palette("tab10"),
    common_norm=False,
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")
ax.get_legend().remove()

ax.set_ylabel("", fontsize=11)
ax.set_yticks([])

ax.set_title("Query", fontsize=11)


In [None]:
records = [
    {"label": l, "excess_homophily": metric_val}
    for l, metric_vals in rel_excess_homophily.items()
    for metric_val in metric_vals
    ]

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=1,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette=sns.color_palette("tab10"),
    common_norm=False,
    )

ax.set_xlim([-1.25, 1.25])
#ax.set_xticklabels([])
#ax.set_xlabel("")
ax.set_xlabel("Rel. Excess Homophily")
ax.get_legend().remove()

ax.set_ylabel("", fontsize=11)
ax.set_yticks([])

#ax.set_title("Query", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/twitterHateSpeech_query.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/twitterHateSpeech_query.png")

In [None]:
with open("../misc/stats/twitterhatespeech_query.pickle", "wb") as f:
    pickle.dump(graph_stats, f)


# CoAID

## Support

In [None]:
loc = "./data/structured/CoAID/seed[942]_splits[0]_minlen[0]_filterisolated[True]_topk[30]_topexcl[0]_userdoc[30]_featuretype[lm-embeddings]_vocab[external][roberta-base][NonexNone]_userfeatures[post][zeros]_version[transfer_77c5a6tu]/0"

subdirs = next(os.walk(loc))[1]
for subdir in subdirs:
    
    split = re.search("split\=(.+?), ", subdir).group(1)
    
    version  = re.search("version\=(.+?)\)", subdir)
    
    if version is None:
        continue
    else:
        meta_split = version.group(1)
    
    if split == "test" and meta_split == "meta_train_support":
        support_batches_loc = loc + f"/{subdir}"
    elif split == "train" and meta_split == "meta_train_query":
        query_batches_loc = loc + f"/{subdir}"


In [None]:
support_graph_stats = []

batch_locs = next(os.walk(support_batches_loc))[2]

for i, batched_graph in enumerate(tqdm(batch_locs)):
    
    batched_graph_loc = support_batches_loc + f"/{batched_graph}"
    
    graph = torch.load(batched_graph_loc, map_location="cpu")

    support_graph_stats += [compute_graph_stats(graph)]


In [None]:
with open("../misc/stats/coaid_support.pickle", "wb") as f:
    pickle.dump(support_graph_stats, f)

## Query

In [None]:
import pickle

with open("./data/structured/CoAID/seed[942]_splits[0]_minlen[0]_filterisolated[True]_topk[30]_topexcl[0]_userdoc[30]_featuretype[lm-embeddings]_vocab[external][roberta-base][NonexNone]_userfeatures[post][zeros]_version[transfer_77c5a6tu]/0/episodickhopneighbourhoodsocialgraph(mode=transductive, split=test, k_shot=4, prop_query=0.0, max_k_hop=5, budget=2048, doc_k_hop=2).pickle", "rb") as f:
    graph_dataset = pickle.load(f)

In [None]:
num_classes = graph_dataset["num_classes"]

graph_labels = graph_dataset["graph"]["y"]
label_locs = [torch.where(graph_labels == l)[0] for l in range(num_classes)]
all_label_locs = torch.cat(label_locs)

label_prop = torch.tensor(list(map(lambda x: x.shape[0], label_locs)))
label_prop = label_prop / label_prop.sum()

num_nodes = graph_dataset["graph"].num_nodes

all_num_nodes = []
all_num_edges = []
all_density = []
all_num_labels = []

rel_excess_homophily = {l: [] for l in range(num_classes)}

max_iterations = min(label_locs[l].shape[0] for l in range(num_classes))

for l in range(num_classes):
    for i, label_loc in enumerate(label_locs[l]):
        
        subset, edge_index, _, _ = k_hop_subgraph(
            node_idx=label_loc.unsqueeze(0),
            edge_index=graph_dataset["graph"]["edge_index"],
            num_hops=2,
        )

        subset_local_adj_matrix = torch.zeros((num_classes, ))

        for ll in range(num_classes):

            ll_label_nodes = torch.isin(
                test_elements=subset,
                elements=label_locs[ll],
                assume_unique=True,
            ).sum()

            subset_local_adj_matrix[ll] += ll_label_nodes

        all_num_nodes += [subset.shape[0]]
        all_num_edges += [edge_index.shape[1]]
        all_density += [(2 * all_num_edges[-1]) / (all_num_nodes[-1] * (all_num_nodes[-1] - 1))]
        all_num_labels += [subset_local_adj_matrix.sum()]

        subset_local_adj_matrix = subset_local_adj_matrix / subset_local_adj_matrix.sum()
        local_rel_excess_homophily = (subset_local_adj_matrix - label_prop) / (1 - label_prop)

        rel_excess_homophily[l] += [local_rel_excess_homophily[l].item()]

        #if i == max_iterations-1:
        #    break

networkx_graph = torch_geometric.utils.to_networkx(
    graph_dataset["graph"],
    to_undirected=True
    )

degree_centrality = nx.degree_centrality(networkx_graph)
degree_centrality = [degree_centrality[i] for i in all_label_locs.tolist()]

eigen_centrality = nx.eigenvector_centrality_numpy(networkx_graph, max_iter=10)
eigen_centrality = [eigen_centrality[i] for i in all_label_locs.tolist()]

del graph_dataset, networkx_graph

graph_stats = {
    "mean_rel_excess_homophily": rel_excess_homophily,
    "num_nodes": all_num_nodes,
    "num_edges": all_num_edges,
    "num_labels": all_num_labels,
    "graph_density": all_density,
    "label_density": [(num_labels / num_nodes).item() for num_labels, num_nodes in zip(all_num_labels, all_num_nodes)],
    "degree_centrality": degree_centrality,
    "eigen_centrality": eigen_centrality,
    }


In [None]:
np.mean(graph_stats['mean_rel_excess_homophily'][1])

In [None]:
subset_local_adj_matrix

# HealthStory

## Support

In [None]:
loc = "./data/structured/HealthStory/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[20]_topexcl[0]_userdoc[30]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0"

subdirs = next(os.walk(loc))[1]
for subdir in subdirs:
    
    split = re.search("split\=(.+?), ", subdir).group(1)
    
    version  = re.search("version\=(.+?)\)", subdir)
    
    if version is None:
        continue
    else:
        meta_split = version.group(1)
    
    if split == "train" and meta_split == "meta_train_support":
        support_batches_loc = loc + f"/{subdir}"
    elif split == "train" and meta_split == "meta_train_query":
        query_batches_loc = loc + f"/{subdir}"


In [None]:
support_graph_stats = []

batch_locs = next(os.walk(support_batches_loc))[2]
for i, batched_graph in enumerate(tqdm(batch_locs)):
    
    batched_graph_loc = support_batches_loc + f"/{batched_graph}"
    
    graph = torch.load(batched_graph_loc, map_location="cpu")

    support_graph_stats += [compute_graph_stats(graph)]


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

records = list(
    map(
        lambda x: {
            "label": 0,
            "excess_homophily": x["mean_rel_excess_homophily"][0],
            },
        healthstory_support_stats,
        )
    )

records += list(
    map(
        lambda x: {
            "label": 1,
            "excess_homophily": x["mean_rel_excess_homophily"][1],
            },
        healthstory_support_stats,
        )
    )

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette="tab10"
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")

ax.set_ylabel("HealthStory", fontsize=11)
ax.set_yticks([])

#ax.set_title("Support", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/healthstory_support.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/healthstory_support.png")

In [None]:
with open("../misc/stats/healthstory_support.pickle", "wb") as f:
    pickle.dump(support_graph_stats, f)

## Query

In [None]:
import pickle

with open("./data/structured/HealthStory/seed[942]_splits[5]_minlen[0]_filterisolated[True]_topk[20]_topexcl[0]_userdoc[30]_featuretype[one-hot]_vocab[joint][random][10000x768]_userfeatures[post][zeros]/0/socialgraph(mode=inductive, split=train, keep_cc=largest).pickle", "rb") as f:
    graph_dataset = pickle.load(f)

In [None]:
num_classes = graph_dataset["num_classes"]

graph_labels = graph_dataset["graph"]["y"]
label_locs = [torch.where(graph_labels == l)[0] for l in range(num_classes)]
all_label_locs = torch.cat(label_locs)

label_prop = torch.tensor(list(map(lambda x: x.shape[0], label_locs)))
label_prop = label_prop / label_prop.sum()

num_nodes = graph_dataset["graph"].num_nodes

all_num_nodes = []
all_num_edges = []
all_density = []
all_num_labels = []

rel_excess_homophily = {l: [] for l in range(num_classes)}

max_iterations = min(label_locs[l].shape[0] for l in range(num_classes))

for l in range(num_classes):
    for i, label_loc in enumerate(label_locs[l]):
        
        subset, edge_index, _, _ = k_hop_subgraph(
            node_idx=label_loc.unsqueeze(0),
            edge_index=graph_dataset["graph"]["edge_index"],
            num_hops=2,
        )

        subset_local_adj_matrix = torch.zeros((num_classes, ))

        for ll in range(num_classes):

            ll_label_nodes = torch.isin(
                test_elements=subset,
                elements=label_locs[ll],
                assume_unique=True,
            ).sum()

            subset_local_adj_matrix[ll] += ll_label_nodes

        all_num_nodes += [subset.shape[0]]
        all_num_edges += [edge_index.shape[1]]
        all_density += [(2 * all_num_edges[-1]) / (all_num_nodes[-1] * (all_num_nodes[-1] - 1))]
        all_num_labels += [subset_local_adj_matrix.sum()]

        subset_local_adj_matrix = subset_local_adj_matrix / subset_local_adj_matrix.sum()
        local_rel_excess_homophily = (subset_local_adj_matrix - label_prop) / (1 - label_prop)

        rel_excess_homophily[l] += [local_rel_excess_homophily[l].item()]

        #if i == max_iterations-1:
        #    break


networkx_graph = torch_geometric.utils.to_networkx(
    graph_dataset["graph"],
    to_undirected=True
    )

degree_centrality = nx.degree_centrality(networkx_graph)
degree_centrality = [degree_centrality[i] for i in all_label_locs.tolist()]

eigen_centrality = nx.eigenvector_centrality_numpy(networkx_graph, max_iter=10)
eigen_centrality = [eigen_centrality[i] for i in all_label_locs.tolist()]

del graph_dataset, networkx_graph

graph_stats = {
    "mean_rel_excess_homophily": rel_excess_homophily,
    "num_nodes": all_num_nodes,
    "num_edges": all_num_edges,
    "num_labels": all_num_labels,
    "graph_density": all_density,
    "label_density": [(num_labels / num_nodes).item() for num_labels, num_nodes in zip(all_num_labels, all_num_nodes)],
    "degree_centrality": degree_centrality,
    "eigen_centrality": eigen_centrality,
    }


In [None]:
records = [
    {"label": l, "excess_homophily": metric_val}
    for l, metric_vals in rel_excess_homophily.items()
    for metric_val in metric_vals
    ]

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.kdeplot(
    data=df,
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=ax,
    palette=sns.color_palette("tab10"),
    common_norm=False,
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")
ax.get_legend().remove()

ax.set_ylabel("", fontsize=11)
ax.set_yticks([])

#ax.set_title("Query", fontsize=11)


In [None]:
fig.tight_layout()
fig.savefig("../misc/figures/rel_excess_homophily/healthstory_query.pdf")
fig.savefig("../misc/figures/rel_excess_homophily/healthstory_query.png")

In [None]:
with open("../misc/stats/healthstory_query.pickle", "wb") as f:
    pickle.dump(graph_stats, f)


In [None]:
records = [
    {"label": l, "excess_homophily": metric_val}
    for l, metric_vals in rel_excess_homophily.items()
    for metric_val in metric_vals
    ]

df = pd.DataFrame.from_records(records)

fig, ax = plt.subplots(1, 1, figsize=(6.30 / 2, 9.72 / 3))

ax = sns.histplot(
    data=df,
    x="excess_homophily",
    hue="label",
    stat="probability",
    fill=True,
    multiple="dodge",
    ax=ax,
    palette=sns.color_palette("tab10"),
    common_norm=False,
    )

ax.set_xlim([-1.25, 1.25])
ax.set_xticklabels([])
ax.set_xlabel("")
#ax.set_xlabel("Rel. Excess Homophily")
ax.get_legend().remove()

ax.set_ylabel("", fontsize=11)
ax.set_yticks([])

#ax.set_title("Query", fontsize=11)


# Aggregated Table

In [None]:
import pickle

import numpy as np
import pandas as pd

In [None]:
with open("../misc/stats/gossipcop_support.pickle", "rb") as f:
    gossipcop_support_stats = pickle.load(f)

with open("../misc/stats/gossipcop_query.pickle", "rb") as f:
    gossipcop_query_stats = pickle.load(f)

with open("../misc/stats/coaid_support.pickle", "rb") as f:
    coaid_support_stats = pickle.load(f)

with open("../misc/stats/coaid_query.pickle", "rb") as f:
    coaid_query_stats = pickle.load(f)


In [None]:
with open("../misc/stats/twitterhatespeech_support.pickle", "rb") as f:
    twitter_support_stats = pickle.load(f)

with open("../misc/stats/twitterhatespeech_query.pickle", "rb") as f:
    twitter_query_stats = pickle.load(f)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

def make_support_df(stats, num_labels: int = 2):
    records = []
    
    for l in range(num_labels):
        records += list(
            map(
                lambda x: {
                    "label": l,
                    "excess_homophily": x["mean_rel_excess_homophily"][l],
                    },
                stats
                )
            )

    df = pd.DataFrame.from_records(records)
    
    return df

def make_query_df(stats):
    records = [
        {"label": l, "excess_homophily": metric_val}
        for l, metric_vals in stats["mean_rel_excess_homophily"].items()
        for metric_val in metric_vals
    ]

    df = pd.DataFrame.from_records(records)
    
    return df


In [None]:
hatches = ['', '\\\\', '//']

fig, axes = plt.subplots(3, 2, figsize=figsize)
flat_axes = np.ravel(axes)

axes[0, 0].set_title("Support", fontsize=11)
axes[0, 1].set_title("Query", fontsize=11)

# Gossipcop
axes[0, 0] = sns.kdeplot(
    data=make_support_df(gossipcop_support_stats, num_labels=2),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[0, 0],
    palette="tab10",
    label="label",
    )
axes[0, 0].set_ylabel("Gossipcop", fontsize=11)

handles = []
for collection, handle, hatch in zip(axes[0, 0].collections[::-1], axes[0, 0].get_legend().legend_handles, hatches):
    collection.set_hatch(hatch)
    handle.set_hatch(hatch)
    
    handles.append(handle)

axes[0, 0].legend(
    handles=handles,
    labels=["Real", "Fake"],
    loc='upper left',
    title=""
    )

axes[0, 1] = sns.kdeplot(
    data=make_query_df(gossipcop_query_stats),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[0, 1],
    palette="tab10",
    )
axes[0, 1].set_ylabel("", fontsize=11)
axes[0, 1].get_legend().remove()

for collection, hatch in zip(axes[0, 1].collections[::-1], hatches):
    collection.set_hatch(hatch)

# TwitterHateSpeech
axes[1, 0] = sns.kdeplot(
    data=make_support_df(coaid_support_stats, num_labels=2),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[1, 0],
    palette="tab10",
    )
axes[1, 0].set_ylabel("CoAID", fontsize=11)

handles = []
for collection, handle, hatch in zip(axes[1, 0].collections[::-1], axes[1, 0].get_legend().legend_handles, hatches):
    collection.set_hatch(hatch)
    handle.set_hatch(hatch)
    
    handles.append(handle)

axes[1, 0].legend(
    handles=handles,
    labels=["Real", "Fake"],
    loc='upper left',
    title=""
    )

axes[1, 1] = sns.kdeplot(
    data=make_query_df(coaid_query_stats),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[1, 1],
    palette="tab10",
    )
axes[1, 1].set_ylabel("", fontsize=11)
axes[1, 1].get_legend().remove()

for collection, hatch in zip(axes[1, 1].collections[::-1], hatches):
    collection.set_hatch(hatch)

# TwitterHS 
axes[2, 0] = sns.kdeplot(
    data=make_support_df(twitter_support_stats, num_labels=3),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[2, 0],
    palette="tab10",
    )
axes[2, 0].set_ylabel("TwitterHateSpeech", fontsize=11)

handles = []
for collection, handle, hatch in zip(axes[2, 0].collections[::-1], axes[2, 0].get_legend().legend_handles, hatches):
    collection.set_hatch(hatch)
    handle.set_hatch(hatch)
    
    handles.append(handle)

axes[2, 0].legend(
    handles=handles,
    labels=["Racism", "Sexism", "None"],
    loc='upper left',
    title=""
    )

axes[2, 1] = sns.kdeplot(
    data=make_query_df(twitter_query_stats),
    x="excess_homophily",
    bw_adjust=0.8,
    clip=(-10, 1),
    hue="label",
    fill=True,
    multiple="layer",
    ax=axes[2, 1],
    palette="tab10",
    )
axes[2, 1].set_ylabel("", fontsize=11)
axes[2, 1].get_legend().remove()

for collection, hatch in zip(axes[2, 1].collections[::-1], hatches):
    collection.set_hatch(hatch)

for ax in flat_axes:
    ax.set_yticks([])
    ax.set_xlim([-1.25, 1.25])
    ax.set_xticklabels([])
    ax.set_xlabel("")
    ax.set_xticks([-1, -0.5, 0, 0.5, 1])

axes[2, 0].set_xticklabels([-1, -0.5, 0, 0.5, 1])
axes[2, 1].set_xticklabels([-1, -0.5, 0, 0.5, 1])
fig.supxlabel("Rel. Excess Homophily")

fig.tight_layout()

fig.savefig(
    "../../meta-learning-gnns-paper/emnlp2023-latex/figures/homophily_plot.png"
)


In [None]:
axes[2, 0].legend

In [None]:
from collections import defaultdict

import numpy as np

def five_stats(values: list):
    
    values = np.array(values)
    
    quantiles = np.quantile(values, q=[0.25, 0.50, 0.75])
    
    five_stats_summary = {
        "mean": np.mean(values),
        "stddev": np.std(values), 
        "q25": quantiles[0],
        "q50": quantiles[1],
        "q75": quantiles[2],

    }
    
    return five_stats_summary

def get_stats_table(stats_records):

    if isinstance(next(iter(stats_records)), dict):
        stats = defaultdict(list)

        for record in stats_records:
            for k, v in record.items():

                if isinstance(v, list):
                    for kk, vv in enumerate(v):
                        stats[f"{k}_{kk}"].append(vv)

                else:
                        stats[k] += [v]

    else:
        stats = stats_records

    records = []
    for k, v in stats.items():
        
        if isinstance(v, dict):
            for kk, vv in v.items():
                records.append({"metric": f"{k}_{kk}", **five_stats(vv)})

        else:
            records.append({"metric": k, **five_stats(v)})

    df = pd.DataFrame.from_records(records)
    print(df.to_string())
    df.to_clipboard(excel=True,)


In [None]:
get_stats_table(gossipcop_support_stats)

In [None]:
get_stats_table(gossipcop_query_stats)

In [None]:
get_stats_table(coaid_support_stats)

In [None]:
get_stats_table(coaid_query_stats)

In [None]:
get_stats_table(twitter_support_stats)

In [None]:
get_stats_table(twitter_query_stats)