In [1]:
import networkx as nx
import numpy as np
import scanpy as sc
import squidpy as sq
from sklearn.metrics import r2_score
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.data import Data   # Create data containers
from torch_geometric.utils import from_networkx
from torch.utils.data import random_split
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

In [2]:
class Dataloader:

    # Constructor
    # file_path: path to the .h5ad file
    # image_col: column name of the image id
    # label_col: column name of the label
    # include_label: whether to include the label in the graph
    # radius: radius of the ego graph
    # node_level: number of node levels to include in the ego graph
    # batch_size: batch size for the data loader
    # split_percent: tuple of percentages for train, validation, and test sets
    def __init__(self, file_path, image_col ,label_col, include_label, radius,node_level, batch_size, split_percent):
        self.file_path = file_path
        self.image_col = image_col
        self.label_col = label_col
        self.node_level = node_level
        self.include_label = include_label
        self.radius = radius
        self.batch_size = batch_size
        self.split_percent = split_percent

    def load_data(self):
        raise NotImplementedError

    def construct_graph(self):
        raise NotImplementedError

    def split_data(self, loader):
        # Assuming split_percent is a tuple like (0.7, 0.2, 0.1)
        train_size = int(self.split_percent[0] * len(loader.dataset))
        val_size = int(self.split_percent[1] * len(loader.dataset))
        test_size = len(loader.dataset) - train_size - val_size
        
        print(train_size,val_size,test_size)
        
        
        train_data, val_data, test_data = random_split(loader.dataset, [train_size, val_size, test_size])

        # Create data loaders for each set
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader


from tqdm import tqdm

class Ego_net_dataloader(Dataloader):
    def __init__(self, *args, **kwargs):
        super(Ego_net_dataloader, self).__init__(*args, **kwargs)

    def load_data(self):
        # Load data from .h5ad file and return a scanpy AnnData object
        adata = sc.read(self.file_path)
        return adata

    def construct_graph(self, adata):
        # Constructing graph from coordinates using scanpy's spatial_neighbors function
        images = np.unique(adata.obs[self.image_col])
        print(images)
        sub_g_ensemble = []
        for image in images:
            sub_adata = adata[adata.obs[self.image_col] == image].copy()
            sq.gr.spatial_neighbors(adata=sub_adata, radius=self.radius, key_added="adjacency_matrix", coord_type="generic")
            edge_index, _ = from_scipy_sparse_matrix(sub_adata.obsp['adjacency_matrix_connectivities'])

            # Create subgraphs for each node
            G = nx.Graph()

            # Add nodes with features to the graph
            print('Adding nodes...')
            for i, features in tqdm(enumerate(adata.X.toarray())):
                G.add_node(i, features=features)

            # Add edges to the graph
            print('Adding edges...')
            G.add_edges_from(edge_index.t().tolist())

            # Create subgraphs for each node of G
            print('Creating subgraphs...')
            subgraphs = [nx.ego_graph(G, node, radius=self.node_level) for node in tqdm(G.nodes())]

            sub_g_dataset = [from_networkx(graph, group_node_attrs=['features']) for graph in tqdm(subgraphs)]

            # Extend the ensemble with the new subgraphs
            sub_g_ensemble.extend(sub_g_dataset)

        loader = DataLoader(sub_g_ensemble, batch_size=32, shuffle=True)
        return loader


class Full_image_dataloader(Dataloader):
    def __init__(self, *args, **kwargs):
        super(Full_image_dataloader, self).__init__(*args, **kwargs)

    def load_data(self):
        # Load data from .h5ad file and return a scanpy AnnData object
        adata = sc.read(self.file_path)
        return adata

    def construct_graph(self, adata):
        # Constructing graph from coordinates using scanpy's spatial_neighbors function
        images = np.unique(adata.obs[self.image_col])
        
        graph_dict = {}
        for image in tqdm(images, desc="Constructing Graphs"):
            sub_adata = adata[adata.obs[self.image_col] == image].copy()
            sq.gr.spatial_neighbors(adata=sub_adata, radius=self.radius, key_added="adjacency_matrix", coord_type="generic")
            edge_index, _ = from_scipy_sparse_matrix(sub_adata.obsp['adjacency_matrix_connectivities'])

            # Construct graph
            G = nx.Graph()
            # Adding nodes
            for i, features in enumerate(sub_adata.X.toarray()):
                G.add_node(i, features=features)
            # Adding edges
            G.add_edges_from(edge_index.t().tolist())
            
            # Convert networkx graph to PyG format
            graph = from_networkx(G)
            graph_dict[image] = graph

        return graph_dict

    def split_data(self, graph_dict):
        # split by entire images
        images = list(graph_dict.keys())
        train_images, test_images = train_test_split(images, test_size=(self.split_percent[1] + self.split_percent[2]), random_state=42)
        val_images, test_images = train_test_split(test_images, test_size=self.split_percent[2]/(self.split_percent[1] + self.split_percent[2]), random_state=42)
        
        train_data = [graph_dict[image] for image in train_images]
        val_data = [graph_dict[image] for image in val_images]
        test_data = [graph_dict[image] for image in test_images]
        
        # Create data loaders for each set
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader

In [4]:
# Create an instance of Ego_net_dataloader

#file_path = "../data/img_119670929.h5ad"
file_path = "../example_files/subset_6img_atlas_brain.h5ad"
dataloader = Ego_net_dataloader(file_path=file_path, image_col="section", label_col="class_id_label", include_label=False, radius=20,node_level = 1, batch_size=32, split_percent=(0.7, 0.2, 0.1))

# Load the data
adata = dataloader.load_data()

# Construct the graph
loader = dataloader.construct_graph(adata)

['1199650929' '1199650932' '1199650935' '1199650938' '1199650941'
 '1199650944']
Adding nodes...


240945it [00:00, 345479.36it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:12<00:00, 19048.12it/s]
  data[key] = torch.tensor(value)
100%|██████████| 240945/240945 [01:41<00:00, 2374.93it/s]


Adding nodes...


240945it [00:00, 590633.63it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:12<00:00, 19625.11it/s]
100%|██████████| 240945/240945 [01:47<00:00, 2243.57it/s]


Adding nodes...


240945it [00:00, 620614.71it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:11<00:00, 21069.02it/s]
100%|██████████| 240945/240945 [01:44<00:00, 2307.00it/s]


Adding nodes...


240945it [00:00, 545722.11it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:12<00:00, 19706.23it/s]
100%|██████████| 240945/240945 [01:52<00:00, 2148.05it/s]


Adding nodes...


240945it [00:00, 543060.26it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:13<00:00, 18114.74it/s]
100%|██████████| 240945/240945 [02:00<00:00, 1991.75it/s]


Adding nodes...


240945it [00:00, 550056.08it/s]


Adding edges...
Creating subgraphs...


100%|██████████| 240945/240945 [00:19<00:00, 12264.00it/s]
100%|██████████| 240945/240945 [02:04<00:00, 1936.54it/s]


In [62]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

18361 5246 2623


In [65]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

77742 22212 11106


In [12]:

# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 574
Validation size: 82
Test size: 164


In [8]:
# Create an instance of Ego_net_dataloader

#file_path = "../data/img_119670929.h5ad"
file_path = "../data/subset_6img_atlas_brain.h5ad"
dataloader = Full_image_dataloader(file_path=file_path, image_col="section", label_col="class_id_label", include_label=False, radius=20,node_level = 1, batch_size=32, split_percent=(0.7, 0.2, 0.1))

# Load the data
adata = dataloader.load_data()

# Construct the graph
loader = dataloader.construct_graph(adata)

Constructing Graphs: 100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


In [11]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

In [12]:
# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 4
Validation size: 1
Test size: 1
