In [1]:
import argparse
import copy
import csv
import itertools
import math
import os.path
import pdb
import random
import sys
import time
from collections import Counter
from itertools import permutations, product
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
import torch.optim as optim
from scipy.integrate import solve_ivp, odeint
from scipy.spatial.distance import braycurtis
from scipy.stats import entropy, pearsonr, spearmanr
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# DKO Model

In [4]:
def get_batch(ztrn,ptrn,mb_size):
    s = torch.from_numpy(np.random.choice(np.arange(ptrn.size(dim=0), dtype=np.int64), mb_size, replace=False))
    batch_p = ztrn[s,:]
    batch_q = ptrn[s,:]
    batch_t = t[:batch_time]
    return batch_p.to(device), batch_q.to(device)


def loss_bc(p_i,q_i):
    return torch.sum(torch.abs(p_i-q_i))/torch.sum(torch.abs(p_i+q_i))


def process_data(P):
    #Z = P.copy()
    #Z[Z!=0] = 1
    P = P/P.sum(axis=0)[np.newaxis,:]
    #Z = Z/Z.sum(axis=0)[np.newaxis,:]
    
   
    P = P.astype(np.float32)
#     Z = Z.astype(np.float32)

    P = torch.from_numpy(P.T)
    #Z = torch.from_numpy(Z.T)
    return P


class ODEFunc(torch.nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.fcc1 = torch.nn.Linear(N, N)
        self.fcc2 = torch.nn.Linear(N, N)

    def forward(self, y):
        out = self.fcc1(y)
        out = self.fcc2(out)
        return torch.mul(y,torch.abs(out))/torch.sum(torch.mul(y,torch.abs(out)))
    
def train_reptile(max_epochs,mb,LR,ztrn,ptrn,ztst,ptst,zval,pval,zall,pall):
    loss_train = []
    loss_val = []
    qtst = np.zeros((ztst.size(dim=0),N))
    qtrn = np.zeros((zall.size(dim=0),N))
    ltst_pred = np.zeros((ztst.size(dim=0),1))
    ltst_ground = np.zeros((ztst.size(dim=0),1))
    
    func = ODEFunc().to(device)
    optimizer = torch.optim.Adam(func.parameters(), lr=LR)

    Loss_opt = 1
    for e in range(max_epochs):
        optimizer.zero_grad()
        batch_p, batch_q = get_batch(ztrn,ptrn,mb)
        
        # loss of the traning set
        for i in range(mb):
            p_pred = func(batch_p[i]).to(device)
            p_pred = torch.reshape(p_pred,(1,N))
            if i==0:
                loss = loss_bc(p_pred.unsqueeze(dim=0),batch_q[i].unsqueeze(dim=0))
            else:
                loss = loss + loss_bc(p_pred.unsqueeze(dim=0),batch_q[i].unsqueeze(dim=0))
        loss_train.append(loss.item()/mb)


        # validation set
        for i in range(zval.size(dim=0)):
            p_pred = func(zval[i]).to(device)
            p_pred = torch.reshape(p_pred,(1,N))
            if i==0:
                l_val = loss_bc(p_pred.unsqueeze(dim=0),pval[i].unsqueeze(dim=0))
            else:
                l_val = l_val + loss_bc(p_pred.unsqueeze(dim=0),pval[i].unsqueeze(dim=0))
        loss_val.append(l_val.item()/zval.size(dim=0))
        if l_val.item()/zval.size(dim=0)<=Loss_opt:
            Loss_opt = loss_val[-1]
            best_model = copy.deepcopy(func)
        #print('epoch = ',e, 'loss = ', l_val.item()/mb)

        # update the neural network
        func.zero_grad()
        loss.backward()
        optimizer.step()

        if e == max_epochs-1:
            func = copy.deepcopy(best_model)
            if len(ztst.size())==2:
                for i in range(ztst.size(dim=0)):
                    pred_test = func(ztst[i]).to(device)
                    pred_test = torch.reshape(pred_test,(1,N))
                    qtst[i,:] = pred_test.detach().numpy()
                for i in range(zall.size(dim=0)):
                    pred_test = func(zall[i]).to(device)
                    pred_test = torch.reshape(pred_test,(1,N))
                    qtrn[i,:] = pred_test.detach().numpy()
    return loss_train,qtst,qtrn

# Step 1: Generate knockout perturbation data.

In [None]:
# P_test: shape = (N_genes, n_perturbations)
#   Each column corresponds to one perturbation event.
#   A perturbation event is defined by a pair (cell_id, gene_id) indicating that gene_id is knocked out in cell_id. 

# Recorder: shape = (n_perturbations, 2)
#     recording the (cell_id, gene_id) each perturbation

# n_perturbations:
#   In single-cell data, this equals the total number of knocked-out genes across all cells.
#   If you create one KO per gene with expression > 0 in each cell:
#       n_perturbations = sum over cells of (number of genes with expression > 0 in that cell).

In [17]:
Orignial = pd.read_csv(f'./data/Ptrain.csv',header = None).values.astype(float)
perurbed_times = len(np.nonzero(Orignial)[0])  #if gene expression is more than 0 will be computed

perturb_count = 0
perturb_matrix = np.zeros((N_genes, perurbed_times))

records = np.zeros((perurbed_times, 2),dtype = int)

for cell_idx in tqdm(range(N_cells)):
    original_cell = Orignial[:, cell_idx].copy()
    nonzero_genes = np.nonzero(original_cell)[0]
    for gene_idx in nonzero_genes:
        perturb_cell = original_cell.copy()
        perturb_cell[gene_idx] = 0
        
        perturb_matrix[:,perturb_count] = perturb_cell
        records[perturb_count] = (cell_idx, gene_idx)
        perturb_count+=1

matrix_df = pd.DataFrame(perturb_matrix)
matrix_df.to_csv(f'./Ptest.csv',header = None,index = None)    
records_df  = pd.DataFrame(records)
records_df.to_csv(f'./Recoder.csv',header = ['cell_index','gene_index'],index = 0)

100%|██████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 5207.11it/s]


# Step 2: Predicting the gene expression profile after gene KO

In [19]:
# hyperparameters
max_epochs = 5000
device = 'cpu'
batch_time = 100
t = torch.arange(0.0, 100.0, 0.01)

filepath_train = f'./data/Ptrain.csv'
P = np.loadtxt(filepath_train,delimiter=',')
Z = P.copy()
Z[Z>0] = 1

pall = process_data((P))
zall = process_data(Z)

number_of_cols = P.shape[1]
random_indices = np.random.choice(number_of_cols, size=int(0.2*number_of_cols), replace=False)
P_val = P[:,random_indices]
Z_val = Z[:,random_indices]
P_train =  P[:,np.setdiff1d(range(0,number_of_cols),random_indices)]
Z_train =  Z[:,np.setdiff1d(range(0,number_of_cols),random_indices)]


ptrn= process_data((P_train))
pval = process_data((P_val))
ztrn = process_data(Z_train)
zval = process_data(Z_val)

filepath_test = f'./data/Ptest.csv'
P1 = np.loadtxt(filepath_test,delimiter=',')
Z1 = P1.copy()
Z1[Z1>0] = 1
ptst = process_data((P1))
ztst = process_data(Z1)

M, N = ptrn.shape

print(ztst.shape)
# pre training to select the parameter
LR = 0.01
mb = 20

loss_train,qtst,qtrn = train_reptile(max_epochs,mb,LR,ztrn,ptrn,ztst,ptst,zval,pval,zall,pall)
np.savetxt(f'./data/PredKO.csv',qtst,delimiter=',')
np.savetxt(f'./data/PPred.csv',qtrn,delimiter=',')

torch.Size([39343, 100])
(39343, 100)


# Step 3：Cacluate the gene KO impact score

In [8]:
Original_model = np.loadtxt(f'./data/PPred.csv', delimiter=',')# shape: (N_cells,genes)
predss = np.loadtxt(f'./data/PredKO.csv', delimiter=',')  # shape: (perturbation, N_genes)
records = pd.read_csv(f'./data/Recoder.csv')
Perturb_cell_index = list(records['cell_index'])
Perturb_gene_index = list(records['gene_index'])


p_dis = []
for i in range(predss.shape[0]):
    c_id = Perturb_cell_index[i] #This cell_id of KO gene_id
    
    Without_knock =  Original_model[c_id,:].copy()
    Without_knock = Without_knock/np.sum(Without_knock)
    one_pred = predss[i,:]
        
    p_dis.append(braycurtis(one_pred, Without_knock))
    
df = pd.DataFrame({
    "gene_id":Perturb_gene_index,
    "cell_id":Perturb_cell_index,
    "k_pred": p_dis
})
df.to_csv(f'./results/Impact_results.csv',index = None)