In [None]:
from data.graph_loader import load_highschool
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, to_undirected
import torch

edge_index = load_highschool()
num_nodes = 70
edge_index = to_undirected(edge_index, num_nodes=num_nodes)
node_idx = torch.arange(num_nodes)
graph = Data(x=node_idx, edge_index=edge_index)
nx_graph = to_networkx(graph)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def greedy_maximize_neighbors(G, budget):
    """
    Select a subset of nodes from the NetworkX graph G such that the number of unique neighbors of the selected nodes is maximized.
    
    Parameters:
    G (nx.Graph): A NetworkX graph.
    
    Returns:
    selected_nodes (set): A set of nodes selected to maximize the number of neighbors.
    covered_neighbors (set): A set of neighbors covered by the selected nodes.
    """
    selected_nodes = set()
    covered_neighbors = set()
    
    # Greedily select nodes until the budget is exhausted   
    for i in range(budget):
        best_node = None
        best_increase = 0
        
        # Evaluate each node to find the one that adds the most new neighbors
        for node in G.nodes:
            if node in selected_nodes:
                continue
            
            neighbors = set(G.neighbors(node))
            new_neighbors = neighbors - covered_neighbors - selected_nodes
            increase = len(new_neighbors)
            
            if increase > best_increase:
                best_increase = increase
                best_node = node
        
        if best_increase == 0:
            # No more nodes can increase the number of neighbors
            break
        
        # Add the best node to the selected set and update covered neighbors
        selected_nodes.add(best_node)
        covered_neighbors.update(G.neighbors(best_node))
        covered_neighbors.add(best_node)    

    return selected_nodes, covered_neighbors

from copy import deepcopy as dc
import networkx as nx
from scipy import sparse

def compute_eigenval(nx_graph, node_idx):
    G = dc(nx_graph)
    G.remove_nodes_from(node_idx)
    # import pdb; pdb.set_trace()
    adj = nx.adjacency_matrix(G, dtype=float)
    after_eigenval, _ = sparse.linalg.eigsh(adj, k=1, which='LA')
    return after_eigenval

def SV(graph, node_index, action):
    adj = nx.adjacency_matrix(graph, dtype=float).tolil()
    mask = (node_index == 1).squeeze().cpu().numpy()
    adj[mask, :] = 0
    adj[:, mask] = 0
    eigenval, eigenvec = sparse.linalg.eigsh(adj, k=1, which='LA')

    device = node_index.device

    eigenval = eigenval.item()
    adj = torch.from_numpy(adj.todense()).to(device)
    eigenvec = torch.from_numpy(eigenvec).to(device).squeeze()
    mask = torch.zeros_like(node_index).squeeze()
    mask[action] = 1

    term1 = 2*eigenval*(mask*(eigenvec**2)).sum()
    masked_eigenvec = (eigenvec * mask).reshape(-1,1)
    term2 = (adj * (masked_eigenvec @ masked_eigenvec.T)).sum()
    eigendrop = term1 - term2
    return eigendrop.item()

In [None]:
budget = 5
ob = torch.zeros(num_nodes).cuda()
actions = []
for i in range(budget):
    best_res = 0
    for j in torch.where(ob==0)[0]:
        # import pdb; pdb.set_trace()
        tmp = SV(nx_graph, ob, j.item())
        if best_res < tmp:
            best_res = tmp
            action = j.item()
    ob[action] = 1
    actions.append(action)
    eigenval = compute_eigenval(nx_graph, actions)
    print(f"{action} nodes deleted. Eigenval: {eigenval}")

In [None]:
num_nodes = 70
shield_values = torch.zeros(num_nodes)
for i in range(num_nodes):
    shield_values[i] = SV(nx_graph, torch.ones(num_nodes).cuda(), i)