In [1]:
import networkx as nx
import numpy as np
import scanpy as sc
import squidpy as sq
import torch
from sklearn.metrics import r2_score
from torch import Tensor
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, conv
from torch_geometric.utils import from_networkx
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from torch import nn, optim
from tqdm.auto import tqdm

In [2]:
adata = sc.read("../example_files/img_1199670929.h5ad")

In [24]:
sq.gr.spatial_neighbors(adata=adata, radius=30, key_added="adjacency_matrix", coord_type="generic")
edge_index, edge_weight = from_scipy_sparse_matrix(adata.obsp["adjacency_matrix_connectivities"])
x = torch.tensor(adata.X.toarray(), dtype=torch.double)
print(f"mean node degree: {edge_index.shape[1]/len(adata):.1f}")

mean node degree: 8.9


In [4]:
# Create a large graph
G = nx.Graph()

# Add nodes with features to the graph
for i, features in enumerate(adata.X.toarray()):
    G.add_node(i, features=features)

# Add edges to the graph
G.add_edges_from(edge_index.t().tolist())

In [5]:
# create subgraphs from each node of G using networkx
subgraphs = []
for node in tqdm(G.nodes()):
    subgraphs.append(nx.ego_graph(G, node, radius = 1))

  0%|          | 0/26230 [00:00<?, ?it/s]

In [25]:
# mean number of nodes per subgraph
print(f"Mean number of nodes per subgraph {np.mean([graph.number_of_nodes() for graph in subgraphs]):.2f}")

Mean number of nodes per subgraph 4.92


In [10]:
loader = DataLoader([from_networkx(graph, group_node_attrs=['features']) for graph in tqdm(subgraphs)], batch_size=64, shuffle=True)

  0%|          | 0/26230 [00:00<?, ?it/s]

  data[key] = torch.tensor(value)


In [21]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        #self.conv2 = GCNConv(hidden_channels, out_channels)
        self.linear = nn.Linear(hidden_channels, 550)


    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        x = self.conv1(x, edge_index).relu()
        #print(x.shape)
        #x = self.conv2(x, edge_index).relu()
        #print(x.shape)
        x = self.linear(x)
        #print(x.shape)
        return x

In [23]:
import random
# Set device for training, macbook
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create an instance of the model and move it to the device
output_dim = 550

# Create the model
model = GCN(550, 64, output_dim).to(device)

display(model)


#GCNClassifier(hidden_dim=100, hidden_dim1=100, output_dim=550).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 1000

# store losses
train_losses = []
val_losses = []

# store r2 scores
train_r2_scores = []
val_r2_scores = []

best_val_loss = float('inf') # Set initial best validation loss to infinity
patience = 5                 # Number of epochs to wait for improvement in validation loss
epochs_no_improve = 0        # Number of epochs with no improvement in validation loss
best_epoch = 0               # Epoch at which we get the best validation loss

# epoch training times
epoch_times = []

for epoch in range(num_epochs):

    model.train()  # Set the model to training mode
    total_loss = 0
    targets_list = []
    outputs_list = []

    for data in loader:

        data = data.to(device)

        # Mask 20% of node features
        num_nodes = data.x.shape[0]
        num_nodes_to_mask = int(0.2 * num_nodes)
        nodes_to_mask = random.sample(range(num_nodes), num_nodes_to_mask)
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[nodes_to_mask] = True
        masked_node_features = data.x.float() * mask.unsqueeze(-1).float()


        outputs = model(masked_node_features, data.edge_index.long())
        loss = criterion(outputs, data.x.float())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Measure train loss and r2 score
        total_loss += loss.item() * data.num_graphs
        targets_list.append(data.x.float())
        outputs_list.append(outputs)

    #measure and print r2 and train loss
    train_loss = total_loss / len(loader.dataset)
    train_losses.append(train_loss)
    train_r2 = r2_score(torch.cat(targets_list).cpu().detach().numpy(), torch.cat(outputs_list).cpu().detach().numpy())
    train_r2_scores.append(train_r2)
    print(f"Epoch {epoch+1}/{num_epochs}, train loss: {train_loss:.4f}, train r2: {train_r2:.4f}")

GCN(
  (conv1): GCNConv(550, 64)
  (linear): Linear(in_features=64, out_features=550, bias=True)
)

Epoch 1/1000, train loss: 1.0145, train r2: -0.2945
Epoch 2/1000, train loss: 0.8425, train r2: -0.1283
Epoch 3/1000, train loss: 0.7749, train r2: -0.0625
Epoch 4/1000, train loss: 0.7092, train r2: -0.0072
Epoch 5/1000, train loss: 0.6498, train r2: 0.0351
Epoch 6/1000, train loss: 0.6150, train r2: 0.0569
Epoch 7/1000, train loss: 0.6016, train r2: 0.0664
Epoch 8/1000, train loss: 0.5931, train r2: 0.0739
Epoch 9/1000, train loss: 0.5904, train r2: 0.0779
Epoch 10/1000, train loss: 0.5869, train r2: 0.0816
Epoch 11/1000, train loss: 0.5847, train r2: 0.0850
Epoch 12/1000, train loss: 0.5822, train r2: 0.0875
Epoch 13/1000, train loss: 0.5805, train r2: 0.0903
Epoch 14/1000, train loss: 0.5803, train r2: 0.0913
Epoch 15/1000, train loss: 0.5777, train r2: 0.0936
Epoch 16/1000, train loss: 0.5766, train r2: 0.0952
Epoch 17/1000, train loss: 0.5751, train r2: 0.0968
Epoch 18/1000, train loss: 0.5742, train r2: 0.0980
Epoch 19/1000, train loss: 0.5730, train r2: 0.0994
Epoch 20/1000, tr

KeyboardInterrupt: 