In [None]:
import h5py
import numpy as np
import networkx as nx
import pickle
from joblib import Parallel, delayed
from scipy.stats import entropy
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx

In [3]:
def bin_to_tick(burst_bin, bin_width_ms=10, tick_ms=0.1):
    """
    Convert burst start bin number to tick.
    
    Parameters:
    - burst_bin: int, the burst start bin index.
    - bin_width_ms: float, duration of each bin in milliseconds (default 10ms).
    - tick_ms: float, duration of one tick in milliseconds (default 0.1ms).
    
    Returns:
    - tick: int, corresponding tick number.
    """
    ticks_per_bin = bin_width_ms / tick_ms
    return int(burst_bin * ticks_per_bin)

In [None]:


def get_pre_burst_spikes_and_features(h5_path, node_ids, burst_tick, k=20):
    """
    For each node in node_ids, extract last k spikes before burst_tick and compute features.
    Returns a dict: node_id -> {'spikes': np.array[k], **features_dict}
    """
    result = {}
    with h5py.File(h5_path, 'r') as h5:
        for n in node_ids:
            ds = f"Neuron_{n}"
            if ds not in h5:
                arr = np.zeros(k, dtype=float)
                feats = get_spike_features(arr)
            else:
                spikes = np.sort(h5[ds][()]).astype(float)
                idx0 = np.searchsorted(spikes, burst_tick)
                last_k = spikes[max(0, idx0-k):idx0]
                if len(last_k) < k:
                    last_k = np.pad(last_k, (k-len(last_k), 0), 'constant')
                arr = last_k
                feats = get_spike_features(arr)
            entry = {'spikes': arr}
            entry.update(feats)
            result[n] = entry
    return result

def get_non_burst_spikes_and_features(h5_path, node_ids, burst_tick, k=20, gap = 100):
    """
    Extract the last k spike times before burst_tick for each node.
    Returns a dict: node_id -> 1D np.array of length k.
    """
    result = {}
    with h5py.File(h5_path, 'r') as h5:
        for n in node_ids:
            ds = f"Neuron_{n}"
            if ds not in h5:
                arr = np.zeros(k, dtype=float)
                feats = get_spike_features(arr)
            else:
                spikes = np.sort(h5[ds][()]).astype(float)
                idx0 = np.searchsorted(spikes, burst_tick)
                last_k = spikes[max(0, idx0 - k - gap):idx0 - gap]
                if len(last_k) < k:
                    last_k = np.pad(last_k, (k-len(last_k), 0), 'constant')
                arr = last_k
                feats = get_spike_features(arr)
            entry = {'spikes': arr}
            entry.update(feats)
            result[n] = entry
    return result

def get_spike_features(spikes):
    """
    Convert spike times to features.
    Here we simply return the spike times as features.
    """
    # Example feature extraction: normalize and scale
    isis = np.diff(spikes)
    m, s = np.nanmean(isis), np.nanstd(isis)

    return {
       'mean_isi': m,
       'entropy_isi': entropy(isis)       
    }

# assume these helper functions are defined elsewhere in your script:
# - bin_to_tick(burst_bin, bin_width_ms, tick_ms)
# - get_spikes_and_features(h5_path, node_ids, burst_tick, k)
# - get_non_burst_spikes_and_features(h5_path, node_ids, burst_tick, k, gap)

h5_path    = "/DATA/hdhanu/GNN/Burst_Data/tR_1.0--fE_0.98_10000.h5"
pkl_path   = "/DATA/hdhanu/GNN/Graphs_and_Features/Sub_graph/subgraphs_final.pkl"
entries    = pickle.load(open(pkl_path, 'rb'))

# list of summary feature names (must match get_spike_features keys)
feat_names = [
    'mean_isi','entropy_isi'
]

data_list = []
def process_entry(entry):
    originbin  = entry['burstoriginbin']
    burst_tick = bin_to_tick(originbin, bin_width_ms=10, tick_ms=0.1)
    G = entry['subgraph']
    nodes = list(G.nodes())
    edge_index = from_networkx(G).edge_index

    pre_dict = get_pre_burst_spikes_and_features(h5_path, nodes, burst_tick, k=20)
    non_dict = get_non_burst_spikes_and_features(h5_path, nodes, burst_tick, k=20, gap=100)

    # Build feature matrices in node order
    pre_feats = [[pre_dict[n][f] for f in feat_names] for n in nodes]
    non_feats = [[non_dict[n][f] for f in feat_names] for n in nodes]

    pre_data = Data(
        x=torch.tensor(pre_feats, dtype=torch.float),
        edge_index=edge_index,
        y=torch.tensor([1])
    )
    non_data = Data(
        x=torch.tensor(non_feats, dtype=torch.float),
        edge_index=edge_index,
        y=torch.tensor([0])
    )
    return pre_data, non_data

if __name__ == '__main__':
    # Parallelize over all entries using all available cores
    results = Parallel(n_jobs=-1, verbose=10)(
        delayed(process_entry)(entry) for entry in entries
    )

    # Unpack into a single list
    data_list = []
    for pre_data, non_data in results:
        data_list.extend([pre_data, non_data])

    # Save processed dataset
    torch.save(data_list, '/DATA/hdhanu/GNN/Manual_check/processed_data.pt')

    # Wrap in a DataLoader for training
    loader = DataLoader(data_list, batch_size=32, shuffle=True)

    # Verify
    for batch in loader:
        print(batch.x.shape, batch.edge_index.shape, batch.y.shape)
        break

Processed Burst 6135 with 107 nodes.
Processed Burst 4434 with 132 nodes.
Processed Burst 5132 with 122 nodes.
Processed Burst 7521 with 108 nodes.
Processed Burst 7021 with 145 nodes.
Processed Burst 2586 with 134 nodes.
Processed Burst 8089 with 145 nodes.
Processed Burst 3821 with 140 nodes.
Processed Burst 1073 with 132 nodes.
Processed Burst 2432 with 122 nodes.
Processed Burst 2318 with 145 nodes.
Processed Burst 8173 with 107 nodes.
Processed Burst 6014 with 107 nodes.
Processed Burst 170 with 133 nodes.
Processed Burst 5114 with 122 nodes.
Processed Burst 8977 with 130 nodes.
Processed Burst 5422 with 108 nodes.
Processed Burst 4959 with 97 nodes.
Processed Burst 5547 with 141 nodes.
Processed Burst 6944 with 108 nodes.
Processed Burst 532 with 128 nodes.
Processed Burst 6051 with 108 nodes.
Processed Burst 3103 with 141 nodes.
Processed Burst 1645 with 138 nodes.
Processed Burst 7093 with 99 nodes.
Processed Burst 481 with 133 nodes.
Processed Burst 3084 with 132 nodes.
Proces

KeyboardInterrupt: 

In [3]:
import torch

# 1) Load your processed graphs (turn off weights_only safety)
data_list = torch.load(
    '/DATA/hdhanu/GNN/Manual_check/processed_data.pt',
    weights_only=False
)

total_graphs     = len(data_list)
graphs_with_nan  = 0
total_nan_counts = 0

# 2) Iterate through each graph’s node‐features
for idx, data in enumerate(data_list):
    # data.x is [num_nodes, num_features]
    nan_mask = torch.isnan(data.x)
    n_nans   = nan_mask.sum().item()
    if n_nans > 0:
        graphs_with_nan  += 1
        total_nan_counts += n_nans
        print(f"Graph {idx} has {n_nans} NaNs")

# 3) Summary
print(f"\nFound NaNs in {graphs_with_nan}/{total_graphs} graphs, " 
      f"total NaN entries: {total_nan_counts}")


Graph 4761 has 21 NaNs
Graph 5323 has 2 NaNs
Graph 6399 has 1 NaNs
Graph 12820 has 118 NaNs
Graph 14663 has 22 NaNs
Graph 18283 has 2 NaNs

Found NaNs in 6/18316 graphs, total NaN entries: 166


In [4]:
clean_list   = []
dropped_bursts = 0

# data_list was built as [pre_0, non_0, pre_1, non_1, …]
for i in range(0, len(data_list), 2):
    pre, non = data_list[i], data_list[i+1]
    # if either graph has NaNs, drop the whole pair
    if torch.isnan(pre.x).any() or torch.isnan(non.x).any():
        dropped_bursts += 1
        continue
    clean_list.extend([pre, non])

print(f"Dropped {dropped_bursts} bursts ({dropped_bursts*2} graphs) with NaNs")
data_list = clean_list
torch.save(data_list, '/DATA/hdhanu/GNN/Manual_check/clean_processed_data.pt')



Dropped 6 bursts (12 graphs) with NaNs


In [7]:
# 2) Stack all node‐features into F
#    Each data.x is [n_nodes_i, n_feats]
from sklearn.discriminant_analysis import StandardScaler
import numpy as np

F = np.vstack([ data.x.numpy() for data in data_list ])  # shape [total_nodes, n_feats]

# 3) Fit the StandardScaler
scaler = StandardScaler()
scaler.fit(F)

# 4) Transform each graph’s x in place
for data in data_list:
    X = data.x.numpy()           # [n_nodes, n_feats]
    X_scaled = scaler.transform(X)
    data.x = torch.tensor(X_scaled, dtype=torch.float)

# 5) Save the scaled dataset and the scaler object
torch.save(data_list, '/DATA/hdhanu/GNN/Manual_check/clean_scaled_processed_data.pt')