In [None]:
import gc
import torch
import faiss 
import random
import numpy as np
from scipy.io import mmread
import torch.nn.functional as F
from torch.nn import TripletMarginLoss
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
import torch.optim as optim
# from sklearn.cluster import KMeans
# import torch_geometric.utils as pyg_utils
import matplotlib.pyplot as plt
# from sklearn.decomposition import PCA
# import pandas as pd
# import seaborn as sns

In [None]:
#check cuda 
if torch.cuda.is_available():
    print(f"CUDA Device Count: {torch.cuda.device_count()}")
    print(f"Device Name: {torch.cuda.get_device_name(0)}")
    print(f"Compute Capability: {torch.cuda.get_device_capability(0)}")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9} GB")
    print(f"Memory Cached: {torch.cuda.memory_reserved(0) / 1e9} GB")
else:
    print("No CUDA-compatible GPU found.")

torch.cuda.empty_cache()

In [None]:
rawData = mmread('scRNA.mtx')
coo_matrix = rawData.tocoo()
print(coo_matrix)
print(coo_matrix.shape)

##### Process And Modify Data

In [None]:
def clean_and_split_data(coo_matrix, max_number):
    #get only non-zero values
    total_nnz = coo_matrix.nnz 

    # Ensure max_nnz doesn’t exceed total
    if max_number >= total_nnz:
        raise ValueError(f"max_nnz ({max_number}) must be less than total non-zero elements ({total_nnz})")
    
    rows = coo_matrix.row
    cols = coo_matrix.col
    data = coo_matrix.data
    
    selected_indices = np.arange(max_number)  

    selected = coo_matrix.__class__(
        (data[selected_indices], (rows[selected_indices], cols[selected_indices])),
        shape=coo_matrix.shape
    )
    
    return selected

processed_data = clean_and_split_data(coo_matrix=coo_matrix, max_number=900000)
print(processed_data)

##### Graph Data Object

In [None]:
def cell_graph(data, threshold):

    gene_expression = data.data
    
    x = np.asarray(gene_expression, dtype=np.float32)
    x = x.reshape(-1, 1)


    gpu_resource_manager = faiss.StandardGpuResources() 
    similarity_object = faiss.IndexFlatL2(1)
    similarity_object_in_gpu = faiss.index_cpu_to_gpu(gpu_resource_manager, 0, similarity_object)


    print(similarity_object_in_gpu.is_trained)  
    print(f"FAISS index type: {type(similarity_object_in_gpu)}") 


    similarity_object_in_gpu.add(x)
    k=2
    distances, indices = similarity_object_in_gpu.search(x, k + 1)
    
    edge_index_list = []
    outliers = []
    
    for i in range(len(gene_expression)):
        nearest_neighbors = indices[i, 1:k+1]  
        neighbor_distances = distances[i, 1:k+1]
        
        for j, dist in zip(nearest_neighbors, neighbor_distances):
            if dist <= threshold ** 2:
                edge_index_list.append((i, j))
            else:
                outliers.append(int(j))
    

    edge_index_np = np.array(edge_index_list).T
    edge_index = torch.tensor(edge_index_np, dtype=torch.long) if edge_index_np.size > 0 else torch.empty((2, 0), dtype=torch.long)

    cleaned_outliers = list(set(outliers))
    print(cleaned_outliers)

    x_tensor = torch.tensor(x, dtype=torch.float32)
    pyg_data = Data(edge_index=edge_index, x=x_tensor)
    print(pyg_data)
    return pyg_data

data = cell_graph(data=processed_data,threshold=500)

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
gene_expression_lvl = data.x.cpu().numpy()

k = 10  
d = gene_expression_lvl.shape[1] 
kmeans = faiss.Kmeans(d, k, niter=300, gpu=True)  


kmeans.train(gene_expression_lvl)

_, labels = kmeans.index.search(gene_expression_lvl, 1) 
labels = labels.flatten()

In [None]:
cluster_counts = np.bincount(labels)  


for i in range(k):
    print(f"Cluster {i}: {cluster_counts[i]} nodes")

In [None]:
gene_expression_levels = data.x.cpu().numpy().flatten() 

plt.figure(figsize=(10, 6))
unique_labels = np.unique(labels)


for label in unique_labels:
    cluster_cells = gene_expression_lvl[labels == label]
    num_cells = len(cluster_cells)
    plt.scatter(np.full_like(cluster_cells, label), cluster_cells, alpha=0.5, label=f'Cluster {label} ({num_cells} cells)')

# Label the axes
plt.xlabel('Cluster ID')
plt.ylabel('Gene Expression Level')
plt.title('Gene Expression Levels per Cluster')


plt.legend()
plt.show()

In [None]:
gene_expression_levels_cpu = gene_expression_levels.cpu().numpy() if hasattr(gene_expression_levels, "cpu") else gene_expression_levels
num_cells = len(gene_expression_levels_cpu)

cell_indices = np.arange(num_cells)

plt.figure(figsize=(12, 6))
plt.scatter(gene_expression_levels_cpu, cell_indices, alpha=0.5, s=2, c='blue')  

plt.xlabel('Gene Expression Level')
plt.ylabel('Cell Index')
plt.title('Gene Expression Levels Across Cells')
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
plt.scatter(gene_expression_levels_cpu, cell_indices, c=labels, cmap='tab10', alpha=0.5, s=2)
plt.xlabel('Gene Expression Level')
plt.ylabel('Cell Index')
plt.title('Gene Expression Levels Across Cells (Colored by Cluster)')
plt.colorbar(label='Cluster ID')
plt.grid(True)
plt.show()

Graph Neural Network

In [None]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super(GraphSAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))  
        
        for _ in range(num_layers - 2):  
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        self.convs.append(SAGEConv(hidden_channels, out_channels))  

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]: 
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
        x = self.convs[-1](x, edge_index) 
        return x  

Traning Without Loss Function

In [None]:
#train without loss function
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

model = GraphSAGE(in_channels=1, hidden_channels=128, out_channels=64, num_layers=2).to(device)
data = data.to(device)

model.eval()
with torch.no_grad():
    embeddings = model(data.x, data.edge_index)

print(embeddings.shape)

torch.save(model, 'entire_model.pth')

In [None]:
embeddings_np = embeddings.cpu().numpy()

d = embeddings_np.shape[1]  
k = 10
kmeans = faiss.Kmeans(d, k, niter=300, gpu=True)  


kmeans.train(embeddings_np)
_, labels = kmeans.index.search(embeddings_np, 1)


labels = torch.tensor(labels.flatten(), device='cuda')
print(labels.shape)

In [None]:
cluster_counts = torch.bincount(labels)

# Print the number of cells in each cluster
for i, count in enumerate(cluster_counts):
    print(f"Cluster {i}: {count.item()} cells")

In [None]:
labels_cpu = labels.cpu().numpy()
gene_expression_levels_cpu = gene_expression_levels


plt.figure(figsize=(10, 6))
unique_labels = np.unique(labels_cpu)

for label in unique_labels:
    cluster_cells = gene_expression_levels_cpu[labels_cpu == label]
    num_cells = len(cluster_cells)
    plt.scatter(np.full_like(cluster_cells, label), cluster_cells, alpha=0.5, label=f'Cluster {label} ({num_cells} cells)')


plt.xlabel('Cluster ID')
plt.ylabel('Gene Expression Level')
plt.title('Gene Expression Levels per Cluster')
plt.legend()
plt.show()

In [None]:
gene_expression_levels_cpu = gene_expression_levels.cpu().numpy() if hasattr(gene_expression_levels, "cpu") else gene_expression_levels
num_cells = len(gene_expression_levels_cpu)

cell_indices = np.arange(num_cells)

plt.figure(figsize=(12, 6))
plt.scatter(gene_expression_levels_cpu, cell_indices, alpha=0.5, s=2, c='blue')  

plt.xlabel('Gene Expression Level')
plt.ylabel('Cell Index')
plt.title('Gene Expression Levels Across Cells')
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
plt.scatter(gene_expression_levels_cpu, cell_indices, c=labels_cpu, cmap='tab10', alpha=0.5, s=2)
plt.xlabel('Gene Expression Level')
plt.ylabel('Cell Index')
plt.title('Gene Expression Levels Across Cells (Colored by Cluster)')
plt.colorbar(label='Cluster ID')
plt.grid(True)
plt.show()

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved : {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
print(f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
print(f"Max Reserved : {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")


Model 1  + Testing with parameter changes - No Loss Function

In [None]:
model1 = GraphSAGE(in_channels=1, hidden_channels=64,out_channels=64, num_layers=4).to(device)
model1_1 = GraphSAGE(in_channels=1, hidden_channels=64,out_channels=64, num_layers=8).to(device)
model1_2 = GraphSAGE(in_channels=1, hidden_channels=64,out_channels=64, num_layers=16).to(device)


model1.eval()
with torch.no_grad():
    embeddings_model1 = model1(data.x, data.edge_index)

print(embeddings_model1.shape)

model1.eval()
with torch.no_grad():
    embeddings_model1_1 = model1_1(data.x, data.edge_index)

print(embeddings_model1_1.shape)


model1.eval()
with torch.no_grad():
    embeddings_model1_2 = model1_2(data.x, data.edge_index)

print(embeddings_model1_2.shape)


torch.save(model1, 'entire_model1.pth')
torch.save(model1_1, 'entire_model1_1.pth')
torch.save(model1_2, 'entire_model1_2.pth')

In [None]:
gc.collect()
torch.cuda.empty_cache()

Model 2 with +loss function - Triplet Loss

In [None]:
def generate_triplets_with_faiss_gpu(embeddings, k=5, num_triplets_per_node=1):
    device = embeddings.device
    emb_np = embeddings.detach().cpu().numpy().astype('float32')
    num_nodes, emb_dim = emb_np.shape

    index = faiss.IndexFlatL2(emb_dim)
    if torch.cuda.is_available():
        res = faiss.StandardGpuResources()
        index = faiss.index_cpu_to_gpu(res, 0, index)
    index.add(emb_np)

    _, neighbors = index.search(emb_np, k + 1)

    anchors, positives, negatives = [], [], []

    for i in range(num_nodes):
        pos_candidates = neighbors[i][1:]  # Exclude self
        for _ in range(num_triplets_per_node):
            pos_idx = random.choice(pos_candidates)

            # Avoid large k
            k_neg = min(50, num_nodes)
            _, all_indices = index.search(emb_np[i:i+1], k_neg)

            hard_neg = None
            for n in all_indices[0][1:]:
                if n not in neighbors[i]:
                    hard_neg = n
                    break

            if hard_neg is not None:
                anchors.append(i)
                positives.append(pos_idx)
                negatives.append(hard_neg)

    anchor = embeddings[anchors].to(device)
    positive = embeddings[positives].to(device)
    negative = embeddings[negatives].to(device)

    return anchor, positive, negative



In [None]:
triplet_loss_fn = TripletMarginLoss(margin=0.5, p=2)

In [None]:
model3 = GraphSAGE(in_channels=1, hidden_channels=8, out_channels=4, num_layers=2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model3.train()

for epoch in range(5):
    optimizer.zero_grad()

    embeddings = model3(data.x, data.edge_index)

    anchor, positive, negative = generate_triplets_with_faiss_gpu(embeddings, k=5)
    
    loss = triplet_loss_fn(anchor, positive, negative)

    #distance logging
    pos_dist = F.pairwise_distance(anchor, positive, p=2)
    neg_dist = F.pairwise_distance(anchor, negative, p=2)

    avg_pos_dist = pos_dist.mean().item()
    avg_neg_dist = neg_dist.mean().item()

    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | PosDist: {avg_pos_dist:.4f} | NegDist: {avg_neg_dist:.4f}")
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")

In [None]:
model3.eval()
with torch.no_grad():
    embeddings_model3 = model1_2(data.x, data.edge_index)

print(embeddings_model3.shape)


torch.save(model3, 'entire_model3.pth')

In [None]:
gc.collect()
torch.cuda.empty_cache()

Model 4 +loss function - Constrastive Loss

In [None]:
#train with constrative-loss 
def generate_contrastive_pairs_faiss_gpu(embeddings, k=3, num_negatives=1):
    """
    Generate (i, j, y) pairs for contrastive loss using FAISS on GPU.

    Args:
        embeddings: Tensor [num_nodes, emb_dim] on GPU.
        k: # of nearest neighbors (positives).
        num_negatives: # of negative samples per node.

    Returns:
        anchor_idx, pair_idx, labels (1 for positive, 0 for negative)
    """
    device = embeddings.device
    emb_np = embeddings.detach().cpu().numpy().astype('float32')
    num_nodes = emb_np.shape[0]

    index = faiss.IndexFlatL2(emb_np.shape[1])
    if torch.cuda.is_available():
        res = faiss.StandardGpuResources()
        index = faiss.index_cpu_to_gpu(res, 0, index)
    index.add(emb_np)
    _, neighbors = index.search(emb_np, k + 1)

    anchor_idx, pair_idx, labels = [], [], []

    for i in range(num_nodes):
        # Positive pairs from neighbors (skip self)
        for j in neighbors[i][1:]:
            anchor_idx.append(i)
            pair_idx.append(j)
            labels.append(1)

        # Negative samples (not among k neighbors)
        for _ in range(num_negatives):
            j = random.randint(0, num_nodes - 1)
            while j in neighbors[i]:
                j = random.randint(0, num_nodes - 1)
            anchor_idx.append(i)
            pair_idx.append(j)
            labels.append(0)

    return (
        torch.tensor(anchor_idx, device=device),
        torch.tensor(pair_idx, device=device),
        torch.tensor(labels, dtype=torch.float32, device=device),
    )


In [None]:
def contrastive_loss_fn(z_i, z_j, y, margin=2.0):
    """
    Contrastive loss: minimize distance for positives, maximize for negatives.
    """
    dist = F.pairwise_distance(z_i, z_j)
    loss = y * dist.pow(2) + (1 - y) * F.relu(margin - dist).pow(2)
    return loss.mean()

In [None]:
model4 = GraphSAGE(in_channels=1, hidden_channels=8, out_channels=4, num_layers=4).to(device)

model4.train()
for epoch in range(100):  
    optimizer.zero_grad()

    embeddings = model4(data.x, data.edge_index)

    anchor_idx, pair_idx, labels = generate_contrastive_pairs_faiss_gpu(embeddings, k=5)

    z_i = embeddings[anchor_idx]
    z_j = embeddings[pair_idx]

    loss = contrastive_loss_fn(z_i, z_j, labels)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")


In [None]:
model4.eval()
with torch.no_grad():
    embeddings_model4 = model1_2(data.x, data.edge_index)

print(embeddings_model4.shape)


torch.save(model4, 'entire_model4.pth')