In [2]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, Linear, SAGEConv, GATv2Conv, GATConv
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import time
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


import networkx as nx
import numpy as np
from pathlib import Path
import scipy as sp
rng = np.random.default_rng()

In [53]:
class EdgeAwareMessagePassing(MessagePassing):
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super(EdgeAwareMessagePassing, self).__init__(aggr='add')
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        
        # Message creation networks
        self.message_nn = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Update networks
        self.node_update = nn.GRUCell(hidden_dim, node_dim)
        self.edge_update = nn.GRUCell(2 * hidden_dim, edge_dim)
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(2 * node_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x, edge_index, edge_attr):
        # Propagate messages
        print(edge_index.shape)
        print(x.shape)
        print(edge_attr.shape)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        # Combine source node, target node, and edge features
        tmp = torch.cat([x_i, x_j, edge_attr], dim=1)
        
        # Compute attention weights
        alpha = self.attention(tmp)
        
        # Create message
        message = self.message_nn(tmp)
        
        return alpha * message

    def update(self, aggr_out, x, edge_index, edge_attr):
        # Update node features
        x_new = self.node_update(aggr_out, x)
        
        # Update edge features
        edge_features = []
        for i in range(edge_index.size(1)):
            source, target = edge_index[:, i]
            combined = torch.cat([x_new[source], x_new[target]], dim=0)
            edge_features.append(combined)
        edge_features = torch.stack(edge_features, dim=0)
        
        edge_attr_new = self.edge_update(edge_features, edge_attr)
        
        return x_new, edge_attr_new

class CascadeGNN(nn.Module):
    def __init__(self, num_nodes, hidden_dim=64, num_layers=3):
        super(CascadeGNN, self).__init__()
        self.num_nodes = num_nodes
        self.hidden_dim = hidden_dim
        #self.edge_dim = edge_dim
        
        # Initial embeddings
        self.node_embedding = nn.Embedding(num_nodes, hidden_dim)
        #self.edge_embedding = nn.Parameter(torch.randn(edge_dim))
        self.edge_embedding = nn.Parameter(torch.Tensor(self.num_nodes, self.num_nodes, hidden_dim))

        '''
        # Message passing layers
        self.conv_layers = nn.ModuleList([
            EdgeAwareMessagePassing(hidden_dim, edge_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # Edge probability prediction layer
        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim // 2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        ''' 
        self.convs = nn.ModuleList([
            GATConv(hidden_dim, hidden_dim) 
            for i in range(num_layers)
        ])

        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim // 2),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
    def forward(self, edge_index):
        # Get initial node embeddings
        x = self.node_embedding(torch.arange(self.num_nodes).to(edge_index.device))
        # Initialize edge features (one per edge)
        src, dst = edge_index
        edge_attr = self.edge_embedding[src, dst]
        
        # Apply GNN layers
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.gelu(x)
            x = torch.dropout(x, p=0.1, train=self.training)
            #x_new, edge_attr = conv(x, edge_index, edge_attr)
            #x = x_new + x  # Residual connection for nodes
            #edge_attr = edge_attr + edge_attr.clone()  # Residual connection for edges
            
        # Compute edge probabilities for all edges
        edge_repr = torch.cat([torch.add(x[src], x[dst]), edge_attr], dim=1)
        edge_probs = self.edge_mlp(edge_repr)
        edge_probabilities = {}
        for i in range(edge_index.size(1)):
            source, target = edge_index[:, i]
            edge_probabilities[(source.item(), target.item())] = edge_probs[i]
        
        return edge_probabilities



In [None]:
class GNNIndependentCascade(torch.nn.Module):
  def __init__(self, hidden_dim, n_nodes, num_layers=3):
    super(GNNIndependentCascade, self).__init__()
    self.n = n_nodes

    self.num_layers = num_layers

    self.node_embed = nn.Embedding(self.n, hidden_dim)
    self.edge_embed = nn.Parameter(torch.Tensor(self.n, self.n, hidden_dim))

    self.convs = nn.ModuleList([
      GATConv(hidden_dim, hidden_dim) 
      for i in range(num_layers)
    ])

    self.edge_predictor = nn.Sequential(
      nn.Linear(2*hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, 1),
    )
    #nn.Linear(hidden_dim, 1)

  def forward(self, edge_index):
    x = self.node_embed(torch.arange(self.n))

    src, dst = edge_index
    edge_emb = self.edge_embed[src, dst]

    # Node embedding
    for i in range(self.num_layers):
      #x = self.convs[i](x, edge_index, edge_emb)
      x = self.convs[i](x, edge_index)
      x = F.gelu(x)
      x = torch.dropout(x, p=0.1, train=self.training)

    # Edge probability prediction
    #edge_repr = torch.cat([x[src], x[dst]], dim=1)
    #edge_repr = torch.cat([torch.add(x[src], x[dst]), edge_emb], dim=1)
    #edge_probs = torch.sigmoid(self.edge_predictor(edge_repr))
    edge_probabilities = {}
    for i in range(edge_index.size(1)):
        source, target = edge_index[:, i]
        edge_repr = torch.cat([torch.add(x[source], x[target]), edge_emb[i]], dim=0)
        prob = torch.sigmoid(self.edge_predictor(edge_repr))
        edge_probabilities[(source.item(), target.item())] = prob
    
    print(edge_probabilities)
    return edge_probabilities
    #edge_probs = torch.sigmoid(torch.sum(x[row] * x[col], dim=1))



In [59]:
def compute_cascade_likelihood(num_nodes, edge_probs, cascade, eps=1e-4):
    """
    Compute the negative log likelihood of observing a cascade given edge probabilities
    
    Args:
        num_nodes: Number of nodes in the graph
        edge_probs: Dictionary mapping (source, target) tuples to probabilities
        cascade: List of lists, where cascade[i] contains nodes activated at time i
        eps: Small value to prevent log(0)
    
    Returns:
        Negative log likelihood of the cascade
    """
    log_likelihood = 0.0
    activated_nodes = set()
    
    # Process each time step
    for t in range(len(cascade)):
        prev_activated = cascade[t-1] if t-1 >= 0 else []
        curr_activated = cascade[t]
        next_activated = cascade[t+1] if t+1 < len(cascade) else []
        activated_nodes.update(curr_activated)

        #print(t)
        #print(prev_activated)
        #print(curr_activated)
        #print(next_activated)

        for v in curr_activated:
            # Probability of activation from parents
            if prev_activated:
                parents = set([u for u in range(num_nodes) if (u, v) in edge_probs and u in prev_activated])
                prob = [1 - edge_probs[(u, v)] for u in parents]
                prob = torch.cat(prob)
                prob_not_activated = torch.prod(prob)
                log_likelihood += torch.log(1 - prob_not_activated + eps)
            if next_activated:
                children = set([w for w in range(num_nodes) if (v, w) in edge_probs and w not in activated_nodes and w not in set(next_activated)])
                if not children:
                    continue
                prob = [1 - edge_probs[(v, w)] for w in children]
                prob = torch.cat(prob)
                #print(prob)
                prob_not_activated = torch.prod(prob)
                log_likelihood += torch.log(prob_not_activated + eps)
    
    return log_likelihood

def compute_loss(num_nodes, edge_probs, cascades):
  """
  Compute the negative log-likelihood loss for multiple cascades.
  
  Args:
    num_nodes: Number of nodes in the
    edge_probs: Tensor of predicted edge probabilities
    cascades: List of cascades, where each cascade is a list of lists of activated nodes
  
  Returns:
    loss: Negative log-likelihood loss
  """
  total_log_likelihood = 0.0
  for cascade in cascades:
    total_log_likelihood += compute_cascade_likelihood(num_nodes, edge_probs, cascade)
  
  # Return negative log-likelihood as the loss
  #print(-total_log_likelihood)
  return -total_log_likelihood

In [None]:
def train_cascade_gnn(model, num_nodes, edge_index, adj_list, cascades, num_epochs=100, lr=0.001, verbose=True):
    """
    Train the GNN model using the observed cascades
    
    Args:
        model: CascadeGNN model
        num_nodes: Number of nodes in the graph
        edge_index: Tensor of shape [2, num_edges] containing edge indices
        cascades: List of cascades, where each cascade is a list of lists
        num_epochs: Number of training epochs
        lr: Learning rate
    """
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        # Get edge probabilities
        edge_probs = model.forward(edge_index)
        
        # Compute total negative log likelihood across all cascades
        total_loss = compute_loss(num_nodes, edge_probs, cascades)
            
        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()
        
        if verbose and (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item():.4f}")
            
    return model

In [3]:
def create_dataset(G: nx.DiGraph):
  edge_index = torch.tensor(list(G.edges)).t().contiguous()
  adj_list = {v : (set(), set()) for v in G.nodes}
  for e in G.edges:
    u, v = e
    adj_list[u][1].add(v)
    adj_list[v][0].add(u)
  return edge_index, adj_list

n = 100
p = 0.1
gname = f"er_{n}_{str(p).replace('.', '')}"
path = Path(f"datasets/synthetic/{gname}")

with open(path / f"graph.mtx", "rb") as fh:
  G = nx.from_scipy_sparse_array(sp.io.mmread(fh), create_using=nx.DiGraph)

cascades = []
for i in range(250):
  with open(path / f"diffusions/timestamps/{i}.txt", "r") as fh:
    cascade = []
    for line in fh:
      cascade.append(list(map(int, line.strip().split())))
    cascades.append(cascade)

n = G.number_of_nodes()
m = G.number_of_edges()
edge_index, adj_list = create_dataset(G)

In [69]:
l1_errors = []
times = []
cascade_sizes = [50, 75, 100]
for k in cascade_sizes:
    start = time.time()
    #model = CascadeGNN(n, hidden_dim=64, num_layers=3)
    model = GNNIndependentCascade(32, n, num_layers=2)
    trained_model = train_cascade_gnn(model, n, edge_index, cascades[:k], num_epochs=40, lr=0.01, verbose=True)
    end = time.time()
    times.append(end-start)

    trained_model.eval()
    edge_probs = trained_model(edge_index)
    residuals = []
    for i, e in enumerate(G.edges()):
        u, v = e
        p = G[u][v]['weight']
        residuals.append(abs(p - edge_probs[e].item()))

    l1_errors.append(sum(residuals) / len(residuals))

print("M\tMAE\t\t\tTime")
for i, k in enumerate(cascade_sizes):
    print(f"{k}\t{l1_errors[i]}\t{times[i]}")

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [18]:
print(l1_errors)
print(times)

[0.11266553486120538, 0.11358482052435794, 0.11321353166009078]
[53.8030731678009, 100.08372330665588, 145.98774337768555]
