# Exercise 4
Due:  Tue November 19, 8:00am

## Node2Vec
1. Implement custom dataset that samples pq-walks
    - Use the utility function from torch_cluster that actually performs the walks
2. Implement Node2Vec module and training
	- Node2Vec essentially consists of a torch.Embedding module and a loss function
3. Evaluate node classification performance on Cora
4. Evaluate on Link Prediction: Cora, PPI
    - use different ways to combine the node two embeddings for link prediction

Bonus Question: are the predictions stable wrt to the random seeds of the walks?

In [93]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [94]:
import torch
import torch_geometric as pyg
from tqdm import tqdm
import torch_cluster
import sklearn

In [95]:
# find device
if torch.cuda.is_available(): # NVIDIA
    device = torch.device('cuda')
elif torch.backends.mps.is_available(): # apple M1/M2
    device = torch.device('mps') 
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [96]:
dataset = pyg.datasets.Planetoid(root='./dataset/cora', name='Cora')
cora = dataset[0]
dataset = pyg.datasets.PPI(root='./dataset/ppi')
ppi = dataset[0]

In [97]:
cora

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [98]:
ppi

Data(x=[1767, 50], edge_index=[2, 32318], y=[1767, 121])

## node2vec embedding training
Here the main training and everything on the graph level is happening.

It might be a good idea to create a dataset of walks (fixed for the whole training process) first to get the whole training process running before attempting to create a train_loader that on-demand samples those walks on-demand.

In [99]:
class PQWalkDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data,
        walk_length,
        walks_per_node=1,
        p=1,
        q=1,
        num_negative_samples=1,
        seed=42,
    ):
        self.data = data
        self.edge_index = self.data.edge_index
        self.walk_length = walk_length - 1
        self.walks_per_node = walks_per_node
        self.num_nodes = self.data.num_nodes
        self.p = p
        self.q = q
        self.num_negative_samples = num_negative_samples
        if seed is not None:
            torch.manual_seed(seed)
        self._start_nodes = torch.arange(self.num_nodes).repeat(
            self.walks_per_node
        )
        self._negative_start_nodes = torch.arange(self.num_nodes).repeat(
           self.walks_per_node * self.num_negative_samples
        )
        self._pos_samples = self._get_pos_samples()
        self._neg_samples = self._get_neg_samples()

    def _get_pos_samples(self):
        return torch_cluster.random_walk(
            self.edge_index[0],
            self.edge_index[1],
            start=self._start_nodes,
            walk_length=self.walk_length,
            p=self.p,
            q=self.q,
        )
    
    def _get_neg_samples(self):
        negative_samples = torch.randint(0, self.num_nodes, (self._negative_start_nodes.shape[0], self.walk_length))
        negative_samples = torch.cat([self._negative_start_nodes.view(-1, 1), negative_samples], dim=-1)
        return negative_samples


    def __len__(self):
        return len(self._pos_samples)

    def __getitem__(self, idx):
        walk = self._pos_samples[idx]
        neg_sample = self._neg_samples[idx]
        return walk, neg_sample



In [100]:
# pqwalkdataset test 
cora_pq_dataset = PQWalkDataset(
    data=cora,
    walk_length=4,
    walks_per_node=4,
    num_negative_samples=1,
    seed=42
)

for walk, neg_sample in cora_pq_dataset:
    print("Walk shape, Neg sample shape:", walk.shape, neg_sample.shape)

# test to count walks per node
walk_counts = torch.zeros(cora.num_nodes, dtype=torch.long)
for walk, neg_sample in torch.utils.data.DataLoader(cora_pq_dataset, batch_size=5, num_workers=2, shuffle=True):
    unique, counts = torch.unique(walk[:, 0], return_counts=True)
    for node, count in zip(unique, counts):
        walk_counts[node] += count

print("\nWalk counts per node:")
for node, count in enumerate(walk_counts):
    print(f"Node {node}: {count} walks")


Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk shape, Neg sample shape: torch.Size([4]) torch.Size([4])
Walk sha

In [101]:
class PQWalkIterableDataset(torch.utils.data.IterableDataset):
    def __init__(
        self,
        data,
        walk_length=10,
        walks_per_node=10,
        p=1,
        q=1,
        num_negative_samples=1,
        batch_size=32,
        seed=42,
    ):
        self.data = data
        self.edge_index = self.data.edge_index
        self.walk_length = walk_length - 1
        self.walks_per_node = walks_per_node
        self.num_nodes = self.data.num_nodes
        self.p = p
        self.q = q
        self.num_negative_samples = num_negative_samples
        self.batch_size = min(batch_size * self.walks_per_node, self.num_nodes * self.walks_per_node)
        self.seed = seed

    def _generate_negative_samples(self, batch_nodes, worker_id):
        # Repeat batch nodes for each negative sample
        batch = batch_nodes.repeat(self.num_negative_samples)
        print(f"{worker_id}: Batch shape: {batch.shape}")
        # Generate random walks for negative samples
        rw = torch.randint(
            self.num_nodes,
            (batch.size(0), self.walk_length),
            dtype=batch.dtype,
            device=batch.device
        )
        # Concatenate batch nodes with random walks
        rw = torch.cat([batch.view(-1, 1), rw], dim=-1)
        return rw
    
    def __iter__(self):
        if self.seed is not None:
            torch.manual_seed(self.seed)

        worker_info = torch.utils.data.get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        num_workers = 1 if worker_info is None else worker_info.num_workers

        print(f"\n{worker_id}: Worker {worker_id} calculations:")

        # Calculate nodes per worker
        nodes_per_worker = self.num_nodes // num_workers
        start_node = worker_id * nodes_per_worker
        end_node = (
            start_node + nodes_per_worker
            if worker_id < num_workers - 1
            else self.num_nodes
        )
        worker_nodes = end_node - start_node

        print(f"\n{worker_id}: Handling nodes [{start_node}, {end_node})")
        print(f"{worker_id}: Number of nodes for this worker: {worker_nodes}")
        print(f"{worker_id}: Walks per node: {self.walks_per_node}")

        # Generate start nodes array that ensures walks_per_node samples for each node
        start_nodes = torch.arange(start_node, end_node).repeat_interleave(
            self.walks_per_node
        )
        total_walks = len(start_nodes)
        num_batches = (
            total_walks + self.batch_size - 1
        ) // self.batch_size  # ceiling division

        print(f"{worker_id}: Total walks to generate: {total_walks}")
        print(f"{worker_id}: Batch size: {self.batch_size}")
        print(f"{worker_id}: Number of batches: {num_batches}")

        # Shuffle all start nodes
        perm = torch.randperm(total_walks)
        start_nodes = start_nodes[perm]

        # Generate walks in batches
        for batch_idx in range(num_batches):
            batch_start = batch_idx * self.batch_size
            batch_end = min(batch_start + self.batch_size, total_walks)

            batch_nodes = start_nodes[batch_start:batch_end]

            walks = torch_cluster.random_walk(
                self.edge_index[0],
                self.edge_index[1],
                start=batch_nodes,
                walk_length=self.walk_length,
                p=self.p,
                q=self.q,
            )

            neg_samples = self._generate_negative_samples(batch_nodes, worker_id)

            yield walks, neg_samples


In [102]:
# test walk counts iterable
def test_walk_counts_iterable():
    walks_per_node = 4
    walk_length = 3

    dataset = PQWalkIterableDataset(
        data=cora,
        walk_length=walk_length,
        walks_per_node=walks_per_node,
        num_negative_samples=1,
        seed=42
    )
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=None,
        num_workers=2
    )
    
    iter_count = 0
    # Count walks per node
    walk_counts = torch.zeros(cora.num_nodes, dtype=torch.long)
    for pos_sample, neg_sample in dataloader:
        iter_count += 1
        print(f"\nPos sample shape, Neg sample shape: {pos_sample.shape}, {neg_sample.shape}")
        unique, counts = torch.unique(pos_sample[:, 0], return_counts=True)
        for node, count in zip(unique, counts):
            walk_counts[node] += count
    
    print(f"\nWalk counts per node after {iter_count} iterations:")
    for node, count in enumerate(walk_counts):
        print(f"Node {node}: {count} walks")

test_walk_counts_iterable()


0: Worker 0 calculations:
1: Worker 1 calculations:


1: Handling nodes [1354, 2708)
0: Handling nodes [0, 1354)

0: Number of nodes for this worker: 13541: Number of nodes for this worker: 1354

0: Walks per node: 41: Walks per node: 4

0: Total walks to generate: 5416
1: Total walks to generate: 54160: Batch size: 128

1: Batch size: 1280: Number of batches: 43

1: Number of batches: 43
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([128])
0: Batch shape: torch.Size([128])
1: Batch shape: torch.Size([1

In [103]:
# test walk counts pyg
def test_walk_counts_pyg():
    # Create a small test graph
    walk_length = 3
    walks_per_node = 4
    # Initialize Node2Vec model
    model = Node2Vec(
        edge_index=cora.edge_index,
        embedding_dim=16,
        walk_length=walk_length,
        walks_per_node=walks_per_node,  # Same as walks_per_node
        p=1,
        q=1,
        context_size=walk_length,
    )
    
    # Create loader
    loader = model.loader(
        batch_size=32,
        shuffle=True,
        num_workers=2
    )
    iter_count = 0
    
    # Count walks per node
    walk_counts = torch.zeros(cora.num_nodes, dtype=torch.long)
    for pos_sample, neg_sample in loader:
        iter_count += 1
        print(f"\nPos sample shape, Neg sample shape: {pos_sample.shape}, {neg_sample.shape}")
        unique, counts = torch.unique(pos_sample[:, 0], return_counts=True)
        for node, count in zip(unique, counts):
            walk_counts[node] += count
    
    print(f"\nWalk counts per node after {iter_count} iterations:")
    for node, count in enumerate(walk_counts):
        if count > 0:  # Only print nodes that have walks
            print(f"Node {node}: {count} walks")
            
   
test_walk_counts_pyg()


Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: torch.Size([128, 3]), torch.Size([128, 3])

Pos sample shape, Neg sample shape: tor

In [104]:
edge_index = torch.tensor([
    [0, 1, 1, 2, 1, 3, 3, 4, 4, 2],  # Source nodes
    [1, 0, 2, 1, 3, 1, 4, 3, 2, 4],  # Target nodes
], dtype=torch.long)

# Create a dummy Data object
class DummyData:
    def __init__(self, edge_index, num_nodes):
        self.edge_index = edge_index
        self.num_nodes = num_nodes

# Create small graph with 5 nodes
small_graph = DummyData(edge_index, num_nodes=5)

# Initialize the dataset with small parameters
dataset = PQWalkIterableDataset(
    data=small_graph,
    walk_length=3,        # Short walks for demonstration
    walks_per_node=2,     # Generate 2 walks per node
    p=1,                  # Return parameter
    q=1,                  # In-out parameter
    num_negative_samples=1,
    batch_size=32,         # Small batch size
    seed=42
)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=None,  # Batch size is handled by the dataset
    num_workers=2
)

print("Graph structure:")
print("Nodes: 0, 1, 2, 3, 4")
print("Edges:", end=" ")
for i in range(edge_index.shape[1]):
    print(f"({edge_index[0][i]}-{edge_index[1][i]})", end=" ")
print("\n")

# Iterate through the batches
for batch_idx, (walks, neg_samples) in enumerate(dataloader):
    print(f"\nBatch {batch_idx + 1}:")
    print("\nPositive walks:")
    for i, walk in enumerate(walks):
        print(f"Walk {i + 1}: {walk.tolist()}")
    
    print("\nNegative samples:")
    for i, neg_sample in enumerate(neg_samples):
        print(f"Negative {i + 1}: {neg_sample.tolist()}")

Graph structure:
Nodes: 0, 1, 2, 3, 4
Edges: (0-1) (1-0) (1-2) (2-1) (1-3) (3-1) (3-4) (4-3) (4-2) (2-4) 


1: Worker 1 calculations:

1: Handling nodes [2, 5)

0: Worker 0 calculations:1: Number of nodes for this worker: 3


0: Handling nodes [0, 2)1: Walks per node: 2

0: Number of nodes for this worker: 2
1: Total walks to generate: 60: Walks per node: 2

1: Batch size: 10
0: Total walks to generate: 41: Number of batches: 1

0: Batch size: 10
0: Number of batches: 1
1: Batch shape: torch.Size([6])
0: Batch shape: torch.Size([4])

Batch 1:

Positive walks:
Walk 1: [1, 3, 1]
Walk 2: [1, 2, 1]
Walk 3: [0, 1, 3]
Walk 4: [0, 1, 3]

Negative samples:
Negative 1: [1, 4, 0]
Negative 2: [1, 4, 1]
Negative 3: [0, 2, 0]
Negative 4: [0, 0, 2]

Batch 2:

Positive walks:
Walk 1: [2, 4, 2]
Walk 2: [3, 4, 3]
Walk 3: [3, 1, 3]
Walk 4: [4, 3, 4]
Walk 5: [2, 4, 3]
Walk 6: [4, 2, 4]

Negative samples:
Negative 1: [2, 0, 2]
Negative 2: [3, 1, 4]
Negative 3: [3, 1, 3]
Negative 4: [4, 1, 4]
Negative 5: [

In [105]:
# import torch
# import torch_geometric as pyg
# from torch_geometric.nn import Node2Vec

# # Create the same small example graph
# # 0 -- 1 -- 2
# #      |     |
# #      3 -- 4
# edge_index = torch.tensor([
#     [0, 1, 1, 2, 1, 3, 3, 4, 4, 2],  # Source nodes
#     [1, 0, 2, 1, 3, 1, 4, 3, 2, 4],  # Target nodes
# ], dtype=torch.long)

# # Create a dummy Data object
# class DummyData:
#     def __init__(self, edge_index, num_nodes):
#         self.edge_index = edge_index
#         self.num_nodes = num_nodes

# # Create small graph with 5 nodes
# small_graph = DummyData(edge_index, num_nodes=5)

# # Initialize Node2Vec model
# model = Node2Vec(
#     edge_index=edge_index,
#     embedding_dim=16,     # Size of embeddings
#     walk_length=3,        # Same as our example
#     p=1,                  # Return parameter
#     q=1,                  # In-out parameter
#     walks_per_node=2,          # Same as our walks_per_node
#     context_size=3,
# )

# # Create loader with the same batch size
# loader = model.loader(batch_size=32, shuffle=True)

# print("Graph structure:")
# print("Nodes: 0, 1, 2, 3, 4")
# print("Edges:", end=" ")
# for i in range(edge_index.shape[1]):
#     print(f"({edge_index[0][i]}-{edge_index[1][i]})", end=" ")
# print("\n")

# # Iterate through the batches
# for batch_idx, (pos_rw, neg_rw) in enumerate(loader):
#     print(f"\nBatch {batch_idx + 1}:")
#     print("\nPositive random walks:")
#     for i, walk in enumerate(pos_rw):
#         print(f"Walk {i + 1}: {walk.tolist()}")
    
#     print("\nNegative samples:")
#     for i, neg_sample in enumerate(neg_rw):
#         print(f"Negative {i + 1}: {neg_sample.tolist()}")
    

# # You can also access the generated walks directly
# print("\nAll positive random walks:")
# pos_walks = model.pos_sample(batch=torch.arange(small_graph.num_nodes))
# print(pos_walks)

# print("\nAll negative random walks:")
# neg_walks = model.neg_sample(batch=torch.arange(small_graph.num_nodes))
# print(neg_walks)

## Node classification performance
just a small MLP or even linear layer on the embeddings to predict node classes. Accuracy should be above 60%. Please compare your results to those you achieved with GNNs.

In [106]:
# as the simple MLP is pretty straightforward
model = torch.nn.Sequential(
    torch.nn.Linear(embedding_dim, 256), # Input layer
    torch.nn.ReLU(),
    torch.nn.Linear(256, 128), # Hidden layer 2
    torch.nn.ReLU(),
    torch.nn.Linear(128, cora_dataset.num_classes), # Output layer
)
model = model.to(device)


NameError: name 'embedding_dim' is not defined

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # define an optimizer
criterion = torch.nn.CrossEntropyLoss()  # define loss function

node2vec_embeddings = embedding.to(device)
cora = cora.to(device)

for epoch in range(100):  # 100 epochs
    model.train()
    optimizer.zero_grad()
    out = model(node2vec_embeddings[cora.train_mask])  # forward pass
    loss = criterion(out, cora.y[cora.train_mask]) 
    loss.backward()  
    optimizer.step()

    # print out loss info
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.3e}")

def get_accuracy(model, embeddings, y, mask):
    out = model(embeddings[mask])
    pred = out.argmax(dim=1)
    acc = sklearn.metrics.accuracy_score(y[mask].cpu().numpy(), pred.cpu().detach().numpy())
    return acc

train_acc = get_accuracy(model, node2vec_embeddings, cora.y, cora.train_mask)
val_acc = get_accuracy(model, node2vec_embeddings, cora.y, cora.val_mask)
test_acc = get_accuracy(model, node2vec_embeddings, cora.y, cora.test_mask)
    
print(f"node classification accuracy for cora: {test_acc:.2f} (train: {train_acc:.2f}, val: {val_acc:.2f})")

## link prediction on trained embeddings
this should only train simple MLPs.

Note: for link prediction to be worthwhile, one needs to train the embeddings on a subset of the graph (less edges, same nodes) instead of the whole graph.

In [32]:
# for link prediction, do something like the following
link_splitter = pyg.transforms.RandomLinkSplit(is_undirected=True)
train_data, val_data, test_data = link_splitter(cora)
train_data
# the positive and negative edges are in "edge_label_index" with "edge_label" 
# indicating whether an edge is a true edge or not.

Data(x=[2708, 1433], edge_index=[2, 7392], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[7392], edge_label_index=[2, 7392])

In [28]:
test_data

Data(x=[2708, 1433], edge_index=[2, 8446], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[2110], edge_label_index=[2, 2110])

In [None]:
# retrain node2vec on train_data

In [33]:
# use those (new) embeddings for link prediction