In [None]:
from moscot.problems.generic._generic import ConditionalNeuralProblem
import jax.numpy as jnp
import scanpy as sc
import pickle as pkl
import jax
from pathlib import Path
import sys

In [None]:
neural_problem = ConditionalNeuralProblem(
    adata_train,
    embedding_data=embedding_data,
)

neural_problem.prepare(
    key="cov_drug", joint_attr="X_pca", policy="explicit", subset=subset
)

print("INFO: Training the model")
neural_problem.solve(
    cond_dim=494,
    embedding_data=embedding_data,
    best_model_metric=None,
    iterations=100000,
    train_set=1.0,
    valid_freq=5000000,
    compute_wasserstein_baseline=False,
    pretrain_iters=0,
)

In [None]:
batch_predictor = push_results = jax.vmap(
    lambda x, cond: neural_problem.solution.push(x=x, cond=cond),
)

jitted_batch_predictor = jax.jit(batch_predictor)


for name, adata in zip(["ood", "test"], [adata_ood, adata_test]):
    print(
        f"INFO: Evaluating on {adata.obs['split_ood_finetuning'].unique()}"
        f"(with {adata.shape[0]} cells and {len(adata.obs['cov_drug'].unique())} conditions)"
    )
    for cell_line_condition in adata.obs["cov_drug"].unique():
        cell_line, condition = cell_line_condition.split("_")
        print(f"INFO: Evaluating {cell_line}_{condition}")
        if condition == "control":
            continue
        
        if name == "ood":
            source_gex = adata_train[
                (adata_train.obs["cell_type"] == cell_line)
                & (adata_train.obs["condition"] == "control")
            ].obsm["X_pca"]
        else:
            source_gex = adata[
                (adata.obs["cell_type"] == cell_line)
                & (adata.obs["condition"] == "control")
            ].obsm["X_pca"]

        try:
            embedding = jnp.hstack(
                [embedding_data[cell_line], embedding_data[condition]]
            )
        except KeyError:
            print(f"Skipping {cell_line}_{condition}")
            continue

        encoded_condition = jnp.expand_dims(embedding, axis=0)
        repeated_condition = jnp.repeat(
            encoded_condition, source_gex.shape[0], axis=0
        )

        predicted_gex = jitted_batch_predictor(
            source_gex, repeated_condition
        ).reshape(-1, 25)
