In [223]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, Linear, SAGEConv, GATv2Conv, GATConv
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

import networkx as nx
import numpy as np
from pathlib import Path
import scipy as sp
rng = np.random.default_rng()

In [224]:
class GNNIndependentCascade(torch.nn.Module):
  def __init__(self, num_node_features, hidden_dim, num_layers=2):
    super(GNNIndependentCascade, self).__init__()
    self.num_layers = num_layers
    self.convs = nn.ModuleList([
      GATConv(num_node_features if i == 0 else hidden_dim, hidden_dim) 
      for i in range(num_layers)]
    )

    self.edge_predictor = nn.Sequential(
      nn.Linear(2 * hidden_dim, hidden_dim),
      nn.ELU(),
      nn.Linear(hidden_dim, 1),
      nn.ELU()
    )

  def forward(self, data):
    x, edge_index = data.x, data.edge_index

    # Node embedding
    for i in range(self.num_layers):
      x = self.convs[i](x, edge_index)
      x = torch.relu(x)
      x = torch.dropout(x, p=0.1, train=self.training)

    # Edge probability prediction
    row, col = edge_index
    edge_features = torch.cat([x[row], x[col]], dim=1)
    edge_probs = torch.sigmoid(self.edge_predictor(edge_features).squeeze())
    #edge_probs = torch.sigmoid(torch.sum(x[row] * x[col], dim=1))

    return edge_probs
    
class CascadeDataset(Dataset):
  def __init__(self, cascades):
    self.cascades = cascades

  def __len__(self):
    return len(self.cascades)

  def __getitem__(self, idx):
    return self.cascades[idx]

In [199]:
def compute_cascade_likelihood(edge_probs, edge_index, cascade, epsilon=1e-8):
  """
  Compute the likelihood of observing a single cascade given edge probabilities.
  
  Args:
  - edge_probs: Tensor of predicted edge probabilities
  - edge_index: Tensor of shape [2, num_edges] containing edge indices
  - cascade: List of lists, where each inner list contains nodes activated at that time step
  - epsilon: Small value to avoid log(0)
  
  Returns:
  - log_likelihood: Log-likelihood of the cascade
  """
  device = edge_probs.device
  num_nodes = edge_index.max().item() + 1
  activated = torch.zeros(num_nodes, dtype=torch.bool, device=device)
  log_likelihood = 0.0

  src, dst = edge_index

  for t in range(len(cascade)):
    prev_activated = torch.tensor(cascade[t-1] if t-1 >= 0 else [], device=device)
    curr_activated = torch.tensor(cascade[t], device=device)
    next_activated = torch.tensor(cascade[t+1] if t+1 < len(cascade) else [], device=device)
    activated[curr_activated] = True
    
    # Probability of activation from parents
    for v in curr_activated:
      parents = src[(dst == v) & activated[src]]
      activated_parents = parents[torch.isin(parents, prev_activated)]
      if len(activated_parents) > 0:
        prob_v_activated = 1 - torch.prod(1 - edge_probs[torch.isin(src, activated_parents) & (dst == v)])
        log_likelihood += torch.log(prob_v_activated + epsilon)

    # Probability of non-activation of children
    for v in curr_activated:
      children = dst[(src == v) & ~activated[dst]]
      non_activated_children = children[~torch.isin(children, next_activated)]
      if len(non_activated_children) > 0:
        prob_children_not_activated = torch.prod(1 - edge_probs[(src == v) & torch.isin(dst, non_activated_children)])
        log_likelihood += torch.log(prob_children_not_activated + epsilon)

  return log_likelihood

  '''
  device = edge_probs.device
  num_nodes = edge_index.max().item() + 1
  activated = torch.zeros(num_nodes, dtype=torch.bool, device=device)
  log_likelihood = 0.0

  for t, activated_nodes in enumerate(cascade):
    if t == 0:
      activated[activated_nodes] = True
      continue

    # Compute activation probabilities for this step
    src, dst = edge_index
    mask = activated[src] & ~activated[dst]
    relevant_probs = edge_probs[mask]
    relevant_dst = dst[mask]

    # Compute likelihood of activations and non-activations
    new_activations = torch.tensor(activated_nodes, device=device)
    activated_probs = relevant_probs[torch.isin(relevant_dst, new_activations)]
    non_activated_probs = relevant_probs[~torch.isin(relevant_dst, new_activations)]

    log_likelihood += torch.sum(torch.log(activated_probs + epsilon))
    log_likelihood += torch.sum(torch.log(1 - non_activated_probs + epsilon))

    # Update activated nodes
    activated[activated_nodes] = True

  return log_likelihood
  '''

def compute_loss(edge_probs, edge_index, cascades):
  """
  Compute the negative log-likelihood loss for multiple cascades.
  
  Args:
  - edge_probs: Tensor of predicted edge probabilities
  - edge_index: Tensor of shape [2, num_edges] containing edge indices
  - cascades: List of cascades, where each cascade is a list of lists of activated nodes
  
  Returns:
  - loss: Negative log-likelihood loss
  """
  total_log_likelihood = 0.0
  for cascade in cascades:
    total_log_likelihood += compute_cascade_likelihood(edge_probs, edge_index, cascade)
  
  # Return negative log-likelihood as the loss
  return -total_log_likelihood

In [193]:
def create_dataset(G: nx.DiGraph, features):
  # Create a PyG Data object from the networkx graph
  edge_index = torch.tensor(list(G.edges)).t().contiguous()
  #x = torch.tensor(features, dtype=torch.float)
  data = Data(x=features, edge_index=edge_index)
  return data

n = 100
p = 0.1
gname = f"er_{n}_{str(p).replace('.', '')}"
path = Path(f"datasets/synthetic/{gname}")

with open(path / f"{gname}.mtx", "rb") as fh:
  G = nx.from_scipy_sparse_array(sp.io.mmread(fh), create_using=nx.DiGraph)
with open(path / "feats.npy", "rb") as fh:
  features_npy = torch.tensor(np.load(fh), dtype=torch.float)

cascades = []
idxes = rng.choice(500, 100, replace=False)
for i in idxes:
  with open(path / f"diffusions/timestamps/{i}.txt", "r") as fh:
    cascade = []
    for line in fh:
      cascade.append(list(map(int, line.strip().split())))
    cascades.append(cascade)

In [225]:
def train_model(model, optimizer, data, cascades, num_epochs, batch_size = 50):
  model.train()
  #batches = DataLoader(cascades, batch_size=10, shuffle=True)
  #print(batches)

  for epoch in range(num_epochs):
    loss = 0.0
    rng.shuffle(cascades)
    batches = [cascades[i:i+batch_size] for i in range(0, len(cascades), batch_size)]

    for batch in batches:
      optimizer.zero_grad()
      edge_probs = model.forward(data)

      batch_loss = compute_loss(edge_probs, data.edge_index, batch)
      batch_loss.backward()
      optimizer.step()
      loss += batch_loss.item()
    
    if epoch % 10 == 0 or epoch == num_epochs - 1:
      print(f"Epoch {epoch+1}/{num_epochs}, Total Loss: {loss:.4f}")
      #print(edge_probs[0])
      #print(edge_probs[1])
      #print('\n')

features_eye = torch.eye(G.number_of_nodes())
data = create_dataset(G, features_eye)
model = GNNIndependentCascade(data.num_features, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
train_model(model, optimizer, data, cascades, 100)

Epoch 1/100, Total Loss: 13293.3374
Epoch 11/100, Total Loss: 11944.2715
Epoch 21/100, Total Loss: 11815.2622
Epoch 31/100, Total Loss: 11790.8296
Epoch 41/100, Total Loss: 11790.7979
Epoch 51/100, Total Loss: 11784.6089
Epoch 61/100, Total Loss: 11782.4673
Epoch 71/100, Total Loss: 11778.4082
Epoch 81/100, Total Loss: 11770.3862
Epoch 91/100, Total Loss: 11757.4048
Epoch 100/100, Total Loss: 11732.8413


In [226]:
model.eval()
l1_error = 0
l2_error = 0
edge_probs = model(data)

for i, e in enumerate(G.edges()):
  u, v = e
  p = G[u][v]['weight']
  l1_error += abs(p - edge_probs[i].item())
  l2_error += (p - edge_probs[i].item())**2

print(l1_error)
print(l1_error / G.number_of_edges())
print(edge_probs)

188.68733220664117
0.19553091420377325
tensor([0.3130, 0.3316, 0.3023, 0.3137, 0.3223, 0.3287, 0.3517, 0.3292, 0.3487,
        0.3243, 0.3074, 0.3337, 0.3208, 0.3147, 0.3186, 0.3236, 0.3267, 0.3151,
        0.3063, 0.3131, 0.2978, 0.2986, 0.2992, 0.2973, 0.3013, 0.2907, 0.3048,
        0.2893, 0.2945, 0.2914, 0.2923, 0.3044, 0.3038, 0.2982, 0.2981, 0.3100,
        0.3063, 0.3125, 0.3108, 0.3281, 0.3000, 0.3360, 0.3293, 0.3127, 0.3456,
        0.3186, 0.3429, 0.2895, 0.2868, 0.2925, 0.2868, 0.2857, 0.2868, 0.2919,
        0.3024, 0.3148, 0.3226, 0.3259, 0.3138, 0.3263, 0.3199, 0.3425, 0.3090,
        0.3180, 0.3378, 0.3296, 0.3035, 0.3245, 0.3183, 0.2897, 0.3025, 0.3230,
        0.3011, 0.3040, 0.3204, 0.3114, 0.3144, 0.3110, 0.3033, 0.2991, 0.3038,
        0.3179, 0.3385, 0.3564, 0.3123, 0.3280, 0.3373, 0.3131, 0.3244, 0.3216,
        0.3153, 0.3072, 0.3215, 0.3228, 0.3325, 0.3475, 0.3306, 0.3317, 0.3495,
        0.3421, 0.3561, 0.3265, 0.3030, 0.3384, 0.3041, 0.3262, 0.3357, 0.3467,
 