In [None]:
import torch
import torch.nn.functional as F
from scipy.io import mmread
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.metrics.pairwise import euclidean_distances

In [None]:
print(torch.cuda.is_available()) 
print(torch.version.cuda) 

True
11.8


In [None]:
rawData = mmread('scRNA.mtx')
coo_matrix = rawData.tocoo()

[   30    44    58 ... 34279 34368 34393]


In [None]:
cells = np.array([
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
    10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
    20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
    30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
    40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
    50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
    60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
    70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
    80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
    90, 91, 92, 93, 94, 95, 96, 97, 98, 99
])

gene_expression = [
    100, 110, 120, 10000,  # Cells 0-3
    100, 110, 120, 10000,  # Cells 4-7
    100, 110, 120, 10000,  # Cells 8-11
    100, 110, 120, 10000,  # Cells 12-15
    100, 110, 120, 10000,  # Cells 16-19
    100, 110, 120, 10000,  # Cells 20-23
    100, 110, 120, 60000,  # Cells 24-27
    100, 110, 120, 10000,  # Cells 28-31
    100, 110, 120, 10000,  # Cells 32-35
    100, 110, 120, 10000,  # Cells 36-39
    100, 110, 120, 10000,  # Cells 40-43
    100, 110, 120, 10000,  # Cells 44-47
    100, 110, 120, 40000,  # Cells 48-51
    100, 110, 120, 10000,  # Cells 52-55
    100, 110, 120, 10000,  # Cells 56-59
    100, 110, 120, 10000,  # Cells 60-63
    100, 110, 120, 10000,  # Cells 64-67
    100, 110, 120, 10000,  # Cells 68-71
    100, 110, 120, 10000,  # Cells 72-75
    100, 110, 120, 70000,  # Cells 76-79
    100, 110, 120, 10000,  # Cells 80-83
    100, 110, 120, 10000,  # Cells 84-87
    100, 110, 120, 10000,  # Cells 88-91
    100, 110, 120, 10000,  # Cells 92-95
    100, 110, 120, 10000   # Cells 96-99
]

In [None]:

def split_train_test():
    pass

def build_cell_to_cell_graph(data):  
    # cells = data.row
    # gene_expression = data.data
    
    outliers = []
    x = torch.tensor(gene_expression, dtype=torch.float32).unsqueeze(1)

    k = 2  
    distance_threshold = 500 
    distances = euclidean_distances(x, x)

    edge_index_list = []

    for i in range(len(cells)):
        nearest_neighbors = np.argsort(distances[i])[1:k+1]  
        for j in nearest_neighbors:
            if distances[i, j] <= distance_threshold:
                edge_index_list.append((i, j))
            else:
                outliers.append(int(j))

    edge_index_np = np.array(edge_index_list).T
    edge_index = torch.tensor(edge_index_np, dtype=torch.long) if edge_index_np.size > 0 else torch.empty((2, 0), dtype=torch.long)

    source_cells = cells[edge_index_np[0]]
    target_cells = cells[edge_index_np[1]]

    for src, tgt in zip(source_cells, target_cells):
        print(f"Cell {src} is connected to Cell {tgt}")

    #cleaned outliers
    cleaned_outliers = list(set(outliers))
    print(cleaned_outliers)
    
    #pyG data object
    data = Data(edge_index=edge_index, x=x)
    return data

data = build_cell_to_cell_graph(coo_matrix)

Cell 0 is connected to Cell 4
Cell 0 is connected to Cell 12
Cell 1 is connected to Cell 5
Cell 1 is connected to Cell 13
Cell 2 is connected to Cell 6
Cell 2 is connected to Cell 14
Cell 3 is connected to Cell 7
Cell 3 is connected to Cell 15
Cell 4 is connected to Cell 4
Cell 4 is connected to Cell 12
Cell 5 is connected to Cell 5
Cell 5 is connected to Cell 13
Cell 6 is connected to Cell 6
Cell 6 is connected to Cell 14
Cell 7 is connected to Cell 7
Cell 7 is connected to Cell 15
Cell 8 is connected to Cell 4
Cell 8 is connected to Cell 12
Cell 9 is connected to Cell 5
Cell 9 is connected to Cell 13
Cell 10 is connected to Cell 6
Cell 10 is connected to Cell 14
Cell 11 is connected to Cell 7
Cell 11 is connected to Cell 15
Cell 12 is connected to Cell 4
Cell 12 is connected to Cell 12
Cell 13 is connected to Cell 5
Cell 13 is connected to Cell 13
Cell 14 is connected to Cell 6
Cell 14 is connected to Cell 14
Cell 15 is connected to Cell 7
Cell 15 is connected to Cell 15
Cell 16 is c