## Import packages

In [1]:
# ignore scanpy warnings
import warnings

warnings.filterwarnings("ignore")

import scipy
import networkx as nx
import numpy as np
import squidpy as sq
import scanpy as sc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, GCNConv, MLP
from torch_geometric.nn.dense.linear import Linear
from tqdm import tqdm

## Load dataset

In [None]:
adata = sq.datasets.merfish()
adata

## Preprocessing
Our data is stored in Anndata format, which behaves similar to pandas DataFrames. We can access the data matrix using the `.X` attribute. We can also access the cell and gene names using the `.obs_names` and `.var_names` attributes, respectively.

In [None]:
# select one slice
bregma = 1
adata = adata[adata.obs["Bregma"] == bregma, :].copy()

# filtering
sc.pp.filter_cells(adata, min_counts=10)
sc.pp.filter_genes(adata, min_cells=5)

# normalization
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, inplace=True)
sc.pp.log1p(adata)

In [None]:
top_gene_idx = adata.X.toarray().sum(axis=0).argmax()
top_gene = adata.var_names[top_gene_idx]
print(f"The gene with the highest total expression is {top_gene}")

In [None]:
sc.pl.spatial(adata, color=["Cell_class", top_gene], spot_size=0.01)

## Construct graph by connecting each cells to its k nearest neighbors

Here we want to connect each node to its k nearest neighbors, while ensuring the graph is undirected

In [None]:
k = 10

coords = adata.obsm["spatial"]
kdtree = scipy.spatial.KDTree(coords)
distances, indices = kdtree.query(coords, k=k + 1, p=2)

In [None]:
print("Indices:\n", indices)
print(f"Indices shape: {indices.shape}")

In [None]:
edge_index = torch.cat(
    [
        torch.tensor(indices.flatten())[None, :],  # source
        torch.arange(0, coords.shape[0]).repeat_interleave(k + 1)[None, :],  # target
    ],
    axis=0,
)
edge_weight = torch.tensor(distances.flatten()).unsqueeze(-1).to(torch.float32)

In [None]:
edge_index, edge_weight = pyg.utils.to_undirected(edge_index, edge_weight)

In [None]:
print(f"The graph is undirected: {pyg.utils.is_undirected(edge_index)}")

Lets visualize the graph using networkx

In [None]:
data = Data(edge_index=edge_index, num_nodes=coords.shape[0])
g = pyg.utils.to_networkx(data, to_undirected=True)

In [None]:
# Remove self-loops for better visualization
g.remove_edges_from(nx.selfloop_edges(g))

plt.figure(figsize=(10, 10))
nx.draw(g, pos=coords, node_size=10, width=0.3)

## Task 1: construct a graph by connecting cells only if they are within a certain distance of each other
Hint: you can loop over all possible pairs of cells using a nested for loop and calculate the distance between them using the `np.linalg.norm` function

In [None]:
radius = 0.1

# define empty matrices to hold the results
dist_mat = np.zeros((coords.shape[0], coords.shape[0]))
adj_mat = np.zeros((coords.shape[0], coords.shape[0]))

In [None]:
"""
insert your code here
"""

Scipy provides some optimized functions for this task.

In [None]:
kdtree = scipy.spatial.KDTree(coords)
dist_mat_hat = kdtree.sparse_distance_matrix(kdtree, radius, p=2)
dist_mat_hat = scipy.sparse.csr_matrix(dist_mat_hat)
adj_mat_hat = (dist_mat_hat > 0).toarray().astype(int)

In [None]:
print(
    "The results are the same:",
    (adj_mat_hat + np.eye(adj_mat_hat.shape[0]) == adj_mat).all(),
)

print(
    "The results are the same:",
    (dist_mat_hat.toarray() == dist_mat).all(),
)

Execute the cell below if you want to run the subsequent computation on the distance graph instead of the nearest neighbor graph

In [None]:
# edge_index, edge_weight = pyg.utils.from_scipy_sparse_matrix(dist_mat_hat)

## Task 2: aggregate the gene expression per neighborhood
This is our target per cell neighborhood for the graph autoencoder

In a first version we perform the aggregation with a simple matrix multiplication (for the distance-based case we could have directly used the adj mat.)  
Hint 1: you can convert the edge_index to a sparse matrix using the `pyg.utils.to_scipy_sparse_matrix` function  
Hint 2: the gene expression is stored in the `.X` attribute of the Anndata object  
Hint 3: make sure to normalize the adjacency matrix by the degree of each node

In [None]:
"""
insert your code here, name the result X_agg
"""

We can perform the same cell aggregation in the Pytorch Geometric Message Passing Framework

In [None]:
class GraphAggregation(MessagePassing):
    """
    GraphAggregation class for aggregating node features in a graph.

    Args:
        aggr (str): Aggregation method to use. Default is "mean".

    """

    def __init__(self, aggr="mean"):
        super(GraphAggregation, self).__init__(aggr=aggr)

    def forward(self, x, edge_index, **kwargs):
        """
        Forward pass of the GraphAggregation module.

        Args:
            x (Tensor): Node features.
            edge_index (LongTensor): Graph edge indices.

        Returns:
            Tensor: Aggregated node features.

        """
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        """
        Message function for the GraphAggregation module.

        Args:
            x_j (Tensor): Node features of neighboring nodes.

        Returns:
            Tensor: The input node features.

        """
        return x_j

In [None]:
mean_agg = GraphAggregation(aggr="mean")
X = torch.Tensor(adata.X.toarray())
X_agg_pyg = mean_agg(X, edge_index)

In [None]:
print(
    "Results are the same (up to numeric error):",
    np.allclose(X_agg.toarray(), X_agg_pyg.numpy(), atol=1e-9),
)

## Task 3: define your own graph convolutional network in Pytorch Geometric

The equation for the graph convolutional layer is given by:
$$
H^{i+1} = (\hat{D}^{-\frac{1}{2}} \hat{A} \hat{D}^{-\frac{1}{2}}) H^{i} W + b
$$
However, you can ignore the normalization for now (already handled by `gcn_norm`) and just implement the last part of the equation:
$$
H^{i+1} = \dots H^{i} W + b
$$

In [None]:
def gcn_norm(edge_index, num_nodes, edge_weight=None, dtype=None):
    """
    Applies graph convolutional network (GCN) normalization to the given edge index and edge weight.

    Args:
        edge_index (Tensor): The edge index tensor of shape (2, num_edges) representing the connectivity of the graph.
        num_nodes (int): The total number of nodes in the graph.
        edge_weight (Tensor, optional): The edge weight tensor of shape (num_edges,) representing the weight of each edge. Defaults to None.
        dtype (torch.dtype, optional): The desired data type of the edge weight tensor. Defaults to None.

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing the updated edge index tensor and the normalized edge weight tensor.

    """
    if edge_weight is None:
        edge_weight = torch.ones(
            (edge_index.size(1),), device=edge_index.device, dtype=dtype
        )

    row, col = edge_index[0], edge_index[1]
    idx = col
    deg = pyg.utils.scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce="sum")

    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float("inf"), 0)
    edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    return edge_index, edge_weight

In [3]:
class GCNLayer(MessagePassing):
    """Graph Convolutional Network (GCN) layer implementation.

    This class represents a single layer of a Graph Convolutional Network (GCN).
    It performs message passing and aggregation operations on a graph.

    Args:
        in_channels (int): Number of input channels/features.
        out_channels (int): Number of output channels/features.
        normalize (bool, optional): Whether to normalize the edge weights. Defaults to True.
        aggr (str, optional): Aggregation method for message passing. Defaults to "add".

    Attributes:
        linear (torch.nn.Linear): Linear transformation layer.
        bias (torch.nn.Parameter): Bias parameter.

    """

    def __init__(self, in_channels, out_channels, normalize=True, aggr="add"):
        super(GCNLayer, self).__init__(aggr=aggr)
        self.normalize = normalize

        self.linear = Linear(
            in_channels, out_channels, bias=False, weight_initializer="glorot"
        )
        self.bias = Parameter(torch.zeros(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.linear.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        if self.normalize:
            edge_index, edge_weight = gcn_norm(
                edge_index, x.size(0), edge_weight, dtype=x.dtype
            )
        """
        Here you need to implement the missing parts of the forward pass of the GCN layer. you need to use the self.linear and self.bias
        """

        # here something is missing
        x = ...

        # this part is correct
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)

        # here something is missing
        out = ...
        return out

    def message(self, x_i, x_j, edge_weight=None):
        # the message function allows you to access the node features of the source nodes (x_i) and the target nodes (x_j)
        # any other node level attributes that you pass to the propagate function can also be accessed here
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

In [None]:
# Set the seed for reproducibility
torch.manual_seed(42)

gcn_layer_custom = GCNLayer(adata.X.shape[1], 32)
h_custom = gcn_layer_custom(X_agg_pyg, edge_index)
h_custom

### Check that it gives the same result as the built-in GCN implementation

In [None]:
# Set the seed for reproducibility
torch.manual_seed(42)

gcn_layer = GCNConv(adata.X.shape[1], 32, add_self_loops=False)
h = gcn_layer(X_agg_pyg, edge_index)
h

In [None]:
print("Results are the same:", torch.allclose(h, h_custom, atol=1e-9))

## Bonus 1: formulate the same architecture in native Pytorch using matrix multiplications
Hint: for simplicity you can reuse the gcn_norm function and convert the adjacency matrix to a tensor 

## Bonus 2: implement a more general Message Passing Layer
Hint: adapt the Pytorch Geometric Message Passing class according to the general message passing equation

## Define your node-level graph autoencoder
This model takes as input a cell graph with gene expression features and learns a latent representation of the cell neighborhood by reconstructing the neighborhood gene expression

In [None]:
class GraphEncoder(nn.Module):
    """GraphEncoder is a class that represents a graph encoder module.

    Args:
        in_channels (int): The number of input channels.
        hidden_channels (int): The number of hidden channels.
        n_layers (int, optional): The number of graph convolutional layers. Defaults to 2.
        normalize (bool, optional): Whether to apply normalization. Defaults to True.
    """

    def __init__(self, in_channels, hidden_channels, n_layers=2, normalize=True):
        super(GraphEncoder, self).__init__()
        self.linear = Linear(in_channels, hidden_channels)
        self.convs = nn.ModuleList(
            [
                GCNLayer(hidden_channels, hidden_channels, normalize=normalize)
                for _ in range(n_layers)
            ]
        )

    def forward(self, x, edge_index, edge_weight=None):
        x = self.linear(x)
        for conv in self.convs:
            x = conv(x, edge_index, edge_weight)
            x = F.relu(x)
        return x

In [None]:
class Decoder(nn.Module):
    """Decoder module for the GNN model, uses the Pytorch Geometric MLP for convenience.

    Args:
        in_channels (int): Number of input channels.
        hidden_channels (int): Number of hidden channels.
        out_channels (int): Number of output channels.
        n_layers (int, optional): Number of MLP layers. Defaults to 2.
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, n_layers=2, **kwargs
    ):
        super(Decoder, self).__init__()
        self.mlp = MLP(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            num_layers=n_layers,
            plain_last=False,
            norm=None,
            **kwargs,
        )

    def forward(self, x):
        return self.mlp(x)

In [None]:
class GraphAutoEncoder(nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        n_layers_encoder=1,
        n_layers_decoder=1,
        **kwargs,
    ):
        """Initializes a GraphAutoEncoder object.

        Args:
            in_channels (int): Number of input channels.
            hidden_channels (int): Number of hidden channels.
            out_channels (int): Number of output channels.
            n_layers_encoder (int, optional): Number of layers in the encoder. Defaults to 1.
            n_layers_decoder (int, optional): Number of layers in the decoder. Defaults to 1.
            **kwargs: Additional keyword arguments.

        Returns:
            None
        """
        super(GraphAutoEncoder, self).__init__()
        self.encoder = GraphEncoder(
            in_channels, hidden_channels, n_layers=n_layers_encoder, **kwargs
        )
        self.decoder = Decoder(
            hidden_channels, hidden_channels, out_channels, n_layers=n_layers_decoder
        )

    def forward(self, x, edge_index, edge_weight=None):
        h = self.encoder(x, edge_index, edge_weight)
        x = self.decoder(h)
        return x

## Train your model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"You are using {device}")

In [None]:
seed = 42
n_epochs = 400
n_genes = adata.X.shape[1]
n_layers_encoder = 1
n_layers_decoder = 1

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

model = GraphAutoEncoder(n_genes, 32, n_genes, n_layers_encoder, n_layers_decoder)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

for epoch in range(n_epochs):
    model.train()
    optimizer.zero_grad()
    out = model(X_agg_pyg, edge_index, edge_weight)
    loss = criterion(out, X_agg_pyg)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

## Task 4: Extract and visualize the latent embeddings
Hint: our model has an encoder that maps the input gene expression to the latent space. We can use this encoder to extract the latent embeddings of the cells

In [None]:
model.eval()

In [None]:
"""
insert your code here
"""

In [None]:
print(f"Encoded features shape: {h.shape}")

In [None]:
adata.obsm["X_gnn"] = h

Hint to visualize the embeddings you can use PCA or UMAP

In [None]:
"""
insert your code here
"""

## Define spatial domains via Leiden clustering

In [None]:
sc.pp.neighbors(adata, use_rep="X_gnn")
sc.tl.leiden(adata, resolution=0.1)

## Visualize the spatial domains

In [None]:
sc.pl.spatial(adata, color="leiden", spot_size=0.01)

## Task 5: Analyze the cell type proportions for each spatial domain
Hint: the only data you need is given in the below DataFrame

In [None]:
df = adata.obs[["leiden", "Cell_class"]]

## Further reading
1. [Geometric deep learning resources](https://geometricdeeplearning.com/)
2. [Graph (variational) autoencoder paper](https://arxiv.org/abs/1611.07308)
3. [Pytorch Geometric documentation](https://pytorch-geometric.readthedocs.io/en/latest/)
4. [Google tuning playbook](https://github.com/google-research/tuning_playbook)