In [None]:
import torch
import numpy as np
from scipy.io import mmread
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.metrics.pairwise import euclidean_distances
import torch.optim as optim
from sklearn.cluster import KMeans
import torch_geometric.utils as pyg_utils

In [None]:
#check cuda 
print(torch.cuda.is_available()) 
print(torch.version.cuda)
print(torch.__version__)

In [None]:
rawData = mmread('scRNA.mtx')
coo_matrix = rawData.tocoo()
print(coo_matrix)
print(coo_matrix.shape)

In [None]:
def clean_and_split_data(coo_matrix, max_number):
    #get only non-zero values
    total_nnz = coo_matrix.nnz 

    # Ensure max_nnz doesn’t exceed total
    if max_number >= total_nnz:
        raise ValueError(f"max_nnz ({max_number}) must be less than total non-zero elements ({total_nnz})")
    
    rows = coo_matrix.row
    cols = coo_matrix.col
    data = coo_matrix.data
    
    selected_indices = np.arange(max_number)  

    selected = coo_matrix.__class__(
        (data[selected_indices], (rows[selected_indices], cols[selected_indices])),
        shape=coo_matrix.shape
    )
    
    return selected

processed_data = clean_and_split_data(coo_matrix=coo_matrix, max_number=1000000)
print(processed_data)

In [None]:
def build_cell_to_cell_graph_normal(data, threshold):   

    gene_expression = data.data
    outliers = []
    x = torch.tensor(gene_expression, dtype=torch.float32).unsqueeze(1)

    k = 2  
    distance_threshold = 500 
    distances = euclidean_distances(x, x)

    edge_index_list = []

    for i in range(len(gene_expression)):
        nearest_neighbors = np.argsort(distances[i])[1:k+1]  
        for j in nearest_neighbors:
            if distances[i, j] <= threshold:
                edge_index_list.append((i, j))
            else:
                outliers.append(int(j))

    edge_index_np = np.array(edge_index_list).T
    edge_index = torch.tensor(edge_index_np, dtype=torch.long) if edge_index_np.size > 0 else torch.empty((2, 0), dtype=torch.long)

    source_cells = cells[edge_index_np[0]]
    target_cells = cells[edge_index_np[1]]

    for src, tgt in zip(source_cells, target_cells):
        print(f"Cell {src} is connected to Cell {tgt}")

  
    cleaned_outliers = list(set(outliers))
    print(cleaned_outliers)
    
  
    data = Data(edge_index=edge_index, x=x)
    return data

py_data = build_cell_to_cell_graph_normal(processed_data, threshold=500)

In [None]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))  
        
        for _ in range(num_layers - 2):  
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        self.convs.append(SAGEConv(hidden_channels, out_channels))  

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]: 
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
        x = self.convs[-1](x, edge_index) 
        return x  