In [61]:
from collections import Counter
import pandas as pd
import numpy as np
import pickle
import tqdm
import math
import random

from torch_geometric.sampler import BaseSampler, NodeSamplerInput, EdgeSamplerInput
from torch_geometric.loader import NeighborLoader

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit

In [49]:
np.random.seed(42)

Graph sampling implementation using Neighborloader (torch_geometric). Generates graph samples by balancing classes in the dataset.

#### For multiclass

In [50]:
sample_ids_file = "./graph_with_edgeW_sample_id.pkl"
file = open(sample_ids_file, 'rb')
sample_ids = pickle.load(file)

labels = pd.read_csv("./metadata.csv")
labels.index = labels.SubID
labels_sub = labels[['SubID', 'BRAAK_AD']]
labels_sub=labels_sub.loc[sample_ids]

label_sub_val_counts = labels_sub["BRAAK_AD"].value_counts()
count=0
subgraphs_per_gr = {}
print ("braak","#samples","#subgraphs_to_be_generated")
for i in range(len(label_sub_val_counts)):
    print(i, label_sub_val_counts[i], int(label_sub_val_counts.max()/label_sub_val_counts[i])*2)
    count+= label_sub_val_counts[i]* int(label_sub_val_counts.max()/label_sub_val_counts[i])*2
    subgraphs_per_gr[i] = int(label_sub_val_counts.max()/label_sub_val_counts[i])*2
print (f'total graphs: {count}')

braak #samples #subgraphs_to_be_generated
0 36 12
1 55 6
2 74 4
3 75 4
4 52 8
5 62 6
6 218 2
total graphs: 2582


In [52]:
all_graphs=[]
all_graphs_braak=[]
all_graphs_sample_ids=[]

In [53]:
for s_id in sample_ids:
    br = labels_sub.loc[s_id]["BRAAK_AD"]
    NUM_SUBGRAPHS = subgraphs_per_gr[br]
    data = graphs[sample_ids.index(s_id)]
    loader = NeighborLoader(
    data,
    num_neighbors=[10] * 3,
    batch_size=math.ceil(data.x.shape[0]/NUM_SUBGRAPHS), # number of nodes to keep in each sampled graph
    )
    if len(loader) > NUM_SUBGRAPHS:
        print ("Length greater than 3 detected ...")
        print ("Skipping this graph ...")
        continue
    for d in loader:
        all_graphs.append(d)
        all_graphs_braak.append(br)
        all_graphs_sample_ids.append(s_id)
    #break



In [54]:
len(all_graphs), len(graphs), len(all_graphs_braak)

(2582, 572, 2582)

save all graphs, braak levels

In [None]:
with open('./graph_with_edgeW_all_subgraphs.pkl', 'wb') as f:
    pickle.dump(all_graphs, f)

with open('./graph_with_edgeW_all_braak_levels.pkl', 'wb') as f:
    pickle.dump(all_graphs_sample_ids, f)

In [55]:
X_train, X_test, y_train, y_test = train_test_split(all_graphs, all_graphs_braak,stratify=all_graphs_braak,test_size=0.2)

with open('./graph_with_edgeW_train_subgraphs.pkl', 'wb') as f:
    pickle.dump(X_train, f)

with open('./graph_with_edgeW_test_subgraphs.pkl', 'wb') as f:
    pickle.dump(X_test, f)

with open('./graph_with_edgeW_train_subgraphs_braak_levels.pkl', 'wb') as f:
    pickle.dump(y_train, f)

with open('./graph_with_edgeW_test_subgraphs_braak_levels.pkl', 'wb') as f:
    pickle.dump(y_test, f)

In [62]:
Counter(all_graphs_braak)

Counter({1.0: 330, 0.0: 432, 5.0: 372, 6.0: 436, 2.0: 296, 3.0: 300, 4.0: 416})

In [64]:
len(X_train), len(X_test)

(2065, 517)

In [63]:
Counter(y_train), Counter(y_test)

(Counter({4.0: 333,
          0.0: 345,
          2.0: 237,
          3.0: 240,
          6.0: 349,
          1.0: 264,
          5.0: 297}),
 Counter({4.0: 83, 2.0: 59, 0.0: 87, 1.0: 66, 6.0: 87, 3.0: 60, 5.0: 75}))

#### For binary labels

In [5]:
GRAPH_PKL_FILE="./graph_with_edgeW.pkl"
file = open(GRAPH_PKL_FILE, 'rb')
graphs = pickle.load(file)

In [3]:
NUM_SUBGRAPHS=2

In [6]:
all_graphs=[]

In [7]:
for data in graphs:
    loader = NeighborLoader(
    data,
    num_neighbors=[10] * 3,
    batch_size=math.ceil(data.x.shape[0]/NUM_SUBGRAPHS), # number of nodes to keep in each sampled graph
    )
    if len(loader) > NUM_SUBGRAPHS:
        print ("Length greater than 3 detected ...")
        print ("Skipping this graph ...")
        continue
    for d in loader:
        all_graphs.append(d)



In [8]:
len(all_graphs), len(graphs)*NUM_SUBGRAPHS

(1242, 1242)

In [16]:
with open('./graph_with_edgeW_subgraphs_2_each.pkl', 'wb') as f:
    pickle.dump(all_graphs, f)