In [2]:
import networkx as nx
import torch

In [None]:
def label_propagation(G, labeled_nodes, labels, max_iter=1000, tol=1e-6):
    A = nx.adjacency_matrix(G)
    n_nodes = A.shape[0]

    D = np.array(A.sum(axis=1)).flatten()
    D_inv = 1.0 / D
    D_inv[np.isinf(D_inv)] = 0
    
    # Convert to PyTorch tensors
    D_inv = torch.FloatTensor(D_inv)
    A = torch.FloatTensor(A.todense())
    
    # Initialize label matrix Y
    n_classes = max(labels.values()) + 1
    Y = torch.zeros((n_nodes, n_classes))
    
    # Set known labels
    for node, label in labels.items():
        Y[node][label] = 1.0
    
    # Create mask for labeled nodes
    labeled_mask = torch.zeros(n_nodes, dtype=torch.bool)
    labeled_mask[list(labeled_nodes)] = True
    
    # Normalize adjacency matrix
    P = D_inv.unsqueeze(1) * A
    
    # Initialize previous state
    Y_prev = Y.clone()
    
    # Label Propagation
    for i in range(max_iter):
        # Propagate labels
        Y = torch.mm(P, Y)
        
        # Clamp labeled nodes to their original values
        Y[labeled_mask] = Y_prev[labeled_mask]
        
        # Check convergence
        delta = torch.norm(Y - Y_prev)
        if delta < tol:
            break
            
        Y_prev = Y.clone()
    
    # Get predicted labels
    predicted_labels = torch.argmax(Y, dim=1)
    
    # Convert to dictionary
    return {node: predicted_labels[node].item() for node in G.nodes()}