In [10]:
import networkx as nx
import numpy as np
import scanpy as sc
import squidpy as sq
from sklearn.metrics import r2_score
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.data import Data   # Create data containers
from torch_geometric.utils import from_networkx
from torch.utils.data import random_split
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

In [7]:
class Dataloader:
    def __init__(self, file_path, image_col ,label_col, include_label, radius,node_level, batch_size, split_percent):
        self.file_path = file_path
        self.image_col = image_col
        self.label_col = label_col
        self.node_level = node_level
        self.include_label = include_label
        self.radius = radius
        self.batch_size = batch_size
        self.split_percent = split_percent

    def load_data(self):
        raise NotImplementedError

    def construct_graph(self):
        raise NotImplementedError

    def split_data(self, loader):
        # Assuming split_percent is a tuple like (0.7, 0.2, 0.1)
        train_size = int(self.split_percent[0] * len(loader.dataset))
        val_size = int(self.split_percent[1] * len(loader.dataset))
        test_size = len(loader.dataset) - train_size - val_size
        
        print(train_size,val_size,test_size)
        
        
        train_data, val_data, test_data = random_split(loader.dataset, [train_size, val_size, test_size])

        # Create data loaders for each set
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader


from tqdm import tqdm

class Ego_net_dataloader(Dataloader):
    def __init__(self, *args, **kwargs):
        super(Ego_net_dataloader, self).__init__(*args, **kwargs)

    def load_data(self):
        # Load data from .h5ad file and return a scanpy AnnData object
        adata = sc.read(self.file_path)
        return adata

    def construct_graph(self, adata):
        # Constructing graph from coordinates using scanpy's spatial_neighbors function
        images = np.unique(adata.obs[self.image_col])
        print(images)
        sub_g_ensemble = []
        for image in images:
            sub_adata = adata[adata.obs[self.image_col] == image].copy()
            sq.gr.spatial_neighbors(adata=sub_adata, radius=self.radius, key_added="adjacency_matrix", coord_type="generic")
            edge_index, _ = from_scipy_sparse_matrix(sub_adata.obsp['adjacency_matrix_connectivities'])

            # Create subgraphs for each node
            G = nx.Graph()

            # Add nodes with features to the graph
            print('Adding nodes...')
            for i, features in tqdm(enumerate(adata.X.toarray())):
                G.add_node(i, features=features)

            # Add edges to the graph
            print('Adding edges...')
            G.add_edges_from(edge_index.t().tolist())

            # Create subgraphs for each node of G
            print('Creating subgraphs...')
            subgraphs = [nx.ego_graph(G, node, radius=self.node_level) for node in tqdm(G.nodes())]

            sub_g_dataset = [from_networkx(graph, group_node_attrs=['features']) for graph in tqdm(subgraphs)]

            # Extend the ensemble with the new subgraphs
            sub_g_ensemble.extend(sub_g_dataset)

        loader = DataLoader(sub_g_ensemble, batch_size=32, shuffle=True)
        return loader


class Full_image_dataloader(Dataloader):
    def __init__(self, *args, **kwargs):
        super(Full_image_dataloader, self).__init__(*args, **kwargs)

    def load_data(self):
        # Load data from .h5ad file and return a scanpy AnnData object
        adata = sc.read(self.file_path)
        return adata

    def construct_graph(self, adata):
        # Constructing graph from coordinates using scanpy's spatial_neighbors function
        images = np.unique(adata.obs[self.image_col])
        
        graph_dict = {}
        for image in tqdm(images, desc="Constructing Graphs"):
            sub_adata = adata[adata.obs[self.image_col] == image].copy()
            sq.gr.spatial_neighbors(adata=sub_adata, radius=self.radius, key_added="adjacency_matrix", coord_type="generic")
            edge_index, _ = from_scipy_sparse_matrix(sub_adata.obsp['adjacency_matrix_connectivities'])

            # Construct graph
            G = nx.Graph()
            # Adding nodes
            for i, features in enumerate(sub_adata.X.toarray()):
                G.add_node(i, features=features)
            # Adding edges
            G.add_edges_from(edge_index.t().tolist())
            
            # Convert networkx graph to PyG format
            graph = from_networkx(G)
            graph_dict[image] = graph

        return graph_dict

    def split_data(self, graph_dict):
        # split by entire images
        images = list(graph_dict.keys())
        train_images, test_images = train_test_split(images, test_size=(self.split_percent[1] + self.split_percent[2]), random_state=42)
        val_images, test_images = train_test_split(test_images, test_size=self.split_percent[2]/(self.split_percent[1] + self.split_percent[2]), random_state=42)
        
        train_data = [graph_dict[image] for image in train_images]
        val_data = [graph_dict[image] for image in val_images]
        test_data = [graph_dict[image] for image in test_images]
        
        # Create data loaders for each set
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader

In [64]:
# Create an instance of Ego_net_dataloader

#file_path = "../data/img_119670929.h5ad"
file_path = "../data/img_119670929_1199650932.h5ad"
dataloader = Ego_net_dataloader(file_path=file_path, image_col="section", label_col="class_id_label", include_label=False, radius=20,node_level = 1, batch_size=32, split_percent=(0.7, 0.2, 0.1))

# Load the data
adata = dataloader.load_data()

# Construct the graph
loader = dataloader.construct_graph(adata)

['1199650929' '1199650932']
Adding nodes...



55530it [00:00, 908983.30it/s]

Adding edges...





Creating subgraphs...



  0%|          | 0/55530 [00:00<?, ?it/s][A
  3%|▎         | 1761/55530 [00:00<00:03, 17608.67it/s][A
  6%|▋         | 3522/55530 [00:00<00:04, 12917.82it/s][A
  9%|▉         | 4888/55530 [00:00<00:04, 10531.33it/s][A
 11%|█         | 6005/55530 [00:00<00:04, 10028.84it/s][A
 13%|█▎        | 7040/55530 [00:00<00:05, 9583.68it/s] [A
 15%|█▍        | 8073/55530 [00:00<00:04, 9792.51it/s][A
 18%|█▊        | 10184/55530 [00:00<00:03, 13028.59it/s][A
 23%|██▎       | 12535/55530 [00:00<00:02, 16067.10it/s][A
 27%|██▋       | 14882/55530 [00:01<00:02, 18233.63it/s][A
 30%|███       | 16767/55530 [00:01<00:02, 17327.05it/s][A
 34%|███▍      | 19011/55530 [00:01<00:01, 18778.21it/s][A
 38%|███▊      | 21262/55530 [00:01<00:01, 19854.50it/s][A
 42%|████▏     | 23286/55530 [00:01<00:01, 19816.25it/s][A
 46%|████▌     | 25294/55530 [00:01<00:01, 17467.44it/s][A
 51%|█████     | 28144/55530 [00:01<00:01, 20429.90it/s][A
 63%|██████▎   | 35134/55530 [00:01<00:00, 34146.42it/s][A
 

Adding nodes...



55530it [00:00, 888229.26it/s]

Adding edges...





Creating subgraphs...



  0%|          | 0/55530 [00:00<?, ?it/s][A
  3%|▎         | 1800/55530 [00:00<00:02, 17999.29it/s][A
  6%|▋         | 3600/55530 [00:00<00:03, 14493.84it/s][A
  9%|▉         | 5092/55530 [00:00<00:04, 11469.61it/s][A
 11%|█▏        | 6306/55530 [00:00<00:04, 10573.68it/s][A
 13%|█▎        | 7400/55530 [00:00<00:04, 9709.45it/s] [A
 15%|█▌        | 8392/55530 [00:00<00:05, 9245.51it/s][A
 17%|█▋        | 9326/55530 [00:00<00:05, 8961.60it/s][A
 20%|██        | 11146/55530 [00:01<00:03, 11459.98it/s][A
 24%|██▍       | 13308/55530 [00:01<00:02, 14289.58it/s][A
 28%|██▊       | 15560/55530 [00:01<00:02, 16632.17it/s][A
 31%|███       | 17289/55530 [00:01<00:02, 16621.01it/s][A
 34%|███▍      | 19088/55530 [00:01<00:02, 17016.73it/s][A
 38%|███▊      | 21264/55530 [00:01<00:01, 18403.11it/s][A
 42%|████▏     | 23503/55530 [00:01<00:01, 19577.41it/s][A
 46%|████▌     | 25595/55530 [00:01<00:01, 19974.54it/s][A
 50%|████▉     | 27609/55530 [00:01<00:01, 17858.79it/s][A
 53

In [62]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

18361 5246 2623


In [65]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

77742 22212 11106


In [12]:

# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 574
Validation size: 82
Test size: 164


In [8]:
# Create an instance of Ego_net_dataloader

#file_path = "../data/img_119670929.h5ad"
file_path = "../data/subset_6img_atlas_brain.h5ad"
dataloader = Full_image_dataloader(file_path=file_path, image_col="section", label_col="class_id_label", include_label=False, radius=20,node_level = 1, batch_size=32, split_percent=(0.7, 0.2, 0.1))

# Load the data
adata = dataloader.load_data()

# Construct the graph
loader = dataloader.construct_graph(adata)

Constructing Graphs: 100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


In [11]:
train_loader, test_loader, val_loader = dataloader.split_data(loader)

In [12]:
# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 4
Validation size: 1
Test size: 1
