In [1]:
# Test case
import torch
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from pyg_model import relative_positioning

# PairData class
class PairData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index1':
            return self.x1.size(0)
        if key == 'edge_index2':
            return self.x2.size(0)
        return super().__inc__(key, value, *args, **kwargs)


# Create a dummy graph using PyG's Data class
def create_sample_graph(num_nodes, num_node_features, num_edge_features):
    x1, x2 = torch.randn(num_nodes, num_node_features), torch.randn(num_nodes, num_node_features)  # Node features
    edge_index1 = torch.tensor([[0, 2, 3, 1, 4],
                                [2, 0, 1, 3, 3]], dtype=torch.long)
    edge_index2 = edge_index1
    num_edges = edge_index1.size()[1]
    edge_attr1, edge_attr2 = torch.randn((num_edges, num_edge_features)), torch.randn((num_edges, num_edge_features))   # Edge attributes
    
    y = torch.randint(0, 2, (1,))
    y = y.to(torch.float32)
    
    return PairData(x1=x1, edge_index1=edge_index1, edge_attr1=edge_attr1,
                    x2=x2, edge_index2=edge_index2, edge_attr2=edge_attr2,
                    y=y)

# Hyperparameters
num_node_features = 10
num_edge_features = 6
hidden_channels = 20
out_channels = 5

# Create dummy data
data_list = []
for i in range(8):
    data_list.append(create_sample_graph(5, 10, 6))

# Dataloader
loader = DataLoader(data_list, batch_size=4, follow_batch=['x1', 'x2'])

# Initialize the model
model = relative_positioning(num_node_features=num_node_features, num_edge_features=num_edge_features, 
                                   hidden_channels=hidden_channels, out_channels=out_channels)

for batch in loader:
    output = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                         batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch)
    print("Output shape:", output.shape)
    print("Output:", output)

Output shape: torch.Size([4])
Output: tensor([0.5921, 0.5860, 0.5816, 0.5837], grad_fn=<SqueezeBackward1>)
Output shape: torch.Size([4])
Output: tensor([0.5870, 0.5891, 0.5799, 0.5780], grad_fn=<SqueezeBackward1>)


In [3]:
# Training test
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

for epoch in range(1, 1000):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x1, batch.edge_index1, batch.edge_attr1, batch.x1_batch,
                    batch.x2, batch.edge_index2, batch.edge_attr2, batch.x2_batch, mode="sigmoid")
        loss = criterion(out, batch.y.float())
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss.item()}')
    if loss.item() < 1e-5:
        print("Success! Looks like we're overfitting.")
        break

Epoch: 1, Loss: 0.10404607653617859
Epoch: 2, Loss: 0.10238519310951233
Epoch: 3, Loss: 0.10023936629295349
Epoch: 4, Loss: 0.09825339168310165
Epoch: 5, Loss: 0.09644834697246552
Epoch: 6, Loss: 0.094691701233387
Epoch: 7, Loss: 0.09297184646129608
Epoch: 8, Loss: 0.09135718643665314
Epoch: 9, Loss: 0.08962897211313248
Epoch: 10, Loss: 0.0879705622792244
Epoch: 11, Loss: 0.08628644049167633
Epoch: 12, Loss: 0.0844610258936882
Epoch: 13, Loss: 0.08224180340766907
Epoch: 14, Loss: 0.08501631021499634
Epoch: 15, Loss: 0.07691359519958496
Epoch: 16, Loss: 0.07275760918855667
Epoch: 17, Loss: 0.06486956775188446
Epoch: 18, Loss: 0.05056009814143181
Epoch: 19, Loss: 0.029799308627843857
Epoch: 20, Loss: 0.010220888070762157
Epoch: 21, Loss: 0.0024423860013484955
Epoch: 22, Loss: 0.0006682530511170626
Epoch: 23, Loss: 1.5460524082300253e-05
Epoch: 24, Loss: 1.907374780785176e-06
Success! Looks like we're overfitting.


In [2]:
# Load PairData class dataset of graph pairs
import torch
path = r"C:\Users\xmoot\Desktop\Data\ssl-seizure-detection\patient_pseudolabeled\relative_positioning\PyG\jh101_12s_7min_PairData.pt"
data = torch.load(path)

In [3]:
# Setup GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda
