#### Test Relative Positioning

In [3]:
import torch
import sys
sys.path.append('../src')
from models import relative_positioning

In [2]:
# Node features of shape (num_nodes, num_node_features) and type torch.float32
from torch_geometric.loader import DataLoader
from preprocess import PairData

x1 = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [1, 1, 1, 1, 1, 1, 1, 1, 1],
                    [2, 2, 2, 2, 2, 2, 2, 2, 2],
                    [3, 3, 3, 3, 3, 3, 3, 3, 3]], dtype=torch.float32)
x2 = torch.tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4],
                    [5, 5, 5, 5, 5, 5, 5, 5, 5],
                    [6, 6, 6, 6, 6, 6, 6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr1 = torch.tensor([[0, 0, 0],
                            [1,1,1],
                            [2,2,2],
                            [3,3,3],
                            [4,4,4],
                            [5,5,5]], dtype=torch.float32)
edge_attr2 = torch.tensor([[6,6,6],
                            [7,7,7],
                            [8,8,8],
                            [9,9,9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([1], dtype=torch.float32)

data = PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,  # Graph 1.
                x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,  # Graph 2.
                y=y) #Graph pair label. 

In [3]:
data_list = [data, data, data, data]
dataloader = DataLoader(data_list, batch_size=2, follow_batch=['x1', 'x2'])

In [5]:
model = relative_positioning(num_node_features=9, num_edge_features=3, hidden_channels=[64, 128], out_channels=32)

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

for batch in dataloader:
    # Move batch data to the device
    batch.x1 = batch.x1.to(device)
    batch.edge_index1 = batch.edge_index1.to(device)
    batch.edge_attr1 = batch.edge_attr1.to(device)
    batch.x1_batch = batch.x1_batch.to(device)
    
    batch.x2 = batch.x2.to(device)
    batch.edge_index2 = batch.edge_index2.to(device)
    batch.edge_attr2 = batch.edge_attr2.to(device)
    batch.x2_batch = batch.x2_batch.to(device)
    
    batch.y = batch.y.to(device)

    optimizer.zero_grad()
    
    out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
    loss = criterion(out, batch.y)
    
    loss.backward()
    optimizer.step()

    print(loss.item())

2.248107671737671
0.33536654710769653


#### Test Temporal Shuffling

In [9]:
from preprocess import TripletData

# Node features of shape (num_nodes, num_node_features) and type torch.float32
from torch_geometric.loader import DataLoader
from preprocess import PairData

x1 = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
                    [1, 1, 1, 1, 1, 1, 1, 1, 1],
                    [2, 2, 2, 2, 2, 2, 2, 2, 2],
                    [3, 3, 3, 3, 3, 3, 3, 3, 3]], dtype=torch.float32)
x2 = torch.tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4],
                    [5, 5, 5, 5, 5, 5, 5, 5, 5],
                    [6, 6, 6, 6, 6, 6, 6, 6, 6]], dtype=torch.float32)

# Edge indices of shape (2, num_edges) and type torch.long
edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3],
                             [1, 0, 2, 1, 3, 2]], dtype=torch.long)
edge_index2 = torch.tensor([[0, 1, 1, 2],
                             [1, 0, 2, 1]], dtype=torch.long)

# Edge features of shape (num_edges, num_edge_features) and type torch.float32
edge_attr1 = torch.tensor([[0, 0, 0],
                            [1,1,1],
                            [2,2,2],
                            [3,3,3],
                            [4,4,4],
                            [5,5,5]], dtype=torch.float32)
edge_attr2 = torch.tensor([[6,6,6],
                            [7,7,7],
                            [8,8,8],
                            [9,9,9]], dtype=torch.float32)

# Pair label of shape (1,) and type torch.long
y = torch.tensor([1], dtype=torch.float32)

data = TripletData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,  # Graph 1.
                x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,  # Graph 2.
                x3=x2, edge_index3=edge_index2, edge_attr3=edge_attr2,  # Graph 2.
                y=y) #Graph pair label. 

In [10]:
data_list = [data,data,data,data]
dataloader = DataLoader(data_list, batch_size=2, follow_batch=['x1', 'x2', 'x3'])

In [15]:
from models import temporal_shuffling
model = temporal_shuffling(num_node_features=9, num_edge_features=3, hidden_channels=[32, 64], out_channels=32)

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

for batch in dataloader:
    # Move batch data to the device
    batch.x1 = batch.x1.to(device)
    batch.edge_index1 = batch.edge_index1.to(device)
    batch.edge_attr1 = batch.edge_attr1.to(device)
    batch.x1_batch = batch.x1_batch.to(device)
    
    batch.x2 = batch.x2.to(device)
    batch.edge_index2 = batch.edge_index2.to(device)
    batch.edge_attr2 = batch.edge_attr2.to(device)
    batch.x2_batch = batch.x2_batch.to(device)

    batch.x3 = batch.x3.to(device)
    batch.edge_index3 = batch.edge_index3.to(device)
    batch.edge_attr3 = batch.edge_attr3.to(device)
    batch.x3_batch = batch.x3_batch.to(device)
    
    batch.y = batch.y.to(device)

    optimizer.zero_grad()
    
    out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch,
                batch.x3, batch.edge_index3, batch.edge_attr3, batch.x3_batch)
    loss = criterion(out, batch.y)
    
    loss.backward()
    optimizer.step()

    print(loss.item())

0.6862651109695435
0.3853457570075989
