In [2]:
import sys
import os
sys.path.append("../src")


import numpy as np
import torch
# OR to import specific functions:
from preprocess import graph_pairs, graph_triplets, pseudo_data, convert_to_TripletData, ef_to_edge_attr, adj_to_edge_attr, create_tensordata_new, build_K_n
from torch_geometric.data import Data



In [19]:
def test_graph_pairs():
    # Create a synthetic dataset with 10 graph representations
    # Note that the graph representations won't be in this format, but it doesn't matter as it only needs to work for general lists with entries [a,b,c,d].
    n = 4500
    data = [[[i], np.random.randint(2)] for i in range(n)]
    
    # Call the function with the synthetic dataset
    sample_ratio = 0.8
    result = graph_pairs(data, tau_pos=12//0.12, tau_neg=90//0.12, sample_ratio=sample_ratio)
    
    print(f"Number of samples: {len(result)}")

    # Print the result to see what the function returns
    k = 0
    for pair in result:
        print(f"Graph 1: {pair[0]}")
        print(f"Graph 2: {pair[1]}")
        print(f"Pseudo Label: {pair[2]}\n")
        k+=1
        if k==10:
            break

test_graph_pairs()


# 4500 with sample_ratio = 0.8 yields 570,000 (28s)
# 2500 with sample_ratio = 0.7 yields 240,000
# 2500 with sample_ratio = 0.8 yields 310,000
# 2200 with sample_ratio = 0.7 yields 210,000
# 2200 with sample_ratio = 0.8 yields 270,000
# 1500 with sample_ratio = 0.7 yields 185,000

Number of samples: 570050
Graph 1: [3324]
Graph 2: [3372]
Pseudo Label: 1

Graph 1: [542]
Graph 2: [593]
Pseudo Label: 1

Graph 1: [2737]
Graph 2: [2780]
Pseudo Label: 1

Graph 1: [2286]
Graph 2: [3890]
Pseudo Label: 0

Graph 1: [1329]
Graph 2: [1379]
Pseudo Label: 1

Graph 1: [1356]
Graph 2: [3313]
Pseudo Label: 0

Graph 1: [3297]
Graph 2: [4197]
Pseudo Label: 0

Graph 1: [334]
Graph 2: [2599]
Pseudo Label: 0

Graph 1: [2643]
Graph 2: [3706]
Pseudo Label: 0

Graph 1: [2939]
Graph 2: [4065]
Pseudo Label: 0



In [46]:
def test_graph_triplets():
    # Create a synthetic dataset with 10 graph representations
    # Note that the graph representations won't be in this format, but it doesn't matter as it only needs to work for general lists with entries [a,b,c,d].
    n = 1500
    data = [[[i], np.random.randint(2)] for i in range(n)]
    
    # Call the function with the synthetic dataset
    sample_ratio = 0.22
    result = graph_triplets_new(data, tau_pos=12//0.12, tau_neg=90//0.12, sample_ratio=sample_ratio)
    
    print(f"Number of samples: {len(result)}")

    # Print the result to see what the function returns
    k = 0
    for triplet in result:
        print(f"Graph 1: {triplet[0]}")
        print(f"Graph 2: {triplet[1]}")
        print(f"Graph 3: {triplet[2]}")
        print(f"Pseudo Label: {triplet[3]}\n")
        k+=1
        if k==10:
            break

test_graph_triplets()

# 2200 with sample_ratio = 0.3 yields 635,662 samples
# 2200 with sample_ratio = 0.22 yields 222,000 samples
# 1500 with sample_ratio = 0.22 yields 150,000 samples

Number of samples: 147634
Graph 1: [173]
Graph 2: [1129]
Graph 3: [198]
Pseudo Label: 0

Graph 1: [1300]
Graph 2: [117]
Graph 3: [1366]
Pseudo Label: 0

Graph 1: [567]
Graph 2: [581]
Graph 3: [618]
Pseudo Label: 1

Graph 1: [850]
Graph 2: [853]
Graph 3: [890]
Pseudo Label: 1

Graph 1: [707]
Graph 2: [758]
Graph 3: [759]
Pseudo Label: 1

Graph 1: [383]
Graph 2: [1186]
Graph 3: [393]
Pseudo Label: 0

Graph 1: [1137]
Graph 2: [1159]
Graph 3: [1222]
Pseudo Label: 1

Graph 1: [887]
Graph 2: [173]
Graph 3: [940]
Pseudo Label: 0

Graph 1: [475]
Graph 2: [1222]
Graph 3: [509]
Pseudo Label: 0

Graph 1: [571]
Graph 2: [1047]
Graph 3: [669]
Pseudo Label: 0



In [20]:
def test_pseudo_data():
    # Step 1: Create a sample data list
    data = [
        [["gr1", "gr1", "gr1"], np.random.randint(2)],
        [["gr2", "gr2", "gr2"], np.random.randint(2)],
        [["gr3", "gr3", "gr3"], np.random.randint(2)],
        [["gr4", "gr4", "gr4"], np.random.randint(2)],
        [["gr5", "gr5", "gr5"], np.random.randint(2)],
        [["gr6", "gr6", "gr6"], np.random.randint(2)]
        ]
    
    # Step 2: Test the function with various inputs
    # Test with default parameters
    try:
        pairs = pseudo_data(data, tau_pos = 2, tau_neg = 3, stats = True, save = False, patientid = "patient", logdir = None, model = "relative_positioning")
        assert isinstance(pairs, list), "Output is not a list"
        
        # Check the length of the output
        assert len(pairs) > 0, "Output list is empty"
        
        # Check the structure of the first element in the output
        assert len(pairs[0]) == 3, "Output elements do not have the correct structure"
        print(pairs)

    except Exception as e:
        print(f"Test with default parameters failed with error: {e}")
    
    
    # Test with a different model
    try:
        triplets = pseudo_data(data, tau_pos = 2, tau_neg = 4, stats = True, save = False, patientid = "patient", logdir = None, model = "temporal_shuffling")
        assert isinstance(triplets, list), "Output is not a list"
        
        # Check the length of the output
        assert len(triplets) > 0, "Output list is empty"
        
        # Check the structure of the first element in the output
        assert len(triplets[0]) == 4, "Output elements do not have the correct structure"
        print(triplets)
        
    except Exception as e:
        print(f"Test with temporal shuffling failed with error: {e}")

    
    print("Tests completed")

test_pseudo_data()

Number of examples: 6
y
1    3
0    3
Name: count, dtype: int64
[[['gr2', 'gr2', 'gr2'], ['gr3', 'gr3', 'gr3'], 1], [['gr2', 'gr2', 'gr2'], ['gr4', 'gr4', 'gr4'], 1], [['gr1', 'gr1', 'gr1'], ['gr5', 'gr5', 'gr5'], 0], [['gr2', 'gr2', 'gr2'], ['gr6', 'gr6', 'gr6'], 0], [['gr1', 'gr1', 'gr1'], ['gr6', 'gr6', 'gr6'], 0], [['gr4', 'gr4', 'gr4'], ['gr6', 'gr6', 'gr6'], 1]]
Number of examples: 38
y
1    19
0    19
Name: count, dtype: int64
[[['gr3', 'gr3', 'gr3'], ['gr5', 'gr5', 'gr5'], ['gr6', 'gr6', 'gr6'], 1], [['gr3', 'gr3', 'gr3'], ['gr5', 'gr5', 'gr5'], ['gr4', 'gr4', 'gr4'], 0], [['gr2', 'gr2', 'gr2'], ['gr3', 'gr3', 'gr3'], ['gr5', 'gr5', 'gr5'], 1], [['gr1', 'gr1', 'gr1'], ['gr6', 'gr6', 'gr6'], ['gr5', 'gr5', 'gr5'], 0], [['gr1', 'gr1', 'gr1'], ['gr5', 'gr5', 'gr5'], ['gr4', 'gr4', 'gr4'], 0], [['gr1', 'gr1', 'gr1'], ['gr2', 'gr2', 'gr2'], ['gr4', 'gr4', 'gr4'], 1], [['gr1', 'gr1', 'gr1'], ['gr3', 'gr3', 'gr3'], ['gr5', 'gr5', 'gr5'], 1], [['gr3', 'gr3', 'gr3'], ['gr6', 'gr6', 'gr6

In [22]:
class TripletData(Data):
    """
    Creates the torch_geometric data object for a triplets of graphs.
    
    """
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index1':
            return self.x1.size(0)
        if key == 'edge_index2':
            return self.x2.size(0)
        if key == 'edge_index3':
            return self.x3.size(0)
        return super().__inc__(key, value, *args, **kwargs)



def test_convert_to_TripletData():
    # Step 1: Create a sample data list
    data_list = [
        [
            [torch.tensor([[0, 1], [1, 2]]), torch.rand((3, 3)), torch.rand((2, 3))],
            [torch.tensor([[1, 0], [2, 1]]), torch.rand((3, 3)), torch.rand((2, 3))],
            [torch.tensor([[2, 0], [1, 2]]), torch.rand((3, 3)), torch.rand((2, 3))],
            torch.tensor([1])
        ] 
        for _ in range(3)
    ]
    
    # Step 2: Call the function with the sample data
    output = convert_to_TripletData(data_list, save=False)
    
    
    if torch.equal(output[0].x1,data_list[0][0][1]) and torch.equal(output[2].edge_index3, data_list[2][2][0]) and torch.equal(output[0].y, data_list[0][3]):
        print(output[0].x1)
        print(data_list[0][0][1])
        print(data_list[2][2][0]) 
        print(output[2].edge_index3)
        print(output[0].y)
        print(data_list[0][3])
        print("Test passed")

test_convert_to_TripletData()

tensor([[0.1727, 0.3077, 0.0150],
        [0.5427, 0.1907, 0.3594],
        [0.0342, 0.7222, 0.2739]])
tensor([[0.1727, 0.3077, 0.0150],
        [0.5427, 0.1907, 0.3594],
        [0.0342, 0.7222, 0.2739]])
tensor([[2, 0],
        [1, 2]])
tensor([[2, 0],
        [1, 2]])
tensor([1])
tensor([1])
Test passed


In [4]:
import numpy as np

def test_ef_to_edge_attr():
    # Define a 3-node complete graph with edge features
    edge_index = np.array([
        [0, 1, 2, 0, 2, 1],
        [1, 0, 0, 2, 1, 2]
    ])

    # Define edge features for 3 nodes (3 x 3 x 2 matrix)
    ef = np.array([
        [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]],
        [[0.3, 0.4], [0.1, 0.2], [0.7, 0.8]],
        [[0.5, 0.6], [0.7, 0.8], [0.1, 0.2]]
    ])

    # Expected output (6 edges x 2 features)
    expected_edge_attr = np.array([
        [0.3, 0.4],
        [0.3, 0.4],
        [0.5, 0.6],
        [0.5, 0.6],
        [0.7, 0.8],
        [0.7, 0.8]
    ])

    if np.array_equal(ef_to_edge_attr(edge_index, ef), expected_edge_attr):
        print("Test passed")


test_ef_to_edge_attr()

Test passed


In [6]:
def test_create_tensordata_new(mode = "binary"):

    num_nodes = 4
    num_edges = int(num_nodes * (num_nodes - 1) / 2)
    num_node_features = 3
    num_edge_features = 2

    # Create data_list with the 3 examples with entries of the form [[x, ef], y]
    if mode == "multi":
        data_list = [[[np.random.rand(num_nodes, num_node_features), np.random.rand(num_nodes, num_nodes, num_edge_features)], np.random.randint(3)] for i in range(6)]

    if mode == "binary":
        data_list = [[[np.random.rand(num_nodes, num_node_features), np.random.rand(num_nodes, num_nodes, num_edge_features)], np.random.randint(2)] for i in range(6)]
    
    return create_tensordata_new(num_nodes, data_list, complete=True, save=False, logdir=None)

print(test_create_tensordata_new(mode = "binary"))

[[[tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]), tensor([[0.7062, 0.2018, 0.4384],
        [0.0263, 0.8856, 0.5984],
        [0.8255, 0.7409, 0.5824],
        [0.6624, 0.1274, 0.6451]]), tensor([[0.6877, 0.3812],
        [0.5184, 0.4200],
        [0.7134, 0.6546],
        [0.6486, 0.3674],
        [0.7353, 0.3583],
        [0.9322, 0.8115],
        [0.7168, 0.4183],
        [0.8734, 0.8871],
        [0.4216, 0.0684],
        [0.3000, 0.1026],
        [0.8075, 0.5920],
        [0.9838, 0.5842]])], tensor(0)], [[tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]), tensor([[0.5896, 0.6931, 0.1396],
        [0.6675, 0.2411, 0.4332],
        [0.3772, 0.8732, 0.8673],
        [0.6981, 0.6055, 0.0320]]), tensor([[0.9834, 0.0980],
        [0.4269, 0.3192],
        [0.3849, 0.5704],
        [0.9851, 0.7094],
        [0.4818, 0.4551],
        [0.7171, 0.1559],
        [0.3757, 0.6507],
        [0.8028, 0.814

In [24]:
def test_adj_to_edge_attr():

    # Adjacency matrix A (same for all cases)
    A = np.array([[0, 1, 0, 0],
                [1, 0, 1, 0],
                [0, 1, 0, 1],
                [0, 0, 1, 0]])

    # Edge index (same for all cases). Complete graph K_4.
    edge_index = build_K_n(4)

    # Test Case 1: No edge features
    case1_edge_attr = None

    # Test Case 2: Edge features in FCN format (shape = (num_nodes, num_nodes))
    case2_edge_attr = np.random.rand(4, 4, 2)

    # Test Case 3: Edge features in PyG format (shape = (num_edges, 1))
    case3_edge_attr = np.random.rand(12, 2)

    test_cases = {
        "Case 1": (A, edge_index, case1_edge_attr),
        "Case 2": (A, edge_index, case2_edge_attr),
        "Case 3": (A, edge_index, case3_edge_attr),
    }

    # Case 1: No edge features.
    print("Case 1: No edge features.")
    print(adj_to_edge_attr(A, edge_index))

    # Case 2: Edge features in FCN format shape = (num_nodes, num_nodes, num_edge_features).
    print("\nCase 2: Edge features in FCN format shape = (num_nodes, num_nodes, num_edge_features).")
    edge_attr_new = adj_to_edge_attr(A, edge_index, case2_edge_attr, "FCN")
    for i in range(4):
        for j in range(4):
            if i != j:
                print(f"A[{i}, {j}]: {A[i, j]}")
                print(f"Old edge_attr[{i}, {j}]: {case2_edge_attr[i, j]}")
    for k in range(6):
        print(f"New edge_attr[{k}]: {edge_attr_new[k]}")

    # Case 3: Edge features in PyG format shape = (num_edges, num_edge_features).
    print("\nCase 3: Edge features in PyG format shape = (num_edges, num_edge_features).")
    edge_attr_new = adj_to_edge_attr(A, edge_index, case3_edge_attr, "PyG")
    for k in range(12):
        print(f"New edge_attr[{k}]: {edge_attr_new[k]}")
        print(f"Old edge_attr[{k}]: {case3_edge_attr[k]}")

test_adj_to_edge_attr()

Case 1: No edge features.
[[1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]
 [0.]
 [1.]]

Case 2: Edge features in FCN format shape = (num_nodes, num_nodes, num_edge_features).
A[0, 1]: 1
Old edge_attr[0, 1]: [0.76202206 0.58994923]
A[0, 2]: 0
Old edge_attr[0, 2]: [0.67713691 0.93162194]
A[0, 3]: 0
Old edge_attr[0, 3]: [0.42088023 0.9501309 ]
A[1, 0]: 1
Old edge_attr[1, 0]: [0.0814149  0.04268138]
A[1, 2]: 1
Old edge_attr[1, 2]: [0.54536779 0.61504542]
A[1, 3]: 0
Old edge_attr[1, 3]: [0.67335006 0.53333583]
A[2, 0]: 0
Old edge_attr[2, 0]: [0.1271862  0.72529445]
A[2, 1]: 1
Old edge_attr[2, 1]: [0.39719874 0.63849772]
A[2, 3]: 1
Old edge_attr[2, 3]: [0.41139917 0.70986625]
A[3, 0]: 0
Old edge_attr[3, 0]: [0.41701962 0.75990651]
A[3, 1]: 0
Old edge_attr[3, 1]: [0.50577002 0.14245345]
A[3, 2]: 1
Old edge_attr[3, 2]: [0.79951818 0.98112661]
New edge_attr[0]: [1.         0.76202206 0.58994923]
New edge_attr[1]: [0.         0.67713691 0.93162194]
New edge_attr[2]: [0.         0.420

In [4]:
import sys
sys.path.append("../src")
from preprocess import create_data_loaders
import torch

def test_create_data_loaders():
    dataset = torch.load("/Users/xaviermootoo/Documents/Data/ssl-seizure-detection/patient_pyg/jh101/supervised/jh101_run1.pt")
    train_loader, val_loader, test_loader = create_data_loaders(dataset, data_size=1.0, val_ratio=0.2, test_ratio=0.1, batch_size=32, num_workers=4, model_id="supervised")
    for batch in train_loader:
        print(batch)
        break
    for batch in val_loader:
        print(batch)
        break
    for batch in test_loader:
        print(batch)
        break

test_create_data_loaders()


DataBatch(x=[3424, 9], edge_index=[2, 362944], edge_attr=[362944, 3], y=[32], batch=[3424], ptr=[33])
DataBatch(x=[3424, 9], edge_index=[2, 362944], edge_attr=[362944, 3], y=[32], batch=[3424], ptr=[33])

DataBatch(x=[3424, 9], edge_index=[2, 362944], edge_attr=[362944, 3], y=[32], batch=[3424], ptr=[33])
