In [1]:
import squidpy as sq
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from sklearn.model_selection import StratifiedKFold
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import from_scipy_sparse_matrix
import random
from sklearn import metrics
from EquivariantGNN import EquivariantGNN

In [2]:
adata = sq.datasets.seqfish()

labels = adata.obs['celltype_mapped_refined'].cat.codes.values
classes = np.unique(labels)
_, counts = np.unique(labels, return_counts=True)

In [3]:
if 'spatial_distances' not in adata.obsp:
    sq.gr.spatial_neighbors(adata, n_neighs=6, coord_type='generic')

In [4]:
# Build the graph for the GNN models
def get_graph_data(X) -> Data:
    edge_index, edge_attr = from_scipy_sparse_matrix(
        adata.obsp['spatial_distances'])
    edge_attr = edge_attr.view(-1, 1).float()
    x_index, x_values = from_scipy_sparse_matrix(X)
    x = torch.sparse_coo_tensor(x_index, x_values)
    y = torch.tensor(labels, dtype=torch.long)
    pos = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
    data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x.to_dense(), y=y, pos=pos)
    return data

In [5]:
def calculate_stats(y_true, y_pred, weight=1):
    return {
        'accuracy_score': metrics.accuracy_score(y_true, y_pred),
        'balanced_accuracy': metrics.balanced_accuracy_score(y_true, y_pred),
        'f1_score': metrics.f1_score(y_true, y_pred, average='macro', labels=classes, zero_division=0, ),
        'recall': metrics.recall_score(y_true, y_pred, average='macro', labels=classes, zero_division=0),
        'precision_score': metrics.precision_score(y_true, y_pred, average='macro', labels=classes, zero_division=0),
        'weight': weight,
    }

def average_stats(stats_list):
    keys = stats_list[0].keys()
    total_weight = sum([stats['weight'] for stats in stats_list])
    avg = {}
    for key in keys:
        sum = 0
        for stats in stats_list:
            sum += stats[key] * stats['weight']
        avg[key] = sum / total_weight
    return avg

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def train(model, loader, optimizer, class_weights=None):
    model.train()
    for batch in loader:
        batch = batch.to(device)
        batch_size = batch.batch_size
        optimizer.zero_grad()
        out = model(batch)[:batch_size]
        loss = F.cross_entropy(
            out, batch.y[:batch_size], weight=class_weights)
        loss.backward()
        optimizer.step()

def calculate_loss(model, loader, class_weights=None):
    model.eval()
    loss = 0.
    count = 0
    for batch in loader:
        batch = batch.to(device)
        batch_size = batch.batch_size
        out = model(batch)[:batch_size]
        batch_loss = F.cross_entropy(
            out, batch.y[:batch_size], weight=class_weights)
        loss += batch_loss * batch_size
        count += batch_size
    return loss / count

def eval(model, loader):
    model.eval()
    stats = []
    for batch in loader:
        batch = batch.to(device)
        batch_size = batch.batch_size
        out = model(batch)[:batch_size]
        predictions = torch.argmax(out, dim=1)
        stats.append(calculate_stats(
            batch.y[:batch_size], predictions, weight=batch_size))
    return average_stats(stats)


def neighbor_batch_k_fold(model, data, num_neighbors, class_weights=None, n_splits=3):
    model = model.to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, weight_decay=0.0001)
    stats_list = []
    patience = 10
    patience_counter = 0
    max_epochs = 200
    tol = 1e-4
    val_fraction = 0.1
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=1604)
    loader_batch_size = 8
    for train_index, test_index in skf.split(data.x, labels):

        random.shuffle(train_index)
        val_size = int(val_fraction * len(train_index))
        val_index = train_index[-val_size:]
        train_index = train_index[:-val_size]
        train_loader = NeighborLoader(data=data, num_neighbors=num_neighbors, 
                                      batch_size=loader_batch_size, 
                                      input_nodes=torch.tensor(train_index))
        val_loader = NeighborLoader(data=data, num_neighbors=num_neighbors, 
                                    batch_size=loader_batch_size, 
                                    input_nodes=torch.tensor(val_index))
        test_loader = NeighborLoader(data=data, num_neighbors=num_neighbors, 
                                     batch_size=loader_batch_size, 
                                     input_nodes=torch.tensor(test_index))

        best_stats = None
        model.reset_parameters()
        best_score = np.Inf
        old_val_loss = np.Inf
        for epoch in range(max_epochs):
            train(model, train_loader, optimizer, class_weights)
            val_loss = calculate_loss(model, val_loader, class_weights)
            epoch_stats = eval(model, test_loader)
            if val_loss + tol >= old_val_loss:
                patience_counter += 1
                if patience_counter == patience:
                    break
            else:
                patience_counter = 0
                if val_loss < best_score:
                    best_score = val_loss
                    best_stats = epoch_stats
                    best_stats['epoch'] = epoch
            old_val_loss = val_loss

        stats_list.append(best_stats)
    avg_stats = average_stats(stats_list)
    return avg_stats


In [6]:
data = get_graph_data(adata.X)
data.n_id = torch.arange(data.num_nodes)

In [7]:
torch.cuda.empty_cache()
model = EquivariantGNN(in_dim=adata.X.shape[1], out_dim=len(
    labels), emb_dim=256, num_layers=4, edge_dim=0)


In [8]:
neighbor_batch_k_fold(model, data, num_neighbors=[10] * 4,)

RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 2.00 GiB total capacity; 1.12 GiB already allocated; 0 bytes free; 1.18 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF