Here we try to reproduce the `Model/synthNetDataScreen.py` script from Avlant's work

In [1]:
import sys
sys.path.insert(0, '../')
from nn_cno import nn_models
import numpy as np
import scipy as sp
import pandas as pd
import equinox as eqx
import optax
import jax.tree_util as jtu
from jax.experimental import sparse
from jax import numpy as jnp
import jax
import matplotlib.pyplot as plt
import functools as ft

# import argparse

It was run on the cluster. The args were found in the `sunSlurmSynthScreen.sh` file. It takes a value between 0 and 14. 

In [2]:
# This code is used to evaluate the different parts on a cluster. 

# #Get data number
# parser = argparse.ArgumentParser(prog='Macrophage simulation')
# parser.add_argument('--selectedCondition', action='store', default=None)
# args = parser.parse_args()
# curentId = int(args.selectedCondition)

currentID = 0     # goes between 0-14 based on the sunSlurmSynthScreen.sh file. 


In [3]:

testCondtions = pd.read_csv('synthNetScreen/conditions.tsv', sep='\t', low_memory=False)
simultaniousInput = int(testCondtions.loc[currentID == testCondtions['Index'],:]['Ligands'].values)
N = int(testCondtions.loc[currentID == testCondtions['Index'],:]['DataSize'].values)
print(currentID, simultaniousInput, N)

inputAmplitude = 3
projectionAmplitude = 1.2

0 2 10


Initialization of the model using the new implementation:

In [4]:
modelFile = "data/KEGGnet-Model.tsv"
annotationFile = 'data/KEGGnet-Annotation.tsv'
parameterFile = 'synthNetScreen/equationParams.txt'

parameterizedModel = nn_models.bioNetwork(networkFile=modelFile, 
                                          nodeAnnotationFile=annotationFile,
                                          inputAmplitude=inputAmplitude,
                                          projectionAmplitude=projectionAmplitude)
parameterizedModel.loadParams(parameterFile)

Model = nn_models.bioNetwork(networkFile=modelFile, nodeAnnotationFile=annotationFile)

In-silico data was generated by simulating the model with given parameters.
I think the in-silico model was a result of some kind of optimization, because the paper
mentions that the random parameterization results in pretty uniform output. 


In [5]:
#Generate data
X = np.zeros((N, len(parameterizedModel.network.inName)))
for i in range(1, N): #skip 0 to include a ctrl sample i.e. zero input
    X[i, (i-1) % len(parameterizedModel.network.inName)] = np.random.rand(1) #stimulate each receptor at least once
    X[i, np.random.randint(0, len(parameterizedModel.network.inName), simultaniousInput-1)] = np.random.rand(simultaniousInput-1)
# compared X with pytorch version

In this code the `Yfull` is used for regularization. This is the state variable after the recurrent NN layer, but before the output layer. It contains the state of all the proteins in the network. 

rewrote the `call` function of the model to return both, but was not  sure how to use vmap if there are multiple outputs, but it seems simple. 

In [6]:
controlIndex = 0
# Y, YfullRef = parameterizedModel(X)
Y, YfullRef = jax.vmap(parameterizedModel.model, in_axes=(0),out_axes=(0,0))(X)
print(Y.shape)
print(YfullRef.shape)



Y_old = jax.vmap(parameterizedModel.model.call_old, in_axes=(0),out_axes=(0))(X)
print(Y_old.shape)
assert((Y_old == Y).all())

input shape:(101,)
(10, 1, 55)
(10, 1, 409)
input shape:(101,)
(10, 1, 55)


In [7]:
#Setup optimizer
noiseLevel = 10
batchSize = 5
MoAFactor = 0.1
spectralFactor = 1e-3
maxIter = 10000

I used the `dataloader_inf` for the toy model. It reuses the training data if the batch `size * n_steps` is larger than the available training data. 

but we should use the `dataloader_finite` for the comparison with Avlant's approach. This use each data only once. 

In [8]:
# pytorch has a data loader, but jax does not.
# trainloader = torch.utils.data.DataLoader(range(N), batch_size=batchSize, shuffle=True)

# here we take the example from https://docs.kidger.site/equinox/examples/train_rnn/
# from equinox

# this is another version of dataloader, which can be emptied. 
import jax.random as jrandom

# 
def dataloader_inf(arrays, batch_size, *, key):
    """ Create an infinite dataloader from data. 
    
    It goes through the data and returns batches of size batch_size.
    When it finishes, it starts again from the beginning.

    Args:
        arrays: list of arrays to be batched.
        batch_size: size of the batch.
        key: random key.
    """
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def dataloader_finite(arrays, batch_size, *, key):
    """ Create an finite dataloader from data. 
    
    It goes through the data and returns batches of size batch_size, finally, 
    it returns the last batch if the data is not divisible by batch_size.

    Args:
        arrays: list of arrays to be batched.
        batch_size: size of the batch.
        key: random key.
    """
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)

    perm = jrandom.permutation(key, indices)
    (key,) = jrandom.split(key, 1)
    start = 0
    end = batch_size
    while end < dataset_size:
        batch_perm = perm[start:end]
        yield tuple(array[batch_perm] for array in arrays)
        start = end
        end = start + batch_size
    
    if start < dataset_size:
        batch_perm = perm[start:]
        yield tuple(array[batch_perm] for array in arrays)


def dataIndexloader(arrays, batch_size, *, key):
    """ Create the indices used in the batch.
    
    Args:
        arrays: list of arrays to be batched.
        batch_size: size of the batch.
        key: random key.
    """
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)

    perm = jrandom.permutation(key, indices)
    (key,) = jrandom.split(key, 1)
    start = 0
    end = batch_size
    while end < dataset_size:
        batch_perm = perm[start:end]
        yield batch_perm
        start = end
        end = start + batch_size
    
    if start < dataset_size:
        batch_perm = perm[start:]
        yield batch_perm


## Check how to rewrite the `uniformLoss` 

In [9]:
def uniformLoss(curState, dataIndex, YhatFull, targetMin = 0, targetMax = 0.99, maxConstraintFactor = 10):
    #data = curState.detach().clone()
    #data[dataIndex, :] = YhatFull
    loss = uniformLossBatch(curState, targetMin = targetMin, targetMax = targetMax, maxConstraintFactor = maxConstraintFactor)
    return loss

def uniformLossBatch(YhatFull, targetMin = 0, targetMax = 0.99, maxConstraintFactor = 10):
    targetMean = (targetMax-targetMin)/2
    targetVar = (targetMax-targetMin)**2/12

    factor = 1
    meanFactor = factor
    varFactor = factor
    minFactor = factor
    maxFactor = factor
    maxConstraintFactor = factor * maxConstraintFactor

    nodeMean = jnp.mean(YhatFull, axis=0)
    nodeVar = jnp.mean(jnp.square(YhatFull-nodeMean), axis=0)
    maxVal = jnp.amax(YhatFull, axis=0)
    minVal = jnp.amin(YhatFull, axis=0)

    meanLoss = meanFactor * jnp.sum(jnp.square(nodeMean - targetMean))
    varLoss =  varFactor * jnp.sum(jnp.square(nodeVar - targetVar))
    maxLoss = maxFactor * jnp.sum(jnp.square(maxVal - targetMax))
    minloss = minFactor * jnp.sum(jnp.square(minVal- targetMin))
    maxConstraint = -maxConstraintFactor * jnp.sum(maxVal[maxVal<=0]) #max value should never be negative

    loss = meanLoss + varLoss + minloss + maxLoss + maxConstraint
    return loss



testing `uniformLoss`

In [10]:
batch_size = 5
maxIter = 100
loader_key = jax.random.PRNGKey(142)
progress = nn_models.rnnModel.OptimProgress(maxIter)


In [11]:
iter_data = dataIndexloader((X, Y), batch_size, key=loader_key)
dataIndex = next(iter_data)
x = X[dataIndex, :]
y = Y[dataIndex, :]
print(x.shape)
(pred_y, pred_yFull) = jax.vmap(Model.model, out_axes=(0,0))(x)
#currentState = jax.random.uniform(shape=(N, Model.model.layers[1].biases.shape[0]),key=statekey)
currentState = jnp.zeros(shape=(N, 1, Model.model.layers[1].biases.shape[0]))

print(currentState.shape)
print(pred_yFull.shape)
stateLossFactor = 1e-4

currentState = currentState.at[dataIndex,1,:].set(jnp.squeeze(pred_yFull))

# State constraints: 
stateLoss = stateLossFactor * uniformLoss(currentState, dataIndex, pred_yFull, maxConstraintFactor = 50)
stateLoss


(5, 101)
input shape:(101,)
(10, 1, 409)
(5, 1, 409)


DeviceArray(0.05038045, dtype=float32)

## Checking spectral loss computation

In [12]:
#@jax.jit
def oneStepDeltaActivationFactor(x, leak):
    """ derivative of the activation function F(x):
        x < 0:        leak 
        0 < x < 0.5 : 1
        0.5 < x :      0.25/x^2
    """
    y = jnp.ones(x.shape[1]) #derivative = 1 if nothing else is stated
    y = jnp.where(x <= 0, leak, y)  #let derivative be 0.01 at x=0
    
    y = jnp.where(x > 0.5, 0.25/(x**2), y)
    return y

def lreig(A):
    # for DENSE Matrix
    #fall back if eigs fails
    e, w, v = sp.linalg.eig(A, left = True)
    selected = np.argmax(np.abs(e))
    eValue = e[selected]
    # selected = (e == eValue)

    # if numpy.sum(selected) == 1:
    w = w[:,selected]
    v = v[:,selected]
    # else:
    #     w = numpy.sum(w[:,selected], axis=1, keepdims=True)
    #     v = numpy.sum(v[:,selected], axis=1, keepdims=True)
    #     w = w/norm(w)
    #     v = v/norm(v)
    return eValue, v, w

# A: sparse matrix
@jax.custom_jvp
def getspectralRadius(weights, ind, M):
    
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    tolerance = 10**-6

    try:
        e, v = sp.sparse.linalg.eigs(A, k=1, which='LM', ncv=100, tol = tolerance)
        v = v[:,0]
        e = e[0]
    except  (KeyboardInterrupt, SystemExit):
        raise
    except:
        print('Forward fail (did not find any eigenvalue with eigs)')
        tmpA = A.toarray()
        e, v, w = lreig(tmpA) #fall back to solving full eig problem

    spectralRadius = np.abs(e)
    #ctx.e = e
    #ctx.v = v
    #ctx.w = np.empty(0)

    return spectralRadius

def getspectralRadius_save_out(weights, ind, M):
    
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    tolerance = 10**-6

    try:
        e, v = sp.sparse.linalg.eigs(A, k=1, which='LM', ncv=100, tol = tolerance)
        v = v[:,0]
        e = e[0]
    except  (KeyboardInterrupt, SystemExit):
        raise
    except:
        print('Forward fail (did not find any eigenvalue with eigs)')
        tmpA = A.toarray()
        e, v, w = lreig(tmpA) #fall back to solving full eig problem

    spectralRadius = np.abs(e)
    e = e
    v = v
    w = np.empty(0)

    return spectralRadius, e, v, w

@getspectralRadius.defjvp
def AgetspectralRadius_jvp(primals, tangents):
    weights, ind, M = primals
    w_dot, ind_dot, M_dot  = tangents
    primal_out, e, v, w = getspectralRadius_save_out(weights, ind, M)

    tolerance = 10**-6
    networkList = ind
    A = sp.sparse.csr_matrix((weights, ind), shape=M.shape, dtype='float32')
    
    tmpA = A
    tmpA = tmpA.T  #tmpA.T.toarray()

    if w.shape[0]==0:
        try:
            eT = e
            if np.isreal(eT): #does for some reason not converge if imag = 0
                eT = eT.real
            e2, w = sp.sparse.linalg.eigs(tmpA, k=1, sigma=eT, OPpart='r', tol=tolerance)
            selected = 0 #numpy.argmin(numpy.abs(e2-eT))
            w = w[:,selected]
            e2 = e2[selected]
            #Check if same eigenvalue
            if abs(e-e2)>(tolerance*10):
                print('Backward fail (eigs left returned different eigenvalue)')
                w = np.empty(0)
                #e, v, w = lreig(tmpA) #fall back to solving whole eig problem
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            print('Backward fail (did not find any eigenvalue with eigs)')
            #e, v, w = lreig(tmpA) #fall back to solving full eig problem
            delta = np.zeros(weights.shape)


    if w.shape[0] != 0:
        divisor = w.T.dot(v).flatten()
        if abs(divisor) == 0:
            delta = np.zeros(weights.shape)
            print('Empty eig')
        else:
            delta = np.multiply(w[networkList[0]], v[networkList[1]])/divisor
            direction = e/np.abs(e)
            delta = (delta/direction).real
    else:
        #print('Empty eig')
        delta = np.zeros(weights.shape)

    #deltaFilter = numpy.not_equal(numpy.sign(delta), numpy.sign(ctx.weights))
    #delta[deltaFilter] = 0

    #delta = torch.tensor(delta, dtype = grad_output.dtype)

    constrainNorm = True
    if constrainNorm:
        norm = np.linalg.norm(delta)
        if norm>10:
            delta = delta/norm #typical seems to be ~0.36
        #delta = delta * numpy.abs(ctx.weights)
        #delta = delta/norm(delta)

    tangent_out = jnp.dot(delta,w_dot)

    return primal_out, tangent_out


In [17]:

#@jax.jit
def spectralLoss(model, YhatFull, expFactor = 20, lb=0.5, spectralTarget =0.95):
    """ spectral loss of the transmission function:
    
    check Fig3, panel A
    derivativeOfActivation(steadystate) * A 
    """
    
    print("spectral loss compiled")

    # selects a random condition to be used in the spectral computation:
    randomIndex = np.random.randint(YhatFull.shape[0])

    # evaluate the derivative of the activation function at the steady state:
    activationFactor = oneStepDeltaActivationFactor(YhatFull[randomIndex,:], model.layers[1].leak)
    weightFactor = activationFactor.flatten()[model.layers[1].networkList[0]]
    multipliedWeightFactor =  model.layers[1].weights * weightFactor
    
    M = jnp.zeros(shape=model.layers[1].A.shape)
    ind = model.layers[1].networkList
    
    spectralRadius = getspectralRadius(multipliedWeightFactor, ind, M)
    
    scaleFactor = 1/jnp.exp(expFactor *  spectralTarget)
    
    #spectralRadiusLoss = scaleFactor * spectralRadius
    #spectralRadiusLoss = scaleFactor * (jnp.exp(expFactor*spectralRadius)-1)

    # https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
    spectralRadiusLoss = jnp.where(spectralRadius > lb, 
                                   scaleFactor * (jnp.exp(expFactor*spectralRadius)-1),
                                   jnp.zeros(1))[0]
    # if spectralRadius>lb:
    #     spectralRadiusLoss = scaleFactor * (jnp.exp(expFactor*spectralRadius)-1)
    # else:
    #     spectralRadiusLoss = 0.0

    return spectralRadiusLoss

In [18]:
# Test the spectral loss function
iter_data = dataIndexloader((X, Y), batch_size, key=loader_key)
dataIndex = next(iter_data)
x = X[dataIndex, :]
y = Y[dataIndex, :]
M = jnp.zeros(shape = Model.model.layers[1].A.shape)

(pred_y, pred_yFull) = jax.vmap(Model.model, out_axes=(0,0))(x)

spectralLoss(Model.model, pred_yFull, expFactor = 21,spectralTarget = Model.trainingParameters.spectralTarget)

input shape:(101,)
spectral loss compiled


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[832])>with<DynamicJaxprTrace(level=0/2)>
While tracing the function spectralLoss at /var/folders/cx/9kyr3rt90c974wdygym_lhgh0000gn/T/ipykernel_61193/747990753.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'model'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [190]:
v,g = jax.value_and_grad(spectralLoss)(Model.model, pred_yFull, expFactor = 21,spectralTarget =Model.trainingParameters.spectralTarget)
print(v)
g.layers[1].weights[1:100]

spectral loss compiled
6.831739e-06


DeviceArray([ 1.04067749e-13,  3.01219045e-14, -1.19719213e-16,
              4.39322770e-17,  1.30915325e-07,  5.25309861e-06,
              3.13666715e-06, -3.93654389e-08,  3.93511891e-06,
              5.26099448e-15, -2.22454712e-08, -1.61918603e-08,
             -1.00334593e-11, -9.72264935e-09, -7.02504055e-10,
             -7.07684533e-09, -4.38524304e-12, -3.90130282e-07,
             -2.53217669e-09, -6.85135074e-11, -3.11517612e-09,
             -6.05953403e-07, -2.44680471e-08, -3.65580307e-21,
              2.77452926e-14, -3.16464963e-07, -5.34783145e-16,
             -7.64424704e-11,  6.10337247e-06,  8.33020181e-11,
              3.45965788e-07,  9.09419415e-11,  2.22807744e-10,
              3.97434206e-07, -1.79171948e-15, -7.10606644e-16,
              5.87045739e-14, -1.39061841e-12,  6.47419726e-13,
              1.02864556e-06,  4.25671244e-13, -2.31076578e-13,
             -1.38507344e-16, -2.81680366e-14,  4.67178921e-14,
              4.45526557e-14,  1.0464290

## Put together all the loss functions

In [191]:
# freeze layer: https://docs.kidger.site/equinox/examples/frozen_layer/
# We copy the pyTree and indicate that we want to estimate the parmeters only on the recurrent layer. 
# we use the generated pyTree to filter the gradient. 
filter_spec = jtu.tree_map(lambda _: False, Model.model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.layers[1].weights, tree.layers[1].biases),
    filter_spec,
    replace=(True, True),
)
#Setup optimizer
MoAFactor = 0.1
L2beta = 1e-8
spectralFactor = 0
ligandFactor = 1e-5 #0.1 in toy
stateLossFactor = 1e-4

#@jax.jit
def criterion(pred_y, y):
    # Trains with respect to binary cross-entropy
    return jnp.mean((y - pred_y) ** 2)

#@jax.jit
def criterion_mat(pred_y, y):
    # Trains with respect to binary cross-entropy
    return jnp.mean((y - pred_y) ** 2)


# jax/optax alternative: 
#! @eqx.filter_jit
# @ft.partial(eqx.filter_value_and_grad, arg=filter_spec)
@eqx.filter_value_and_grad
def compute_loss(model, x, y, dataIndex, currentState):
    
    (pred_y, pred_yFull) = jax.vmap(model, out_axes=(0,0))(x)

    currentState = currentState.at[dataIndex,1,:].set(jnp.squeeze(pred_yFull))

    # Avlant adds random noise between input layer and the recurrent layer. 
    fitLoss = criterion(pred_y,y)

    # signConstraints
    violated_weights = model.layers[1].getViolations(model.layers[1].weights)
    signConstraint = MoAFactor * jnp.sum(jnp.abs(jnp.where(violated_weights,model.layers[1].weights,0)))
    
    # ligand constraints: 
    ligandConstraint = ligandFactor * jnp.sum(jnp.square(model.layers[1].biases[model.layers[0].inOutIndices,0]))
    
    # State constraints: 
    stateLoss = stateLossFactor * uniformLoss(currentState, dataIndex, pred_yFull, maxConstraintFactor = 50)

    # Parameter costraints: 
    biasLoss = L2beta * jnp.sum(jnp.square(model.layers[1].biases))
    weightLoss = L2beta * jnp.sum(jnp.square(model.layers[1].weights))
    
    spectralRadiusLoss = spectralFactor * spectralLoss(model, pred_yFull, expFactor = 21)

    # sum of constraints:
    # loss = fitLoss + signConstraint + ligandConstraint + weightLoss + biasLoss + spectralFactor * spectralRadiusLoss + stateLoss
    loss = fitLoss + signConstraint + ligandConstraint + weightLoss + biasLoss + stateLoss + spectralRadiusLoss 

    return loss



In [192]:
#compute_loss_vg = jax.value_and_grad(compute_loss,has_aux=True)

# we have to use the partitioning trick for the make_step
#! @ft.partial(jax.jit, static_argnums=1)
def make_step(params,static, x, y, opt_state, dataIndex, currentState):
    model = eqx.combine(params, static)
    loss, grads = compute_loss(model, x, y, dataIndex, currentState)
    
    #print("grads:")
    #print(grads.layers[1].weights)
    
    updates, opt_state = optim.update(grads, opt_state)
    params = eqx.apply_updates(params, updates)
    return loss, params, opt_state, currentState


# Whole optimization

In [162]:
def oneCycle(e, maxIter, maxHeight = 1e-3, startHeight=1e-5, endHeight=1e-5, minHeight = 1e-7, peak = 1000):
    phaseLength = 0.95 * maxIter
    if e<=peak:
        effectiveE = e/peak
        lr = (maxHeight-startHeight) * 0.5 * (np.cos(np.pi*(effectiveE+1))+1) + startHeight
    elif e<=phaseLength:
        effectiveE = (e-peak)/(phaseLength-peak)
        lr = (maxHeight-endHeight) * 0.5 * (np.cos(np.pi*(effectiveE+2))+1) + endHeight
    else:
        lr = endHeight
    return lr

In [1]:
optim = optax.inject_hyperparams(optax.adam)(learning_rate=1)
opt_state = optim.init(Model.model)

resetState = opt_state # state is a tuple, copy it by assignmment
params, static = eqx.partition(Model.model, eqx.is_array)

maxIter = 100
batch_size = 5
loader_key = jax.random.PRNGKey(142)
loader_key, statekey = jax.random.split(loader_key)

progress = nn_models.rnnModel.OptimProgress(maxIter)
steps = np.round(X.shape[0]/batch_size).astype(int)

currentState = jax.random.uniform(shape=(N, 1, Model.model.layers[1].biases.shape[0]),key=statekey)

for e in range(maxIter):

    currentLoss = []
    learning_rate_e = oneCycle(e, maxIter, maxHeight = 2e-3, minHeight = 1e-8, peak = 200)
    opt_state.hyperparams['learning_rate'] = learning_rate_e
    
    curLoss = []
    curEig = []
    iter_data = dataIndexloader((X, Y), batch_size, key=loader_key)

    for dataIndex in iter_data:
        x = X[dataIndex, :]
        y = Y[dataIndex, :]

        loss, params, opt_state, currentState = make_step(params, static, x, y, opt_state, dataIndex, currentState)
        
        currentLoss.append(loss.item())

    #fitLoss = compute_test(params, static, Xtest, Ytest)
    trained_model = eqx.combine(params, static)
    

    # compute progress on the validation set:     
    #Yhat = jax.vmap(trained_model)(Xtest)
    #fitLoss = criterion_mat(Yhat,Ytest)
    
    progress.stats['violations'][e] = np.sum(trained_model.layers[1].getViolations()).item()
    #progress.stats['test'][e] = fitLoss.item()
    progress.storeProgress(loss=currentLoss, lr=learning_rate_e, violations=np.sum(trained_model.layers[1].getViolations(trained_model.layers[1].weights)).item())

    if e % 10 == 0:
        progress.printStats(e)

    if np.logical_and(e % 100 == 0, e>0):
        opt_state = resetState




NameError: name 'optax' is not defined