In [6]:
import cellbox
import os
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
import shutil
import argparse
import json
import glob
import time
from tensorflow.compat.v1.errors import OutOfRangeError
from cellbox.utils import TimeLogger
import pickle
tf.disable_v2_behavior()

import torch
import torch.nn as nn

In [4]:
seeds = [7, 87, 62, 45, 23]
lambdas = [(2.0, 3.0), (0.1, 0.01), (0.001, 0.0001), (0.01, 0.1), (0.0001, 0.001)]
others = [
    ("tanh", "heun", 100, "by u"),
    ("tanh", "euler", 100, "by u"),
    ("clip linear", "heun", 100, "fix x"),
    ("tanh", "heun", 100, "fix x"),
    ("tanh", "midpoint", 100, "fix x")
]

### Initialize weights with correct masks

In [25]:
def weight_init(seed, n_x=99, n_protein_nodes=82, n_activity_nodes=87):
    np.random.seed(seed)
    W = np.random.normal(0.01, 1.0, size=(n_x, n_x))
    W_mask = (1.0 - np.diag(np.ones([n_x])))
    W_mask[n_activity_nodes:, :] = np.zeros([n_x - n_activity_nodes, n_x])
    W_mask[:, n_protein_nodes:n_activity_nodes] = np.zeros([n_x, n_activity_nodes - n_protein_nodes])
    W_mask[n_protein_nodes:n_activity_nodes, n_activity_nodes:] = np.zeros([n_activity_nodes - n_protein_nodes,
                                                                            n_x - n_activity_nodes])

    return W*W_mask

In [56]:
a = np.zeros((89, 99))
b = a[[1, 3, 5]]
b.shape

(3, 99)

### Initialize input matrix

In [57]:
def input_init(seed):
    # Initialize the input by taking a slice of the actual input
    np.random.seed(seed)
    real_inp = pd.read_csv("/users/ngun7t/Documents/cellbox-jun-6/data/pert.csv", header=None).to_numpy()
    real_out = pd.read_csv("/users/ngun7t/Documents/cellbox-jun-6/data/expr.csv", header=None).to_numpy()
    rand_ind = np.random.choice(list(range(89)), replace=False, size=(4,)).tolist()
    inp = real_inp[rand_ind]
    out = real_out[rand_ind]
    return inp, out


In [58]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
tf.compat.v1.set_random_seed(seed)

W_rand = weight_init(seed)
inp, out = input_init(seed)

(89, 99)
[44, 53, 30, 12]


### Make tensorflow model, initialize the weights, and make a forward pass with this

In [83]:
def set_seed(in_seed):
    int_seed = int(in_seed)
    tf.compat.v1.set_random_seed(int_seed)
    np.random.seed(int_seed)


def prepare_workdir(in_cfg):
    # Read Data
    in_cfg.root_dir = os.getcwd()
    in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None) \
        if hasattr(in_cfg, 'node_index_file') else pd.DataFrame(np.arange(in_cfg.n_x))

    # Create Output Folder
    experiment_path = 'results/{}_{}'.format(in_cfg.experiment_id, md5)
    try:
        os.makedirs(experiment_path)
    except Exception:
        pass
    out_cfg = vars(in_cfg)
    out_cfg = {key: out_cfg[key] for key in out_cfg if type(out_cfg[key]) is not pd.DataFrame}
    os.chdir(experiment_path)
    json.dump(out_cfg, open('config.json', 'w'), indent=4)

    if "leave one out" in in_cfg.experiment_type:
        try:
            in_cfg.model_prefix = '{}_{}'.format(in_cfg.model_prefix, in_cfg.drug_index)
        except Exception('Drug index not specified') as e:
            raise e

    in_cfg.working_index = in_cfg.model_prefix + "_" + str(working_index).zfill(3)

    try:
        shutil.rmtree(in_cfg.working_index)
    except Exception:
        pass
    os.makedirs(in_cfg.working_index)
    os.chdir(in_cfg.working_index)

    with open("record_eval.csv", 'w') as f:
        f.write("epoch,iter,train_loss,valid_loss,train_mse,valid_mse,test_mse,time_elapsed\n")

    print('Working directory is ready at {}.'.format(experiment_path))
    return 0

experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.random_partition.CellBox.json"
working_index = 0
stage = {
    "nT": 100,
    "sub_stages":[
        {"lr_val": 0.1,"l1lambda": 0.01, "n_iter_patience":1000},
        {"lr_val": 0.01,"l1lambda": 0.01},
        {"lr_val": 0.01,"l1lambda": 0.0001},
        {"lr_val": 0.001,"l1lambda": 0.00001}
    ]}

cfg = cellbox.config.Config(experiment_config_path)
cfg.ckpt_path_full = os.path.join('./', cfg.ckpt_name)
md5 = cellbox.utils.md5(cfg)
cfg.drug_index = 5         # Change this for testing purposes
cfg.seed = seed
set_seed(seed)
print(vars(cfg))

prepare_workdir(cfg)
logger = cellbox.utils.TimeLogger(time_logger_step=1, hierachy=3)
args = cfg
for i, stage in enumerate(cfg.stages):
    set_seed(cfg.seed)
    cfg = cellbox.dataset.factory(cfg)
    args.sub_stages = stage['sub_stages']
    args.n_T = stage['nT']
    model = cellbox.model.factory(args)
    if i == 0: break

{'experiment_id': 'Example_RP', 'model_prefix': 'seed', 'ckpt_name': 'model11.ckpt', 'export_verbose': 3, 'experiment_type': 'random partition', 'sparse_data': False, 'batchsize': 4, 'trainset_ratio': 0.7, 'validset_ratio': 0.8, 'n_batches_eval': None, 'add_noise_level': 0, 'dT': 0.1, 'ode_solver': 'heun', 'envelope_form': 'tanh', 'envelope': 0, 'pert_form': 'by u', 'ode_degree': 1, 'ode_last_steps': 2, 'n_iter_buffer': 50, 'n_iter_patience': 100, 'weight_loss': 'None', 'l1lambda': 0.0001, 'l2lambda': 0.0001, 'model': 'CellBox', 'pert_file': '/users/ngun7t/Documents/cellbox-jun-6/data/pert.csv', 'expr_file': '/users/ngun7t/Documents/cellbox-jun-6/data/expr.csv', 'node_index_file': '/users/ngun7t/Documents/cellbox-jun-6/data/node_Index.csv', 'n_protein_nodes': 82, 'n_activity_nodes': 87, 'n_x': 99, 'envelop_form': 'tanh', 'envelop': 0, 'n_epoch': 100, 'n_iter': 100, 'stages': [{'nT': 100, 'sub_stages': [{'lr_val': 0.1, 'l1lambda': 0.01, 'n_iter_patience': 1000}, {'lr_val': 0.01, 'l1lamb

In [39]:
inp.shape

(89,)

In [80]:
from cellbox.utils import loss, optimize

class PertBio:
    """define abstract perturbation model"""
    def __init__(self, args):
        self.args = args
        self.n_x = args.n_x
        self.pert_in, self.expr_out = args.pert_in, args.expr_out
        self.iter_train, self.iter_monitor, self.iter_eval = args.iter_train, args.iter_monitor, args.iter_eval
        self.train_x, self.train_y = self.iter_train.get_next()
        self.monitor_x, self.monitor_y = self.iter_monitor.get_next()
        self.eval_x, self.eval_y = self.iter_eval.get_next()
        self.l1_lambda, self.l2_lambda = self.args.l1_lambda_placeholder, self.args.l2_lambda_placeholder
        self.train_y0, self.monitor_y0, self.eval_y0 = None, None, None
        self.lr = self.args.lr

    def get_ops(self):
        """get operators for tensorflow"""
        if self.args.weight_loss == 'expr':
            self.train_loss, self.train_mse_loss = loss(self.train_y, self.train_yhat, self.params['W'],
                                                        self.l1_lambda, self.l2_lambda, weight=self.train_y)
            self.monitor_loss, self.monitor_mse_loss = loss(self.monitor_y, self.monitor_yhat, self.params['W'],
                                                            self.l1_lambda, self.l2_lambda, weight=self.monitor_y)
            self.eval_loss, self.eval_mse_loss = loss(self.eval_y, self.eval_yhat, self.params['W'],
                                                      self.l1_lambda, self.l2_lambda, weight=self.eval_y)
        elif self.args.weight_loss == 'None':
            self.train_loss, self.train_mse_loss = loss(self.train_y, self.train_yhat, self.params['W'],
                                                        self.l1_lambda, self.l2_lambda)
            self.monitor_loss, self.monitor_mse_loss = loss(self.monitor_y, self.monitor_yhat, self.params['W'],
                                                            self.l1_lambda, self.l2_lambda)
            self.eval_loss, self.eval_mse_loss = loss(self.eval_y, self.eval_yhat, self.params['W'],
                                                      self.l1_lambda, self.l2_lambda)
        
        self.op_optimize = optimize(self.train_loss, self.lr)

    def get_variables(self):
        """get model parameters (overwritten by model configuration)"""
        raise NotImplementedError

    def forward(self, x, mu):
        """forward propagation (overwritten by model configuration)"""
        raise NotImplementedError

    def build(self):
        """build model"""
        self.params = {}
        self.get_variables()
        self.train_yhat = self.forward(self.train_y0, self.train_x)
        self.monitor_yhat = self.forward(self.monitor_y0, self.monitor_x)
        self.eval_yhat = self.forward(self.eval_y0, self.train_x)
        self.get_ops()
        return self


class CellBox(PertBio):
    """CellBox model"""
    def build(self, W_rand, inp, out):
        self.W_rand = W_rand
        self.inp = inp
        self.out = out
        self.params = {}
        self.get_variables()
        self.train_x = tf.constant(self.inp, name="sample_input", dtype=tf.float32)
        self.train_y = tf.constant(self.out, name="sample_output", dtype=tf.float32)
        if self.args.pert_form == 'by u':
            y0 = tf.constant(np.zeros((self.n_x, 1)), name="x_init", dtype=tf.float32)
            self.train_y0 = y0
            self.monitor_y0 = y0
            self.eval_y0 = y0
            self.gradient_zero_from = None
        elif self.args.pert_form == 'fix x':  # fix level of node x (here y) by input perturbation u (here x)
            self.train_y0 = tf.transpose(self.train_x)
            self.monitor_y0 = tf.transpose(self.monitor_x)
            self.eval_y0 = tf.transpose(self.eval_x)
            self.gradient_zero_from = self.args.n_activity_nodes

        # ODE-specific params
        self.envelope_fn = cellbox.kernel.get_envelope(self.args)
        self.ode_solver = cellbox.kernel.get_ode_solver(self.args)
        self._dxdt = cellbox.kernel.get_dxdt(self.args, self.params)
        self.convergence_metric_train, self.train_yhat = self.forward(self.train_y0, self.train_x)
        self.convergence_metric_monitor, self.monitor_yhat = self.forward(self.monitor_y0, self.monitor_x)
        self.convergence_metric_eval, self.eval_yhat = self.forward(self.eval_y0, self.eval_x)
        self.get_ops()
        return self

    def forward(self, y0, mu):
        if isinstance(mu, tf.SparseTensor):
            mu_t = tf.sparse.to_dense(tf.sparse.transpose(mu))
        else:
            mu_t = tf.transpose(mu)
        ys = self.ode_solver(y0, mu_t, self.args.dT, self.args.n_T, self._dxdt, self.gradient_zero_from)
        # [n_T, n_x, batch_size]
        ys = ys[-self.args.ode_last_steps:]
        # [n_iter_tail, n_x, batch_size]
        mean, sd = tf.nn.moments(ys, axes=0)
        yhat = tf.transpose(ys[-1])
        dxdt = self._dxdt(ys[-1], mu_t)
        # [n_x, batch_size] for last ODE step
        convergence_metric = tf.concat([mean, sd, dxdt], axis=0)
        return convergence_metric, yhat

    def get_variables(self):
        """
        Initialize parameters in the Hopfield equation

        Mutates:
            self.params(dict):{
                W (tf.Variable): interaction matrix with constraints enforced, , shape: [n_x, n_x]
                alpha (tf.Variable): alpha, shape: [n_x, 1]
                eps (tf.Variable): eps, shape: [n_x, 1]
            }
        """
        n_x, n_protein_nodes, n_activity_nodes = self.n_x, self.args.n_protein_nodes, self.args.n_activity_nodes
        with tf.compat.v1.variable_scope("initialization", reuse=True):
            """
               Enforce constraints  (i: recipient)
               no self regulation wii=0
               ingoing wij for drug nodes (88th to 99th) = 0 [n_activity_nodes 87: ]
                                w [87:99,_] = 0
               outgoing wij for phenotypic nodes (83th to 87th) [n_protein_nodes 82 : n_activity_nodes 87]
                                w [_, 82:87] = 0
               ingoing wij for phenotypic nodes from drug ndoes (direct) [n_protein_nodes 82 : n_activity_nodes 87]
                                w [82:87, 87:99] = 0
            """
            #W = tf.Variable(np.random.normal(0.01, size=(n_x, n_x)), name="W", dtype=tf.float32)
            W = tf.Variable(self.W_rand, name="W", dtype=tf.float32)
            W_mask = (1.0 - np.diag(np.ones([n_x])))
            W_mask[n_activity_nodes:, :] = np.zeros([n_x - n_activity_nodes, n_x])
            W_mask[:, n_protein_nodes:n_activity_nodes] = np.zeros([n_x, n_activity_nodes - n_protein_nodes])
            W_mask[n_protein_nodes:n_activity_nodes, n_activity_nodes:] = np.zeros([n_activity_nodes - n_protein_nodes,
                                                                                    n_x - n_activity_nodes])
            self.params['W'] = W_mask * W

            eps = tf.Variable(np.ones((n_x, 1)), name="eps", dtype=tf.float32)
            alpha = tf.Variable(np.ones((n_x, 1)), name="alpha", dtype=tf.float32)
            self.params['alpha'] = tf.nn.softplus(alpha)
            self.params['eps'] = tf.nn.softplus(eps)

            if self.args.envelope == 2:
                psi = tf.Variable(np.ones((n_x, 1)), name="psi", dtype=tf.float32)
                self.params['psi'] = tf.nn.softplus(psi)

In [61]:
model = CellBox(args).build()
lr = 0.1
l1_lambda = 0.1
l2_lambda = 0.01
iterations = 10

In [62]:
sess = tf.compat.v1.Session()
sess.run(tf.global_variables_initializer())
sess.run(model.iter_train.initializer, feed_dict=args.feed_dicts['train_set'])
sess.run(model.iter_monitor.initializer, feed_dict=args.feed_dicts['valid_set'])
train_input = args.iter_train.get_next()[0].eval(session=sess)
# train_input is only a placeholder. The actual input the model uses is defined within the model's function above
yhat = sess.run(model.train_yhat, feed_dict={args.pert_in: np.ones((4, 99)), args.expr_out: np.ones((4, 99))})
_, loss_train_i, loss_train_mse_i = sess.run(
                    (model.op_optimize, model.train_loss, model.train_mse_loss), 
                    feed_dict={
                        args.pert_in: np.ones((4, 99)), 
                        args.expr_out: np.ones((4, 99)),
                        model.lr: lr,
                        model.l1_lambda: l1_lambda,
                        model.l2_lambda: l2_lambda
                        }
                    )

2023-08-23 00:05:36.408801: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cm/shared/apps/lsf10/10.1/linux3.10-glibc2.17-x86_64/lib:/data/weirauchlab/opt/lib:/data/weirauchlab/opt/lib64:/data/weirauchlab/local/lib:/users/ngun7t/anaconda3/envs/cellbox-3.6-2/lib/:/users/ngun7t/anaconda3/envs/cellbox-3.6-2/lib/:/users/ngun7t/anaconda3/envs/cellbox-3.6-2/lib/
2023-08-23 00:05:36.408886: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
2023-08-23 00:05:36.409003: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (bmiclusterp2.chmcres.cchmc.org): /proc/driver/nvidia/version does not exist
2023-08-23 00:05:36.410629: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorF

In [64]:
loss_train_mse_i

0.8145397

### Initialize weights for Pytorch models

In [65]:
class ModelConfig(object):

    def __init__(self, model, n_x, envelope_form, ode_solver, n_T):
        self.model = model
        self.n_x = n_x
        self.iter_train, self.iter_monitor, self.iter_eval = None, None, None
        self.lr = 0.1
        self.n_protein_nodes, self.n_activity_nodes = 82, 87
        self.pert_form = "by u"

        self.envelope_form = envelope_form
        self.envelope_fn = None
        self.polynomial_k = 2
        self.ode_degree = 1
        self.envelope = 0
        self.ode_solver = ode_solver
        self.dT = 0.1
        self.ode_last_steps = 2
        self.n_T = n_T
        self.gradient_zero_from = None


In [67]:
# Let's try to get the model from CellBox, with all the necessary configs

args = ModelConfig("CellBox", 99, "tanh", "heun", 100)
torch_cellbox = cellbox.model_torch.factory(args)[0]

for w in torch_cellbox.named_parameters():
    if w[0] == "params.W": w[1].data = torch.tensor(W_rand, dtype=torch.float32)

In [69]:
from cellbox.utils_torch import optimize, loss

lr = 0.1
l1_lambda = 0.1
l2_lambda = 0.01
loss_fn = loss

torch_cellbox.train()
prediction = torch_cellbox(torch.zeros((args.n_x, 1), dtype=torch.float32), torch.tensor(inp, dtype=torch.float32))
convergence_metric, yhat = prediction

for param in torch_cellbox.named_parameters():
    if param[0] == "params.W":
        param_mat = param[1]
        break

loss_train_i_torch, loss_train_mse_i_torch = loss_fn(torch.tensor(out, dtype=torch.float32), yhat, param_mat, l1=l1_lambda, l2=l2_lambda)

In [72]:
print(loss_train_i)
print(loss_train_i_torch.item())

727.26514
727.2650756835938


In [73]:
print(loss_train_mse_i)
print(loss_train_mse_i_torch.item())

0.8145397
0.8145418167114258


### Test cases

#### Input and output matrices

In [95]:
import pickle

seeds = [7, 87, 62, 45, 23]
lambdas = [(2.0, 3.0), (0.1, 0.01), (0.001, 0.0001), (0.01, 0.1), (0.0001, 0.001)]
others = [
    ("tanh", "heun", 100, "by u"),
    ("tanh", "euler", 100, "by u"),
    ("clip linear", "heun", 100, "fix x"),
    ("tanh", "heun", 100, "fix x"),
    ("tanh", "midpoint", 100, "fix x")
]

for seed, lamb, other in zip(seeds, lambdas, others):
    W_rand = weight_init(seed)
    inp, out = input_init(seed)
    data = {
        "seed": seed,
        "W": W_rand,
        "inp": inp,
        "out": out,
        "l1_lambda": lamb[0],
        "l2_lambda": lamb[1],
        "envelope_form": other[0],
        "ode_solver": other[1],
        "n_T": other[2],
        "pert_form": other[3]
    }
    with open(f"/users/ngun7t/Documents/cellbox-jun-6/test_arrays/forward_pass/forward_input_{seed}_{lamb[0]}_{lamb[1]}.pkl", "wb") as f:
        pickle.dump(data, f)

(89, 99)
[13, 51, 17, 20]
(89, 99)
[56, 66, 52, 55]
(89, 99)
[55, 67, 14, 36]
(89, 99)
[83, 0, 25, 1]
(89, 99)
[26, 46, 56, 3]


In [84]:
tf_args = args

In [96]:
# Build Tensorflow's models with those weights
for seed, lamb, other in zip(seeds, lambdas, others):
    with open(f"/users/ngun7t/Documents/cellbox-jun-6/test_arrays/forward_pass/forward_input_{seed}_{lamb[0]}_{lamb[1]}.pkl", "rb") as f:
        data = pickle.load(f)

    print(f"Working on {seed}")

    tf_args.envelope_form = other[0]
    tf_args.ode_solver = other[1]
    tf_args.n_T = other[2]
    tf_args.pert_form = other[3]

    W_rand, l1_lambda, l2_lambda, inp, out = data["W"], data["l1_lambda"], data["l2_lambda"], data["inp"], data["out"]
    lr = 0.1
    model = CellBox(tf_args).build(W_rand, inp, out)
    sess = tf.compat.v1.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(model.iter_train.initializer, feed_dict=tf_args.feed_dicts['train_set'])
    sess.run(model.iter_monitor.initializer, feed_dict=tf_args.feed_dicts['valid_set'])
    train_input = tf_args.iter_train.get_next()[0].eval(session=sess)
    yhat = sess.run(model.train_yhat, feed_dict={tf_args.pert_in: np.ones((4, 99)), tf_args.expr_out: np.ones((4, 99))})
    _, loss_train_i, loss_train_mse_i = sess.run(
                        (model.op_optimize, model.train_loss, model.train_mse_loss), 
                        feed_dict={
                            tf_args.pert_in: np.ones((4, 99)), 
                            tf_args.expr_out: np.ones((4, 99)),
                            model.lr: lr,
                            model.l1_lambda: l1_lambda,
                            model.l2_lambda: l2_lambda
                            }
                        )

    tf_out = {
        "yhat": yhat,
        "loss_train": loss_train_i,
        "loss_train_mse": loss_train_mse_i
    }
    with open(f"/users/ngun7t/Documents/cellbox-jun-6/test_arrays/forward_pass/forward_out_{seed}_{l1_lambda}_{l2_lambda}.pkl", "wb") as f:
        pickle.dump(tf_out, f)


Working on 7
Working on 87
Working on 62
Working on 45
Working on 23


#### Validate this on Pytorch

In [2]:
class ModelConfig(object):

    def __init__(self, model, n_x, envelope_form, ode_solver, n_T):
        self.model = model
        self.n_x = n_x
        self.iter_train, self.iter_monitor, self.iter_eval = None, None, None
        self.lr = 0.1
        self.n_protein_nodes, self.n_activity_nodes = 82, 87
        self.pert_form = "by u"

        self.envelope_form = envelope_form
        self.envelope_fn = None
        self.polynomial_k = 2
        self.ode_degree = 1
        self.envelope = 0
        self.ode_solver = ode_solver
        self.dT = 0.1
        self.ode_last_steps = 2
        self.n_T = n_T
        self.gradient_zero_from = None

torch_args = ModelConfig("CellBox", 99, "tanh", "heun", 100)

In [14]:
yhat_diff = 0.01

for seed, lamb, other in zip(seeds, lambdas, others):
    print(f"Working on {seed=}")
    with open(f"/users/ngun7t/Documents/cellbox-jun-6/test_arrays/forward_pass/forward_input_{seed}_{lamb[0]}_{lamb[1]}.pkl", "rb") as f:
        data = pickle.load(f)
    with open(f"/users/ngun7t/Documents/cellbox-jun-6/test_arrays/forward_pass/forward_out_{seed}_{lamb[0]}_{lamb[1]}.pkl", "rb") as f:
        data_out = pickle.load(f)

    torch_args.envelope_form = other[0]
    torch_args.ode_solver = other[1]
    torch_args.n_T = other[2]
    torch_args.pert_form = other[3]
    torch_cellbox = cellbox.model_torch.factory(torch_args)[0]
    

    for w in torch_cellbox.named_parameters():
        if w[0] == "params.W": w[1].data = torch.tensor(data["W"], dtype=torch.float32)

    l1_lambda = data["l1_lambda"]
    l2_lambda = data["l2_lambda"]
    loss_fn = cellbox.utils_torch.loss

    torch_cellbox.train()
    if torch_args.pert_form == "by u":
        prediction = torch_cellbox(torch.zeros((torch_args.n_x, 1), dtype=torch.float32), torch.tensor(data["inp"], dtype=torch.float32))
    elif torch_args.pert_form == "fix x":
        prediction = torch_cellbox(
            torch.tensor(data["inp"].T, dtype=torch.float32),
            torch.tensor(data["inp"], dtype=torch.float32)
        )
    #prediction = torch_cellbox(torch.zeros((args.n_x, 1), dtype=torch.float32), torch.tensor(data["inp"], dtype=torch.float32))
    convergence_metric, yhat = prediction

    for param in torch_cellbox.named_parameters():
        if param[0] == "params.W":
            param_mat = param[1]
            break

    loss_train_i_torch, loss_train_mse_i_torch = loss_fn(torch.tensor(data["out"], dtype=torch.float32), yhat, param_mat, l1=l1_lambda, l2=l2_lambda)
    print(f">>> TF: {data_out['loss_train']}, Torch: {loss_train_i_torch.item()}")
    print(f">>> TF: {data_out['loss_train_mse']}, Torch: {loss_train_mse_i_torch.item()}")
    print(f">>> Prediction similar: {np.all(np.abs(data_out['yhat'] - yhat.detach().cpu().numpy()) < yhat_diff)}")

Working on seed=7
>>> TF: 36551.5, Torch: 36551.5
>>> TF: 0.8670288920402527, Torch: 0.867030143737793
>>> Prediction similar: True
Working on seed=87
>>> TF: 726.19677734375, Torch: 726.19677734375
>>> TF: 0.8856361508369446, Torch: 0.8856316208839417
>>> Prediction similar: False
Working on seed=62
>>> TF: 8.151704788208008, Torch: 8.151703834533691
>>> TF: 0.8972810506820679, Torch: 0.8972810506820679
>>> Prediction similar: True
Working on seed=45
>>> TF: 863.7166137695312, Torch: 863.7166748046875
>>> TF: 0.7445549368858337, Torch: 0.7445532083511353
>>> Prediction similar: True
Working on seed=23
>>> TF: 9.63156795501709, Torch: 9.631569862365723
>>> TF: 0.8667668104171753, Torch: 0.8667678236961365
>>> Prediction similar: True
