In [1]:
from datasets import load_dataset
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP
import pandas as pd

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
dataset = load_dataset("mnist")
dataset.set_format('numpy')

In [3]:
img1d = np.array([dataset['train'][i]['image'].flatten() for i in range(len(dataset['train']))])
labels= dataset['train']['label']
img_shape= dataset['train'][0]['image'].shape
img_shape

(28, 28)

## Projecting data into pre-defined programs

In [4]:
# for each digit, we will plot the average of pixels from the images belonging to that digit
imgs_grouped_by_label = [dataset['train']['image'][labels == i] for i in range(10)]
avg_img = [imgs_grouped_by_label[i].mean(axis=0) for i in range(10)]
digit_programs = [avg_img[i].flatten() for i in range(10)]

----

### building KNN graph for mnist

In [5]:
# import scanpy as sc
import anndata 
import scanpy as sc
mnist_adata = anndata.AnnData(X= img1d)

In [6]:
sc.pp.neighbors(mnist_adata,n_neighbors=5)
mnist_adata.obsp['connectivities']
edge_list = mnist_adata.obsp['connectivities'].nonzero()

         Falling back to preprocessing with `sc.pp.pca` and default params.




In [7]:
# %%
from datasets import load_dataset
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP
import pandas as pd
from models import geneprog_encoder_linear
from datasets import load_dataset

import numpy as np
import os
import time
from datetime import datetime
import wandb
import torch
from torch.utils.data import TensorDataset,DataLoader
from sklearn.metrics import r2_score
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.metrics import log_loss
from sklearn.metrics import accuracy_score

from torch_geometric.loader import GraphSAINTRandomWalkSampler
from torch_geometric.data import Data
# %%

In [9]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
# from torch_geometric.datasets import Planetoid
import torch.nn as nn

class GCN(torch.nn.Module):

    def __init__(self, num_node_features: int, hidden_dim: int, output_dim: int, dropout: float):
        super().__init__()
        self.dropout = dropout

        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.act1 = nn.ReLU()
        self.drop1 = nn.Dropout(p=dropout)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index): 

        x = self.conv1(x, edge_index)
        x = self.act1(x)
        x = self.drop1(x)
        x = self.conv2(x, edge_index)

        return x

In [48]:
import networkx as nx
import matplotlib.pyplot as plt
from random import randint

def visualize_graph(G, color):
    # fig_num  = randint(0,10000)
    # plt.figure(num=fig_num,figsize=(7,7))
    fig, ax = plt.subplots(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, scale= 20, seed=42), ax = ax, with_labels=False,
                     node_color=color, node_size=10,linewidths=1, cmap="Set2")
    return fig


def visualize_embedding(h, color, epoch=None, loss=None):
    fig_num  = randint(0,10000)
    plt.figure(num=fig_num,figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    return plt.figure(fig_num)

In [60]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
# from torch_geometric.datasets import Planetoid
import torch.nn as nn
from models import geneprog_encoder_MLP

class GCN(torch.nn.Module):

    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, dropout: float):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        # self.conv1 = GCNConv(in_dim, in_dim)
        # self.mlp = geneprog_encoder_MLP(in_dim,out_dim)

        self.conv1 = GCNConv(in_dim, out_dim)

    def forward(self, x, edge_index): 

        x = self.conv1(x, edge_index)
        # x = self.mlp(x)

        return x

In [61]:
from torch_geometric.utils import to_networkx

program_def = digit_programs
tru_labels = labels
WANDB_LOGGING = True
LEARNING_RATE = 0.0005
WEIGHT_DECAY = 1e-5
N_EPOCHS = 250
OUTPUT_PREFIX = "./gene_program_runs"

label_list  = [0,1,2,3,4,5,6,7,8,9]

datetimestamp = datetime.now().strftime(r'%Y_%m_%d___%H_%M_%S')
RUN_NAME = f"GCN_{datetimestamp}"



SAVE_DIR = os.path.join(OUTPUT_PREFIX, RUN_NAME)
os.makedirs(SAVE_DIR, exist_ok=True)


device = torch.device('mps:0')


scaler = MinMaxScaler()
X = scaler.fit_transform(img1d)
mnist_x = torch.tensor(X, dtype=torch.float)
mnist_labels = torch.tensor(labels, dtype=torch.long)
mnist_edge_list = torch.tensor(edge_list, dtype=torch.long)
prog_def_tensor = torch.tensor(digit_programs, dtype=torch.float).to(device)
data = Data(x=mnist_x, edge_index=mnist_edge_list,y=mnist_labels)
num_cells, num_genes = data.x.shape
num_prog, _ = prog_def_tensor.shape




model = GCN(num_genes,num_genes,num_prog,dropout=0.6).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)



# GraphSAINT hyperparameters
NUM_MINIGRAPH_INPUTS_PER_EPOCH = 20  # Defining length of an epoch here
NUMBER_OF_RANDOM_WALKS = 100  # This is dependent on GPU memory constraints
RANDOM_WALK_LENGTH = 0


loader = GraphSAINTRandomWalkSampler(
    data, 
    batch_size=NUMBER_OF_RANDOM_WALKS, 
    walk_length=RANDOM_WALK_LENGTH,
    num_steps=NUM_MINIGRAPH_INPUTS_PER_EPOCH, 
    sample_coverage=100,  # leave this as 100, has to do with calculating statistics on initial loading
    log=False
)

# NeighborSampler






iterations_per_anneal_cycle = N_EPOCHS #// 5  # 5 cosine decays during training
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=iterations_per_anneal_cycle, eta_min=1e-7)

# Initialize WandB
if WANDB_LOGGING:
    curr_run = wandb.init(
        project="gene_programs",
        entity="sinag",
        name=RUN_NAME
    )



for epoch in range(N_EPOCHS):
    temp_loss = []
    temp_pearson_r = []
    temp_r2_score = []
    temp_cross_ents = []
    temp_auc = []
    temp_acc = []
    model.train()

    for idx, subgraph in enumerate(loader):   
        optimizer.zero_grad()

        ### Visualize the graph
        # G = to_networkx(subgraph, to_undirected=True)
        # fig = visualize_graph(G, color=subgraph.y)
        # fig.savefig(os.path.join(SAVE_DIR, f"graph_epoch_{epoch}_idx_{idx}.png"), dpi=300)
        # plt.close()
        ####

        subgraph = subgraph.to(device)
        batch_program_scores = model(subgraph.x,subgraph.edge_index)  # shape: [cells, gene_programs]

        X_reconst = torch.matmul(batch_program_scores, prog_def_tensor)  # [cells, gene_programs] x [gene_programs, num_genes]
        loss = criterion(subgraph.x, X_reconst)
        loss.backward()

        proba = batch_program_scores.softmax(dim=1)

        flattened_x = subgraph.x.flatten().detach().cpu().numpy()
        flattened_reconst = X_reconst.flatten().detach().cpu().numpy()

        proba_cpu = proba.detach().cpu().numpy()
        batch_labels_cpu = subgraph.y.detach().cpu().numpy()

        crossent_val = log_loss(y_true=batch_labels_cpu, y_pred=proba_cpu,labels=label_list)
        top_prg = np.argmax(proba_cpu,axis = 1)
        acc = accuracy_score(batch_labels_cpu,top_prg)

        
        temp_cross_ents.append(crossent_val)
        auc = roc_auc_score(y_true=batch_labels_cpu, y_score=proba_cpu, multi_class='ovr',labels=label_list)


        r2_value = r2_score(flattened_x, flattened_reconst)
        pearson_r_value = np.corrcoef(flattened_x, flattened_reconst)[0, 1]

        loss_val = loss.item()
        
        temp_acc.append(acc)
        temp_auc.append(auc)
        temp_pearson_r.append(pearson_r_value)
        temp_r2_score.append(r2_value)
        temp_loss.append(loss_val)

        optimizer.step()
        scheduler.step(epoch + idx / len(loader)) # Adjust learning rate


    # Compute metrics
    avg_auc = np.mean(temp_auc)
    avg_acc = np.mean(temp_acc)
    avg_loss = np.mean(temp_loss)
    avg_r2_score = np.mean(temp_r2_score)
    avg_pearsonr_score = np.mean(temp_pearson_r)
    avg_crossent = np.mean(temp_cross_ents)

    if WANDB_LOGGING:
        wandb.log({
            "Learning Rate": scheduler.get_last_lr()[0],
            "Loss (MSE-reconstruction)": avg_loss,
            "R2 (reconstruction)": avg_r2_score,
            "Pearson (reconstruction)": avg_pearsonr_score,
            "Cross Entropy (program scores - labels)": avg_crossent,
            "AUC (program scores vs labels)": avg_auc,
            "Accuracy (top program vs labels)": avg_acc
        }, step=epoch)

if WANDB_LOGGING:
    curr_run.finish()

# model.eval()
# program_scores = model(torch.from_numpy(X).to(device))
# reconst_all = torch.matmul(program_scores, prog_def_tensor)
# program_scores = program_scores.detach().cpu().numpy()
# reconst_all = reconst_all.detach().cpu().numpy()

# torch.save(model.state_dict(), os.path.join(SAVE_DIR, "model.pt"))

ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.