In [2]:
import networkx as nx
import numpy as np
import scanpy as sc
import squidpy as sq
from sklearn.metrics import r2_score
from torch_geometric.nn import GCNConv, Sequential
from torch_geometric.data import Data   # Create data containers
from torch_geometric.utils import from_networkx

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from tqdm.auto import tqdm

In [3]:
adata = sc.read("../data/img_119670929.h5ad")

In [4]:
sq.gr.spatial_neighbors(adata=adata, radius=20, key_added="adjacency_matrix", coord_type="generic")
edge_index, edge_weight = from_scipy_sparse_matrix(adata.obsp["adjacency_matrix_connectivities"])
x = torch.tensor(adata.X.toarray(), dtype=torch.double)
print(f"mean node degree: {edge_index.shape[1]/len(adata):.1f}")

mean node degree: 3.9


In [5]:
data = Data(x=x, edge_index=edge_index)
#data
#edge_index

In [6]:
subgraph(torch.tensor([0, 10, 33]), edge_index=edge_index)

(tensor([], size=(2, 0), dtype=torch.int64), None)

In [7]:
# we want to create small subgraph using each node as the center

In [8]:
# Create a large graph
G = nx.Graph()

# Add nodes with features to the graph
for i, features in enumerate(adata.X.toarray()):
    G.add_node(i, features=features)

# Add edges to the graph
G.add_edges_from(edge_index.t().tolist())

In [9]:
# create subgraphs from each node of G using networkx
subgraphs = []
for node in tqdm(G.nodes()):
    subgraphs.append(nx.ego_graph(G, node, radius = 1))

100%|██████████| 26230/26230 [00:04<00:00, 6029.69it/s] 


In [43]:
# mean number of nodes per subgraph
np.mean([graph.number_of_nodes() for graph in subgraphs])

4.924590163934426

In [44]:
#torch.tensor(list(subgraphs[0].edges)).t()

In [18]:
# create pytorch geometric dataset from subgraphs
#datasss = [Data(x=torch.tensor(graph.nodes(data="features"), dtype=torch.double), edge_index=torch.tensor(list(graph.edges)).t()) for graph in tqdm(subgraphs)]

#list(subgraphs[0].features)
daata = [from_networkx(graph, group_node_attrs=['features']) for graph in tqdm(subgraphs)]
loader = DataLoader(daata, batch_size=32, shuffle=True)
#for daat in daata:
#loader = DataLoader([from_networkx(graph, group_node_attrs=['features'], dtype=torch.double) for graph in tqdm(subgraphs)], batch_size=32, shuffle=True)


#loader = DataLoader([Data(x=x, edge_index=torch.tensor(list(subgraphs[0].edges)).t()) for graph in subgraphs], batch_size=32)

100%|██████████| 26230/26230 [00:27<00:00, 947.63it/s] 


In [21]:
from torch.utils.data import random_split

# Split the data into training, validation, and test sets
train_size = int(0.8 * len(loader.dataset))  # 80% of the data for training
val_size = int(0.1 * len(loader.dataset))  # 10% of the data for validation
test_size = len(loader.dataset) - train_size - val_size  # The rest for testing

train_data, val_data, test_data = random_split(loader.dataset, [train_size, val_size, test_size])

# Create data loaders for each set
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)


In [11]:
for data in loader:
    print(data)
    break


DataBatch(edge_index=[2, 146], x=[88, 550], batch=[88], ptr=[33])


In [12]:
import torch
from torch_geometric.nn import GCN, summary

#model = GCN(-1, 64, num_layers=2, out_channels=550)
#x = torch.randn(100, 128)
#edge_index = torch.randint(100, size=(2, 20))

#print(summary(model, data.x, data.edge_index))

In [19]:

from torch import nn, optim, Tensor
from torch_geometric.nn import conv


# Define the Graph Convolutional Network (GCN) model
class GCNClassifier(torch.nn.Module):
    def __init__(self, hidden_dim, hidden_dim1, output_dim):
        super(GCNClassifier, self).__init__()
        self.model = nn.Sequential(
            conv.SAGEConv(-1, hidden_dim),
            nn.ReLU(),
            conv.GCNConv(hidden_dim, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, output_dim),
        )

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        return self.model(x, edge_index)

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.linear = nn.Linear(out_channels, 550)


    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        x = self.conv1(x, edge_index).relu()
        #print(x.shape)
        x = self.conv2(x, edge_index)
        #print(x.shape)
        x = self.linear(x)
        #print(x.shape)
        return x

In [None]:
import random
# Set device for training, macbook
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create an instance of the model and move it to the device

output_dim = 550

# Create the model
model = GCN(-1, 64, output_dim).to(device)

#GCNClassifier(hidden_dim=100, hidden_dim1=100, output_dim=550).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 50

# store losses
train_losses = []
val_losses = []

# store r2 scores
train_r2_scores = []
val_r2_scores = []

best_val_loss = float('inf') # Set initial best validation loss to infinity
patience = 5                # Number of epochs to wait for improvement in validation loss
epochs_no_improve = 0        # Number of epochs with no improvement in validation loss
best_epoch = 0               # Epoch at which we get the best validation loss

# epoch training times
epoch_times = []
#start_time = time.time()

for epoch in range(num_epochs):
    
    # Training phase
    model.train()
    total_train_loss = 0
    train_targets_list = []
    train_outputs_list = []
    model.train()  # Set the model to training mode
    total_loss = 0
    targets_list = []
    outputs_list = []

    for data in loader:
        data = data.to(device)
        num_nodes = data.x.shape[0]
        num_nodes_to_mask = int(0.2 * num_nodes)
        nodes_to_mask = random.sample(range(num_nodes), num_nodes_to_mask)
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[nodes_to_mask] = True
        masked_node_features = data.x.float() * mask.unsqueeze(-1).float().to(device)


        outputs = model(masked_node_features, data.edge_index.long())
        loss = criterion(outputs, data.x.float())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Measure train loss and r2 score
        total_loss += loss.item() * data.num_graphs
        targets_list.append(data.x.float())
        outputs_list.append(outputs)

    #measure and print r2 and train loss
    train_loss = total_loss / len(loader.dataset)
    train_losses.append(train_loss)
    train_r2 = r2_score(torch.cat(targets_list).cpu().detach().numpy(), torch.cat(outputs_list).cpu().detach().numpy())
    train_r2_scores.append(train_r2)

        # Validation phase
    model.eval()  # Set the model to evaluation mode
    total_val_loss = 0
    val_targets_list = []
    val_outputs_list = []

    for data in val_loader:
        data = data.to(device)
        num_nodes = data.x.shape[0]
        num_nodes_to_mask = int(0.2 * num_nodes)
        nodes_to_mask = random.sample(range(num_nodes), num_nodes_to_mask)
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[nodes_to_mask] = True
        masked_node_features = data.x.float() * mask.unsqueeze(-1).float().to(device)
        
        with torch.no_grad():
            outputs = model(masked_node_features, data.edge_index.long())
            loss = criterion(outputs, data.x.float())
        total_val_loss += loss.item() * data.num_graphs
        val_targets_list.append(data.x.float())
        val_outputs_list.append(outputs)

    # Measure and print validation loss and R2
    val_loss = total_val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    val_r2 = r2_score(torch.cat(val_targets_list).cpu().detach().numpy(), torch.cat(val_outputs_list).cpu().detach().numpy())
    val_r2_scores.append(val_r2)
    print(f"Epoch {epoch+1}/{num_epochs} , train loss: {train_loss:.4f}, train r2: {train_r2:.4f} ,  val loss: {val_loss:.4f}, val r2: {val_r2:.4f}")


Epoch 1/50 , train loss: 0.5907, train r2: 0.0820 ,  val loss: 0.5621, val r2: 0.1031
Epoch 2/50 , train loss: 0.5548, train r2: 0.1098 ,  val loss: 0.5503, val r2: 0.1139
Epoch 3/50 , train loss: 0.5490, train r2: 0.1144 ,  val loss: 0.5507, val r2: 0.1130
Epoch 4/50 , train loss: 0.5464, train r2: 0.1168 ,  val loss: 0.5476, val r2: 0.1149
Epoch 5/50 , train loss: 0.5443, train r2: 0.1186 ,  val loss: 0.5431, val r2: 0.1199
Epoch 6/50 , train loss: 0.5435, train r2: 0.1195 ,  val loss: 0.5401, val r2: 0.1229
Epoch 7/50 , train loss: 0.5421, train r2: 0.1212 ,  val loss: 0.5463, val r2: 0.1179
Epoch 8/50 , train loss: 0.5421, train r2: 0.1214 ,  val loss: 0.5409, val r2: 0.1209
Epoch 9/50 , train loss: 0.5403, train r2: 0.1227 ,  val loss: 0.5367, val r2: 0.1250
Epoch 10/50 , train loss: 0.5411, train r2: 0.1223 ,  val loss: 0.5437, val r2: 0.1195
Epoch 11/50 , train loss: 0.5396, train r2: 0.1236 ,  val loss: 0.5470, val r2: 0.1164
Epoch 12/50 , train loss: 0.5393, train r2: 0.1241 ,

In [24]:
# Testing phase
model.eval()
total_test_loss = 0
test_targets_list = []
test_outputs_list = []

for data in test_loader:
    data = data.to(device)
    with torch.no_grad():
        outputs = model(data.x.float(), data.edge_index.long())
        loss = criterion(outputs, data.x.float())
    total_test_loss += loss.item() * data.num_graphs
    test_targets_list.append(data.x.float())
    test_outputs_list.append(outputs)

# Measure and print test loss and R2
test_loss = total_test_loss / len(test_loader.dataset)
test_r2 = r2_score(torch.cat(test_targets_list).cpu().detach().numpy(), torch.cat(test_outputs_list).cpu().detach().numpy())
print(f"Test loss: {test_loss:.4f}, test r2: {test_r2:.4f}")

Test loss: 0.5162, test r2: 0.1203
