In [None]:
import torch
import torch.nn.functional as F
from scipy.io import mmread
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.neighbors import kneighbors_graph
from sklearn.cluster import DBSCAN



In [None]:
rawData = mmread('scRNA.mtx')
coo_matrix = rawData.tocoo()
print(coo_matrix)

<COOrdinate sparse matrix of dtype 'float64'
	with 86422438 stored elements and shape (34619, 27180)>
  Coords	Values
  (30, 0)	878.85046
  (44, 0)	2636.5513
  (58, 0)	878.85046
  (74, 0)	878.85046
  (77, 0)	878.85046
  (109, 0)	878.85046
  (225, 0)	3515.4019
  (247, 0)	878.85046
  (272, 0)	878.85046
  (273, 0)	1318.2756
  (365, 0)	878.85046
  (366, 0)	2636.5513
  (421, 0)	878.85046
  (431, 0)	1757.7009
  (440, 0)	878.85046
  (442, 0)	5273.1025
  (499, 0)	16698.158
  (594, 0)	878.85046
  (649, 0)	1757.7009
  (688, 0)	878.85046
  (706, 0)	878.85046
  (831, 0)	1757.7009
  (835, 0)	2636.5513
  (867, 0)	878.85046
  (911, 0)	2636.5513
  :	:
  (32162, 27179)	142.14685
  (32202, 27179)	88.841736
  (32298, 27179)	177.68347
  (32423, 27179)	59.227825
  (32425, 27179)	177.68347
  (32564, 27179)	88.841736
  (33015, 27179)	177.68347
  (33454, 27179)	177.68347
  (33504, 27179)	177.68347
  (33538, 27179)	177.68347
  (33578, 27179)	88.841736
  (33636, 27179)	177.68347
  (33647, 27179)	284.01257
  (33

In [None]:

def split_train_test():
    pass


def build_cell_to_cell_graph(data):
    
    fakeDataSetCells = [0,1,2,3,4,5,6,7,6,8,9,20,40,50,60,70,60,100,101,102,10,99]
    fakeDataSetExpressions = [120,120,110,100,90,30,89,100,90,70,78,66,120,100,110]

    # cells = data.row
    # expression = data.data

    #float32 helps with memeory efficency
    x = torch.tensor(fakeDataSetCells, dtype=torch.float32).unsqueeze(1) 
    y = torch.tensor(fakeDataSetExpressions, dtype=torch.float32).unsqueeze(1)
    print(x,y)

    #find top 5 smiliar cells using k-nearest neibgour 
    k = 2  # number of neighbors (adjust based on your needs)
    adj_matrix = kneighbors_graph(y, n_neighbors=k, mode='connectivity', include_self=False)
    edge_index = torch.tensor(adj_matrix.nonzero(), dtype=torch.long)
    
    print(edge_index)
    # print(similarity_matrix)
    data = Data(edge_index=edge_index, x=x)
    return data

data = build_cell_to_cell_graph(coo_matrix)

tensor([[  0.],
        [  1.],
        [  2.],
        [  3.],
        [  4.],
        [  5.],
        [  6.],
        [  7.],
        [  6.],
        [  8.],
        [  9.],
        [ 20.],
        [ 40.],
        [ 50.],
        [ 60.],
        [ 70.],
        [ 60.],
        [100.],
        [101.],
        [102.],
        [ 10.],
        [ 99.]]) tensor([[120.],
        [120.],
        [110.],
        [100.],
        [ 90.],
        [ 30.],
        [ 89.],
        [100.],
        [ 90.],
        [ 70.],
        [ 78.],
        [ 66.],
        [120.],
        [100.],
        [110.]])
tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
          9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14],
        [12,  1, 12,  0, 14,  1, 13,  7,  8,  6, 11,  9,  8,  4, 13,  3,  4,  6,
         11, 10,  9,  6,  9, 10,  0,  1,  7,  3,  2,  1]])


  edge_index = torch.tensor(adj_matrix.nonzero(), dtype=torch.long)


In [None]:

num_nodes = 10

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels, aggr="mean")
        self.conv2 = SAGEConv(hidden_channels, out_channels, aggr="mean")
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


#loss function
def contrastive_clustering_loss(embeddings, edge_index,margin=1.0):
    src, dst = edge_index
    pos_distances = (embeddings[src] - embeddings[dst]).pow(2).sum(dim=1)

    neg_dst = torch.randint(0, embeddings.size(0), (edge_index.size(1),), device=embeddings.device)
    neg_distances = (embeddings[src] - embeddings[neg_dst]).pow(2).sum(dim=1)
    loss = F.relu(pos_distances - neg_distances + margin).mean()

    return loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
model = GraphSAGE(in_channels=1, hidden_channels=16, out_channels=16).to(device)

#implement optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    embeddings = model(data.x, data.edge_index)
    loss = contrastive_clustering_loss(embeddings, data.edge_index)
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

Epoch: 000, Loss: 354.1731
Epoch: 020, Loss: 14.2802
Epoch: 040, Loss: 25.3115
Epoch: 060, Loss: 5.6164
Epoch: 080, Loss: 0.9164
Epoch: 100, Loss: 1.3584
Epoch: 120, Loss: 1.3187
Epoch: 140, Loss: 2.0002
Epoch: 160, Loss: 1.8381
Epoch: 180, Loss: 0.7326


In [None]:
#clustering
