Here we try to reproduce the `Model/synthNetDataScreen.py` script from Avlant's work

- We only focus on the simulation of the synthetic data.
- we load the input X matrix from torch and simulate the same conditions (unfortunately the seed was not stored)
- compared the Yfull output from torch and the Yfull output from the simulation

result: 
there is an order of 1e-6 difference, which could be due to machine precision 



In [1]:
import os
os.environ["JAX_DEBUG_NANS"] = "True"
import sys
sys.path.insert(0, '../')
from nn_cno import nn_models
import numpy as np
import scipy as sp
import pandas as pd
import equinox as eqx
import optax
import jax.tree_util as jtu
from jax.experimental import sparse
from jax import numpy as jnp
import jax
import matplotlib.pyplot as plt
import functools as ft

# import argparse
# This code is used to evaluate the different parts on a cluster. 

# #Get data number
# parser = argparse.ArgumentParser(prog='Macrophage simulation')
# parser.add_argument('--selectedCondition', action='store', default=None)
# args = parser.parse_args()
# curentId = int(args.selectedCondition)

currentID = 0     # goes between 0-14 based on the sunSlurmSynthScreen.sh file. 

Initialization of the model using the new implementation:

In [2]:

testCondtions = pd.read_csv('synthNetScreen/conditions.tsv', sep='\t', low_memory=False)
simultaniousInput = int(testCondtions.loc[currentID == testCondtions['Index'],:]['Ligands'].values)
N = int(testCondtions.loc[currentID == testCondtions['Index'],:]['DataSize'].values)
print(currentID, simultaniousInput, N)

inputAmplitude = 3
projectionAmplitude = 1.2

modelFile = "data/KEGGnet-Model.tsv"
annotationFile = 'data/KEGGnet-Annotation.tsv'
parameterFile = 'synthNetScreen/equationParams.txt'

parameterizedModel = nn_models.bioNetwork(networkFile=modelFile, 
                                          nodeAnnotationFile=annotationFile,
                                          inputAmplitude=inputAmplitude,
                                          projectionAmplitude=projectionAmplitude)
parameterizedModel.loadParams(parameterFile)

Model  = nn_models.bioNetwork(networkFile=modelFile, 
                                          nodeAnnotationFile=annotationFile,
                                          inputAmplitude=inputAmplitude,
                                          projectionAmplitude=projectionAmplitude)

0 2 10


Load the data saved in the original pytorch version

In [6]:
X_torch = pd.read_csv("./synthNetData_testing/X.csv").to_numpy()
Y_torch = pd.read_csv("./synthNetData_testing/Y.csv").to_numpy()
YfullRef_torch = pd.read_csv("./synthNetData_testing/YfullRef.csv").to_numpy()

Compare the simulation between torch and jax using the same input, the output should be very similar...

In [7]:
Y_jax, YfullRef_jax = jax.vmap(parameterizedModel.model, in_axes=(0),out_axes=(0,0))(X_torch)

In [38]:
# np.savetxt(X = YfullRef_jax.squeeze().to_py() - YfullRef_torch,fname="./synthNetData_testing/YfullRef_diff.csv")
#Y2b = Y2.reshape((Y2.shape[0],1,Y2.shape[1]))


In [10]:
np.linalg.norm(YfullRef_jax.squeeze().to_py() - YfullRef_torch)

2.9239067557038426e-06

In [8]:
YfullRef_torch

array([[-1.50177683e-05,  2.08252244e-02,  1.14711486e-02, ...,
        -2.86063750e-03, -1.50625385e-04,  5.90374019e-02],
       [-1.50177683e-05,  2.36817807e-02,  3.04472496e-01, ...,
        -2.85994995e-03, -1.50640516e-04,  5.90374019e-02],
       [-1.50177683e-05,  2.08252244e-02,  1.14711486e-02, ...,
        -2.87101615e-03, -1.50625385e-04,  6.56646511e-01],
       ...,
       [-1.50177683e-05,  2.08252244e-02,  1.14711486e-02, ...,
        -2.82572184e-03, -1.50625385e-04,  5.90374019e-02],
       [-1.50177683e-05,  2.08252244e-02,  1.14711486e-02, ...,
        -2.87648543e-03, -7.52463104e-05,  5.90374019e-02],
       [-1.50177683e-05,  2.08252244e-02,  1.14711486e-02, ...,
        -4.74807864e-03, -1.50625385e-04,  5.90374019e-02]])

In [9]:
YfullRef_jax

DeviceArray([[[-1.5017786e-05,  2.0825224e-02,  1.1471149e-02, ...,
               -2.8606369e-03, -1.5062539e-04,  5.9037402e-02]],

             [[-1.5017786e-05,  2.3681780e-02,  3.0447251e-01, ...,
               -2.8599496e-03, -1.5064051e-04,  5.9037402e-02]],

             [[-1.5017786e-05,  2.0825224e-02,  1.1471149e-02, ...,
               -2.8710163e-03, -1.5062539e-04,  6.5664649e-01]],

             ...,

             [[-1.5017786e-05,  2.0825224e-02,  1.1471149e-02, ...,
               -2.8257216e-03, -1.5062539e-04,  5.9037402e-02]],

             [[-1.5017786e-05,  2.0825224e-02,  1.1471149e-02, ...,
               -2.8764852e-03, -7.5246309e-05,  5.9037402e-02]],

             [[-1.5017786e-05,  2.0825224e-02,  1.1471149e-02, ...,
               -4.7480790e-03, -1.5062539e-04,  5.9037402e-02]]],            dtype=float32)

In [45]:
X_torch[1,]

array([0.09766712, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.40363993, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

Checking the input layer:

Condition 0 : everything is zero

In [6]:
jax.vmap(parameterizedModel.model.layers[0], in_axes=(0),out_axes=(0))(X_torch)[0,0,:]

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

Condition 1:

In [5]:
jax.vmap(parameterizedModel.model.layers[0], in_axes=(0),out_axes=(0))(X_torch)[1,0,:]

NameError: name 'X_torch' is not defined

Problem: the location of the non-zeros do not match with the torch version

In [10]:
parameterizedModel.model.layers[0].inOutIndices

array([ 59,  39,  36, 109, 256, 287, 366,  33,   7, 154, 231, 205,  76,
        28, 130,  47,  20, 258,  29,  48, 116, 371, 125, 357,  83,   2,
        50, 242,  30,  60,  80,  61,  79,  57, 323, 122, 372,  65,  63,
       176, 157,  56, 387,  72,  54,  52,  55, 343, 381,  49, 276, 367,
       252, 149, 188,  92, 330, 118, 187, 163, 189, 353, 129, 208, 106,
        51,  96, 379, 150,  58, 266,  68, 346,  62, 324, 100,  97,  41,
         9,  53,  74, 218, 204, 318, 386, 101, 238,  67,  64,  82, 265,
        73, 137, 102, 117,  40,  69,  46, 103,  89, 121])

In [3]:
parameterizedModel.network.inName


array(['O14511', 'O14788', 'O14944', 'O43557', 'O75093', 'O75094',
       'O75326', 'O94813', 'O96014', 'P00734', 'P01019', 'P01042',
       'P01133', 'P01135', 'P01137', 'P01138', 'P01178', 'P01189',
       'P01213', 'P01215', 'P01236', 'P01270', 'P01308', 'P01344',
       'P01350', 'P01374', 'P01375', 'P01562', 'P01574', 'P01579',
       'P01583', 'P01584', 'P01889', 'P02452', 'P02751', 'P04085',
       'P04196', 'P04439', 'P04628', 'P05019', 'P05112', 'P05230',
       'P05305', 'P06307', 'P07585', 'P08311', 'P08476', 'P08700',
       'P09038', 'P09326', 'P09603', 'P10321', 'P10586', 'P12272',
       'P12643', 'P12644', 'P13501', 'P14210', 'P15018', 'P15514',
       'P15692', 'P16619', 'P20783', 'P20827', 'P21583', 'P22301',
       'P23560', 'P29459', 'P33681', 'P34130', 'P35070', 'P41159',
       'P41221', 'P42081', 'P43405', 'P48061', 'P49771', 'P50591',
       'P56975', 'P61278', 'P61812', 'P78536', 'P80075', 'Q02297',
       'Q06643', 'Q14005', 'Q14393', 'Q14623', 'Q15465', 'Q6ZM

In [14]:
parameterizedModel.network.nodeNames[401]

'Q9Y243'

In [4]:
parameterizedModel.model.layers[0].inOutIndices

array([  2,   7,   9,  20,  28,  29,  30,  33,  36,  39,  40,  41,  46,
        47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
        60,  61,  62,  63,  64,  65,  67,  68,  69,  72,  73,  74,  76,
        79,  80,  82,  83,  89,  92,  96,  97, 100, 101, 102, 103, 106,
       109, 116, 117, 118, 121, 122, 125, 129, 130, 137, 149, 150, 154,
       157, 163, 176, 187, 188, 189, 204, 205, 208, 218, 231, 238, 242,
       252, 256, 258, 265, 266, 276, 287, 318, 323, 324, 330, 343, 346,
       353, 357, 366, 367, 371, 372, 379, 381, 386, 387])