In [2]:
import scanpy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GCNConv, MessagePassing
import random
from matplotlib.pyplot import rc_context

In [3]:
# Getting data
adata = scanpy.read_h5ad('adata_head_S_v1.0.h5ad')

sampled_indices_obs = adata.obs.sample(frac=0.1, random_state=42).index
sampled_indices_var = adata.var.sample(frac=0.1, random_state=42).index
sampled_data = adata[sampled_indices_obs, sampled_indices_var]

In [4]:
# Import the gene-gene correlation matrix
gene_gene_interactions = np.load('gg_interactions.npy')

In [5]:
# Converting & filtering
# Load the mapping file of protein IDs to preferred names
mapping_file = '7227.protein.info.v11.5.txt'  # Path to your mapping file
protein_mapping = {}

# Read the mapping file and populate the protein_mapping dictionary
with open(mapping_file, 'r') as file:
    next(file)  # Skip the header line if present
    for line in file:
        protein_id, preferred_name, _, _ = line.strip().split('\t')
        protein_mapping[preferred_name] = protein_id

# Remove genes not in the interaction matrix from sampled_data
sampled_data = sampled_data[:, [protein_mapping.get(gene, None) is not None for gene in sampled_data.var_names]]

# Update the var_names in sampled_data with protein IDs
sampled_data.var_names = [protein_mapping.get(gene, None) for gene in sampled_data.var_names]

# Print the number of genes before and after filtering
print("Number of genes before filtering:", sampled_data.X.shape[1])
print("Number of genes after filtering:", sampled_data.shape[1])


Number of genes before filtering: 1184
Number of genes after filtering: 1184


In [6]:
# Create a set of genes present in sampled_data.var_names
selected_genes = set(sampled_data.var_names)

# Find the indices of selected genes in the gene_gene_interactions matrix
gene_indices = [i for i, gene in enumerate(gene_gene_interactions[0]) if gene in selected_genes]

# Filter gene_gene_interactions matrix
filtered_interactions = gene_gene_interactions[gene_indices][:, gene_indices]
filtered_interactions = filtered_interactions.astype(float)

In [31]:
## Defining and training the model
gene_expression = sampled_data.X.toarray()  # Gene expression data, shape: (num_cells, num_genes)
gene_names = sampled_data.var_names # List of gene names in the same order as gene_expression

adjacency_matrix = filtered_interactions # Matrix is not weighted - should weight it

# Convert elements of adjacency_matrix to a numeric type
adjacency_matrix = np.array(adjacency_matrix, dtype=np.float32)

adjacency_tensor = torch.tensor(adjacency_matrix, dtype=torch.float)
gene_expression_tensor = torch.tensor(gene_expression, dtype=torch.float)

edge_index = torch.nonzero(adjacency_tensor != 0, as_tuple=False).t()

# Set y to be either sex or age
sex_matrix = np.array([0 if sex == "female" else 1 for sex in np.array(sampled_data.obs.sex)])
age_matrix = np.array([0 if age == "5" else 1 if age == "30" else 2 if age == "50" else 3 for age in np.array(sampled_data.obs.age)])
data = Data(x=gene_expression_tensor, edge_index=edge_index, edge_attr=adjacency_tensor, y=age_matrix)

In [33]:
class GNNModel(MessagePassing):
    def __init__(self, num_genes, hidden_dim, num_classes):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(num_genes, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        print("x", x)
        print("softmax", F.log_softmax(x, dim=1))
        return F.log_softmax(x, dim=1)
    
# Extract the number of genes from the dataset
num_genes = sampled_data.X.shape[1]
hidden_dim = 256
# Extract the number of cell types/classes from the dataset
num_classes = len(np.unique(data.y))
model = GNNModel(num_genes, hidden_dim, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Split the dataset into train, validation, and test sets
transform = RandomLinkSplit(is_undirected=True, num_val=0.1, num_test=0.1)
train_data, val_data, test_data = transform(data)

# Print the sizes of the splits
print("Train set size:", train_data.num_edges)
print("Validation set size:", val_data.num_edges)
print("Test set size:", test_data.num_edges)

# Training loop
model.train()
num_epochs = 100
for epoch in range(num_epochs):
    # print("training")
    optimizer.zero_grad()
    out = model(train_data.x, train_data.edge_index)
    # print("out", out)
    out = out.float().squeeze()
    target = torch.from_numpy(train_data.y)
    # since out is a cells by classes matrix and target is just a one dimensional matrix with the correct class for each cell,
    # nll_loss would already be comparing the different probabilities of each class to find the loss
    loss = F.nll_loss(out, target)
    loss.backward()
    optimizer.step()
    print(f'Loss: {loss:.4f}')
    
# Evaluation
model.eval()
pred = model(test_data.x, test_data.edge_index).argmax(dim=1)
correct = (pred == torch.tensor(test_data.y)).sum()
acc = int(correct) / int(test_data.x.shape[0])
print(f'Accuracy: {acc:.4f}')

Train set size: 26168
Validation set size: 26168
Test set size: 29438
x tensor([[ 0.2342,  0.1543, -0.5926, -0.0831],
        [-0.0151, -0.5212, -0.2462, -0.2578],
        [ 0.2730,  0.1734, -0.4794, -0.0868],
        ...,
        [ 1.5978, -0.7168, -1.0910, -0.2555],
        [ 1.0733,  0.0377, -0.6206,  0.1673],
        [ 0.0748,  0.1413, -0.8264, -0.2385]], grad_fn=<AddBackward0>)
softmax tensor([[-1.1278, -1.2076, -1.9546, -1.4451],
        [-1.1572, -1.6633, -1.3883, -1.3999],
        [-1.1230, -1.2227, -1.8755, -1.4828],
        ...,
        [-0.2803, -2.5949, -2.9691, -2.1336],
        [-0.6642, -1.6998, -2.3580, -1.5702],
        [-1.1650, -1.0985, -2.0662, -1.4783]], grad_fn=<LogSoftmaxBackward0>)
Loss: 1.4649
x tensor([[-4.2908, -5.2834,  6.0560, -1.1786],
        [-3.0244, -4.9407,  4.6302, -1.7838],
        [-3.7691, -4.5338,  5.0301, -0.8479],
        ...,
        [-4.5939, -5.4558,  6.0048, -1.1786],
        [-6.2755, -6.9979,  8.5295, -0.9822],
        [-3.6699, -5.0876, 