In [1]:
import muon
import ott
import functools
import logging
import typing as t

import anndata as ad
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import scanpy as sc
import scipy.sparse as sp
import scipy.stats as ss
from ott.geometry import costs, geometry, pointcloud
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from sklearn import metrics, model_selection
from ott.geometry import costs as sparse_costs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mudata = muon.read("/lustre/groups/ml01/workspace/moscot_paper/pancreas_revision/mudata_with_annotation_all.h5mu")



In [3]:
endocrine_celltypes = [
    "Ngn3 high",
    "Fev+",
    "Fev+ Alpha",
    "Fev+ Beta",
    "Fev+ Delta",
    "Eps. progenitors",
    "Alpha",
    "Beta",
    "Delta",
    "Epsilon"
]

In [4]:
adata = mudata["rna"]
adata = adata[adata.obs["cell_type"].isin(endocrine_celltypes)]


In [5]:
def adapt_time(x):
    if x["stage"]=="E14.5":
        return 14.5
    if x["stage"]=="E15.5":
        return 15.5
    if x["stage"]=="E16.5":
        return 16.5
    raise ValueError

adata.obs['time'] = adata.obs.apply(adapt_time, axis=1).astype("category")

  adata.obs['time'] = adata.obs.apply(adapt_time, axis=1).astype("category")


In [6]:
adata = adata[adata.obs["time"].isin((15.5, 16.5))]

In [7]:
adata.X = adata.layers["raw_counts"]

In [8]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

  view_to_actual(adata)




In [9]:
sc.pp.highly_variable_genes(adata, inplace=True, subset=True, n_top_genes=1000)

  disp_grouped = df.groupby("mean_bin")["dispersions"]


In [10]:
adata

AnnData object with n_obs × n_vars = 6131 × 1000
    obs: 'sample', 'name', 'stage', 'stage_num', 'int_id', 'seq_id_gex_id', 'seq_id_atac', 'reporter', 'experiment_batch', 'sequencing_batch', 'n_counts', 'log_counts', 'n_counts_rank', 'n_genes', 'log_genes', 'mt_frac', 'rp_frac', 'ambi_frac', 'final_doublets', 'final_doublets_cat', 'doublet_calls', 'batch', 'size_factors', 'leiden', 'leiden_05_rna', 'leiden_05_atac', 'leiden_1_rna', 'leiden_1_atac', 'leiden_combined', 'leiden_gex_graph', 'leiden_ATAC_graph', 'leiden_wnn_graph', 'cell_type', 'cell_type_refined', 'S_score', 'G2M_score', 'phase', 'proliferation', 'time'
    var: 'gene_ids', 'feature_types', 'genome', 'interval', 'ambient_genes_E14_5-0', 'is_ambient_E14_5-0', 'n_counts-0', 'n_counts-1', 'ambient_genes_E15_5-1', 'is_ambient_E15_5-1', 'n_counts-2', 'ambient_genes_NVF_E15-5_Rep2-2', 'is_ambient_NVF_E15-5_Rep2-2', 'n_counts-3', 'ambient_genes_NVF_E16-5_Rep1-3', 'is_ambient_NVF_E16-5_Rep1-3', 'is_ambient', 'n_counts', 'n_cells'

In [11]:
gex_early = adata[adata.obs["time"]==15.5].X.A
gex_late = adata[adata.obs["time"]==16.5].X.A

In [12]:
x=gex_early
y=gex_late

In [13]:
solver = jax.jit(sinkhorn.Sinkhorn())


def entropic_map(x, y, cost_fn: costs.TICost) -> jnp.ndarray:
    geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)
    output = solver(linear_problem.LinearProblem(geom))
    dual_potentials = output.to_dual_potentials()
    return dual_potentials.transport

In [14]:
map_l1 = entropic_map(x, y, costs.ElasticL1(scaling_reg=50.0))

In [15]:
x_red = x[:10,:]

In [16]:
push_forward = map_l1(x_red)

In [17]:
(np.abs(push_forward-x_red) > 1e-4).sum(axis=1)

array([ 3, 22,  8,  9,  3,  4, 10,  1,  4,  9])

In [34]:
(np.abs(push_forward-x_red) > 1e-4).sum(axis=1)

array([ 59, 111, 107, 102,  85,  86,  85,  59, 104,  59])

In [31]:
(push_forward - x_red).sum()

Array(-566.8052, dtype=float32)

In [16]:
push_forward = map_l1(x_red)

2024-02-21 10:23:33.029516: W external/xla/xla/service/hlo_rematerialization.cc:2941] Can't reduce memory use below -57.88GiB (-62152086496 bytes) by rematerialization; only reduced to 84.67GiB (90912130800 bytes), down from 84.67GiB (90912130800 bytes) originally
2024-02-21 10:23:43.105432: W external/tsl/tsl/framework/bfc_allocator.cc:487] Allocator (GPU_0_bfc) ran out of memory trying to allocate 84.67GiB (rounded to 90912130816)requested by op 
2024-02-21 10:23:43.110162: W external/tsl/tsl/framework/bfc_allocator.cc:499] *___________________________________________________________________________________________________
E0221 10:23:43.110249 2849849 pjrt_stream_executor_client.cc:2766] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 90912130800 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   67.75MiB
              constant allocation:         0B
        maybe_live_out allocation:   84.67G

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 90912130800 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   67.75MiB
              constant allocation:         0B
        maybe_live_out allocation:   84.67GiB
     preallocated temp allocation:         0B
                 total allocation:   84.73GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 84.67GiB
		Operator: op_name="jit(<lambda>)/jit(main)/sub" source_file="/tmp/ipykernel_2849849/2834232686.py" source_line=6
		XLA Label: fusion
		Shape: f32[5185,1699,2580]
		==========================

	Buffer 2:
		Size: 51.03MiB
		Entry Parameter Subshape: f32[5185,2580]
		==========================

	Buffer 3:
		Size: 16.72MiB
		Entry Parameter Subshape: f32[1699,2580]
		==========================



In [None]:
push_forward.shape

In [None]:
sink_div = jax.jit(
    sinkhorn_divergence.sinkhorn_divergence,
    static_argnames=["geom", "cost_fn", "epsilon", "batch_size"],
)

def _extract_cost_matrix(
    x: jnp.ndarray,
    y: jnp.ndarray,
    cost_fn: sparse_costs.RegTICost,
    epsilon: t.Optional[float] = None,
    scale_cost: t.Union[t.Literal["mean", "max_cost"], float] = 1.0,
    batch_size: int = 4,
) -> t.Tuple[geometry.Geometry, pointcloud.PointCloud]:
    cost_matrix = []
    for i in range(0, x.shape[0], batch_size):
        tmp = jnp.array(x[i : i + batch_size])
        tmp = pointcloud.PointCloud(tmp, y, cost_fn=cost_fn, scale_cost=1.0)
        cost_matrix.append(tmp.cost_matrix)

    cost_matrix = jnp.concatenate(cost_matrix)
    geom = geometry.Geometry(
        cost_matrix=cost_matrix, scale_cost=scale_cost, epsilon=epsilon
    )
    pc = pointcloud.PointCloud(x, y, epsilon=geom.epsilon, cost_fn=cost_fn)

    assert geom.shape == (x.shape[0], y.shape[0])
    assert geom.shape == pc.shape
    return geom, pc


@functools.partial(
    jax.jit,
    static_argnames=(
        "tau_a",
        "tau_b",
        "max_iterations",
        "threshold",
    ),
)
def _solve(
    geom: geometry.Geometry,
    pc: t.Optional[
        pointcloud.PointCloud
    ] = None,  # for potentials in case of ElasticNet/GroupNorm
    *,
    tau_a: float = 1.0,
    tau_b: float = 1.0,
    **kwargs: t.Any,
) -> t.Tuple[sinkhorn.SinkhornOutput, potentials.EntropicPotentials]:
    prob = linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b)

    out = sinkhorn.Sinkhorn(**kwargs)(prob)
    if pc is None:
        return out, out.to_dual_potentials()

    pc_prob = linear_problem.LinearProblem(pc, tau_a=tau_a, tau_b=tau_b)
    dp = potentials.EntropicPotentials(out.f, out.g, pc_prob)

    return out, dp


def solve(
    ds: Dataset,
    cost_fn: t.Union[costs.SqEuclidean, sparse_costs.RegTICost],
    epsilon: t.Optional[float] = None,
    batch_size: int = 4,
    scale_cost: t.Union[t.Literal["mean", "max_cost"], float] = 1.0,
    tau_a: float = 1.0,
    tau_b: float = 1.0,
    **kwargs: t.Any,
) -> t.Tuple[sinkhorn.SinkhornOutput, potentials.EntropicPotentials]:
    if isinstance(cost_fn, costs.SqEuclidean):
        geom = pointcloud.PointCloud(
            ds.trn_x,
            ds.trn_y,
            cost_fn=cost_fn,
            scale_cost=scale_cost,
            batch_size=batch_size,
            epsilon=epsilon,
        )
        out, dp = _solve(geom, None, tau_a=tau_a, tau_b=tau_b, **kwargs)
    elif isinstance(cost_fn, sparse_costs.RegTICost):
        geom, pc = _extract_cost_matrix(
            ds.trn_x,
            ds.trn_y,
            cost_fn=cost_fn,
            epsilon=epsilon,
            batch_size=batch_size,
        )
        out, dp = _solve(geom, pc, tau_a=tau_a, tau_b=tau_b, **kwargs)
    else:
        raise TypeError(type(cost_fn))

    plt.plot(out.errors[: out.n_iters])
    plt.title(f"converged: {out.converged}")

    return out, dp


def predict(
    ds: Dataset,
    dp: potentials.EntropicPotentials,
    *,
    forward: bool,
    batch_size: int = 4,
    nan_to_num: t.Optional[float] = 0.0,
) -> t.Tuple[jnp.ndarray, float, float, float]:
    data = ds.tst_x if forward else ds.tst_y
    n = data.shape[0]

    pred_trans, pred = [], []
    for i in range(0, n, batch_size):
        tmp = data[i : i + batch_size]
        tmp = np.asarray(dp.transport(tmp, forward=forward))
        if ds.is_pca:
            pred_trans.append(tmp)
            tmp = np.asarray(ds.upproject(tmp))
        pred.append(tmp)

    pred = np.concatenate(pred)
    if ds.is_pca:
        pred_trans = np.concatenate(pred_trans)
        assert data.shape == pred_trans.shape
        assert data.shape[0] == pred.shape[0]
    else:
        assert data.shape == pred.shape
    if nan_to_num is not None:
        pred = np.nan_to_num(pred, nan=nan_to_num, copy=False)

    if ds.is_pca:
        expected = ds.adata_tst_x.X if forward else ds.adata_tst_y.X
        expected = expected.A if sp.issparse(expected) else expected
        perc_close = np.sum(np.isclose(pred, expected)) / pred.size
    else:
        perc_close = np.sum(np.isclose(pred, data)) / pred.size
    perc_neg = np.sum(pred < 0) / pred.size
    min_neg = np.min(pred)

    if ds.is_pca:
        return jnp.asarray(pred), jnp.asarray(pred_trans), perc_close, perc_neg, min_neg

    return jnp.asarray(pred), perc_close, perc_neg, min_neg


def pca_metric(
    ds: Dataset,
    ds_pca: Dataset,
    data_hat: jnp.ndarray,  # reduced dim data
    data_hat_raw: jnp.ndarray,  # full dim data
    *,
    frac: float,
    cost_fn: costs.CostFn,
    batch_size: t.Optional[int],
    forward: bool = True,
) -> t.Tuple[float, float]:
    assert not ds.is_pca
    assert ds_pca.is_pca

    if forward:
        data = ds_pca.tst_y  # (n, 50)
        data_raw = ds.tst_y  # (n, g)
    else:
        data = ds_pca.tst_x
        data_raw = ds.tst_x

    div = solve_sink_div(data, data_hat, frac=frac, cost_fn=costs.SqEuclidean())
    div_raw = solve_sink_div(
        data_raw, data_hat_raw, frac=frac, cost_fn=cost_fn, batch_size=batch_size
    )

    return div, div_raw


def evaluate(
    ds: Dataset,
    data_hat: jnp.ndarray,
    *,
    forward: bool,
    cost_fn: t.Union[costs.SqEuclidean, sparse_costs.RegTICost],
    genes: t.Optional[t.List[str]] = None,
    epsilon: t.Optional[float] = None,
    **kwargs: t.Any,
) -> t.Tuple[float, float, int]:
    data = ds.tst_y if forward else ds.tst_x
    if genes is not None:
        assert not ds.is_pca, "Cannot be a PCA dataset"
        assert isinstance(cost_fn, costs.SqEuclidean)
        mask = jnp.asarray(ds.adata.var_names.isin(genes))
        n_genes = int(jnp.sum(mask))
        data = data[:, mask]
        data_hat = data_hat[:, mask]
    else:
        n_genes = data_hat.shape[1]

    out_div: sinkhorn_divergence.SinkhornDivergenceOutput = sink_div(
        pointcloud.PointCloud,
        data_hat,
        data,
        epsilon=epsilon,
        cost_fn=cost_fn,
        **kwargs,
    )

    return np.nan, float(out_div.divergence), n_genes

    # TODO(michalk8): tries to allocated `f32[2594,620,5000]`, (n, m, d)
    geom = pointcloud.PointCloud(
        data_hat, data, epsilon=epsilon, cost_fn=cost_fn, **kwargs
    )
    out_sink, _ = _solve(geom)

    return float(out_sink.primal_cost), float(out_div.divergence), n_genes


def feature_correlation(
    ds: Dataset,
    data_hat: jnp.ndarray,
    *,
    forward: bool,
) -> t.Dict[str, np.ndarray]:
    data = np.asarray(ds.tst_x if forward else ds.tst_y)
    data_hat = np.asarray(data_hat)

    stats = {"pearson": [], "spearman": []}
    for i in range(data_hat.shape[1]):
        x, y = data_hat[:, i], data[:, i]
        stats["pearson"].append(ss.pearsonr(x, y).statistic)
        stats["spearman"].append(ss.spearmanr(x, y).statistic)

    stats = {k: np.asarray(v) for k, v in stats.items()}
    return {
        f"{k}_{agg}": getattr(np, "nan" + agg)(v)
        for k, v in stats.items()
        for agg in ["mean", "std"]
    }


def r2(
    ds: Dataset,
    data_hat: jnp.ndarray,
    *,
    forward: bool,
    genes: t.Optional[t.List[str]],
) -> t.Dict[str, float]:
    if genes is None:
        genes = ds.adata.var_names
    mask = ds.adata.var_names.isin(genes)

    if forward:
        data = np.asarray(ds.tst_y)[:, mask]
        other = np.asarray(ds.tst_x)[:, mask]
    else:
        data = np.asarray(ds.tst_x)[:, mask]
        other = np.asarray(ds.tst_y)[:, mask]
    data_hat = np.asarray(data_hat)[:, mask]

    def compute(
        data: jnp.ndarray, data_hat: jnp.ndarray, other: jnp.ndarray
    ) -> t.Tuple[float, float]:

        data_mean = np.mean(data, axis=0)
        data_hat_mean = np.mean(data_hat, axis=0)
        other_mean = np.mean(other, axis=0)

        log2fc = data_mean - other_mean
        log2fc_hat = data_hat_mean - other_mean

        mean_r2 = metrics.r2_score(data_mean, data_hat_mean)
        log2fc_r2 = metrics.r2_score(log2fc, log2fc_hat)
        return mean_r2, log2fc_r2

    mean, log2fold = compute(data, data_hat, other)
    res = {"mean-all-r2": mean, "log2fc-all-r2": log2fold}
    for dosage, mask in ds.dosage_mask.items():
        if forward:
            mean, log2fold = compute(data[mask], data_hat, other)
        else:
            mean, log2fold = compute(data, data_hat, other[mask])
        res[f"mean-{dosage}-r2"] = mean
        res[f"log2fc-{dosage}-r2"] = log2fold

    return res


def gene_metric(
    ds: Dataset,
    data_hat: jnp.ndarray,
    *,
    forward: bool,
    genes: t.Sequence[str],
    p: float = 0.9,
) -> t.Dict[str, t.Any]:
    gt_genes = list(genes)
    data_hat = np.array(data_hat)
    data = np.asarray(ds.tst_x if forward else ds.tst_y)

    vals = jnp.mean((data_hat - data), axis=0)  # log2fc
    gene_ixs = np.argsort(vals)[::-1]
    pred_genes = list(ds.adata.var_names[gene_ixs])

    metric = rbo.RankingSimilarity(pred_genes, gt_genes).rbo(k=len(gt_genes), p=p)

    return {
        "rbo": metric,
        "gt_genes": gt_genes,
        "scores_hat": vals,
        "all_genes": list(ds.adata.var_names),
        "k": len(gt_genes),
    }


def compute_epsilons(
    ds: Dataset,
    data_hat: jnp.ndarray,
    cost_fn: t.Union[costs.SqEuclidean, sparse_costs.RegTICost],
    *,
    forward: bool,
    fracs: t.Sequence[float],
    batch_size: int,
    use_train: bool = False,
) -> t.Sequence[t.Optional[float]]:
    if use_train:
        data = ds.trn_y  # assumes `data_x` is `ds.trn_x`
    else:
        data = ds.tst_y if forward else ds.tst_x

    mean = mean_cost(data_hat, data, cost_fn=cost_fn, batch_size=batch_size)
    return [None if f <= 0 else float(f * mean) for f in fracs]


def mean_cost(
    x: jnp.ndarray,
    y: jnp.ndarray,
    *,
    cost_fn: costs.CostFn,
    batch_size: t.Optional[int],
) -> float:
    inv_mean = pointcloud.PointCloud(
        x,
        y,
        cost_fn=cost_fn,
        batch_size=batch_size,
        scale_cost="mean",
    ).inv_scale_cost
    return float(1.0 / inv_mean)


def solve_sink_div(
    data: jnp.ndarray,
    data_hat: jnp.ndarray,
    *,
    frac: float,
    cost_fn: costs.CostFn,
    batch_size: int = 1024,
) -> float:
    epsilon = frac * mean_cost(
        data,
        data_hat,
        cost_fn=cost_fn,
        batch_size=batch_size,
    )

    div = sink_div(
        pointcloud.PointCloud,
        data,
        data_hat,
        cost_fn=cost_fn,
        epsilon=epsilon,
        batch_size=batch_size,
    )

    return float(div.divergence)

In [None]:
import argparse
import gc
import logging
import pathlib
import pickle
import time
import typing as t

import anndata as ad
import jax.numpy as jnp
import numpy as np
import scanpy as sc
from ott.geometry import costs
from tqdm import tqdm

import data_utils as du
import sparse_costs

CONTROL = "DMSO"
LOGGER = logging.getLogger()
MEAN_COST_FRAC = 0.1


def _compute_markers(
    adata: ad.AnnData,
    pert: str,
    n_genes: t.Optional[int],
    key_added: t.Optional[str] = None,
    *,
    alpha: t.Optional[float] = None,
    log2fc_min: t.Optional[float] = None,
) -> t.List[str]:
    sc.tl.rank_genes_groups(
        adata,
        groupby="perturbation_name",
        reference=CONTROL,
        rankby_abs=True,
        n_genes=n_genes,
        method="wilcoxon",
        key_added=key_added,
    )
    markers = sc.get.rank_genes_groups_df(
        adata, pert, key=key_added, pval_cutoff=alpha, log2fc_min=log2fc_min
    )
    logging.warning(f"#pert markers: `{len(markers)}/{n_genes}`")

    return list(markers["names"])


def _subset_genes(
    ds: du.Dataset, n_genes: int, markers: t.Optional[t.Sequence[str]]
) -> du.Dataset:
    test_adata = ds.adata_trn_x.concatenate(ds.adata_trn_y)  # combine trn ctrl/pert
    if n_genes > 0:
        sc.pp.highly_variable_genes(test_adata, n_top_genes=n_genes, subset=True)
        var_names = test_adata.var_names
        if markers is not None:
            var_names = list(set(var_names) | set(markers))
        LOGGER.warning(f"Using `{len(var_names)}` HVGs")
    else:  # remove genes with constant expression
        expressed = (test_adata.X > 0).sum(0).A.squeeze() > 0  # (g,)
        var_names = test_adata.var_names[np.where(expressed)]
        LOGGER.warning(f"Using `{len(var_names)}` genes")

    return ds.subset_genes(var_names)


def _subset_markers(
    ds: du.Dataset, *markerss: t.Sequence[str]
) -> t.Iterator[t.Sequence[str]]:
    for markers in markerss:
        markers = np.asarray(markers)
        yield markers[np.isin(markers, ds.adata.var_names)]


def _cost_fns(
    gammas: t.Sequence[float],
    k: t.Optional[int] = None,
) -> t.Dict[str, t.Union[costs.SqEuclidean, sparse_costs.RegTICost]]:
    fns = {"sqeucl": costs.SqEuclidean()}
    names = ("elastic", "stvs") if k is None else ("elastic", "stvs", f"kov-{k}")
    for name in names:
        for gamma in gammas:
            if name == "elastic":
                cost_fn = sparse_costs.ElasticNet(gam=gamma, lam=0.0)
            elif name == "stvs":
                cost_fn = sparse_costs.ElasticSTVS(gam=gamma)
            elif name == f"kov-{k}":
                cost_fn = sparse_costs.ElasticSqKOverlap(k=k, gam=gamma)
            else:
                raise NotImplemented(name)
            fns[f"{name}-{gamma}"] = cost_fn

    return fns


def _run(
    adata: ad.AnnData,
    args: argparse.Namespace,
    *,
    pert_markers: t.Sequence[str],
    hvgs: t.Sequence[str],
) -> t.List[t.Dict[str, float]]:
    trn_cost_fns = _cost_fns([args.gamma], k=args.k)
    tst_cost_fns = _cost_fns(
        [args.gamma] if args.tst_gammas is None else args.tst_gammas, k=args.k
    )
    res = []

    for ds in tqdm(
        du.train_val_split(
            adata,
            ctrl=CONTROL,
            pert=args.drug,
            n_splits=args.n_folds,
            test_size=args.test_size,
            seed=args.seed,
        ),
        total=args.n_folds,
    ):
        ds = _subset_genes(
            ds,
            args.n_hvgs,
            markers=pert_markers if args.ensure_markers_present else None,
        )
        pert_markers, *_ = _subset_markers(ds, pert_markers)
        logging.warning(f"#pert markers after gene subset: `{len(pert_markers)}`")

        ds_pca = ds.pca(n_pcs=args.n_pcs)
        pca_potentials = None

        tmp = {}
        for trn_name, trn_cost_fn in trn_cost_fns.items():
            batch_size = 1024 if trn_name == "sqeucl" else 4
            epsilon, epsilon_stats = du.find_trn_epsilon(
                ds, cost_fn=trn_cost_fn, fracs=args.trn_fracs, batch_size=batch_size
            )
            tmp[f"{trn_name}_epsilon"] = epsilon
            tmp[f"{trn_name}_epsilon-metric"] = epsilon_stats

            out, potentials = du.solve(
                ds, cost_fn=trn_cost_fn, epsilon=epsilon, batch_size=batch_size
            )
            if trn_name == "sqeucl":
                # for PCA metric
                pca_epsilon, pca_epsilon_stats = du.find_trn_epsilon(
                    ds_pca,
                    cost_fn=trn_cost_fn,
                    fracs=args.trn_fracs,
                    batch_size=batch_size,
                )
                tmp[f"{trn_name}_pca-epsilon"] = pca_epsilon
                tmp[f"{trn_name}_pca-epsilon-metric"] = pca_epsilon_stats
                _, pca_potentials = du.solve(
                    ds_pca,
                    cost_fn=costs.SqEuclidean(),
                    epsilon=pca_epsilon,
                    batch_size=batch_size,
                )

            for fwd in [True, False]:
                data_hat, perc_close, perc_neg, min_neg = du.predict(
                    ds, potentials, forward=fwd, batch_size=4, nan_to_num=0.0
                )
                # metadata
                tmp[f"{trn_name}_%close_{fwd}"] = perc_close
                tmp[f"{trn_name}_%neg_{fwd}"] = perc_neg
                tmp[f"{trn_name}_min-neg_{fwd}"] = min_neg
                data_hat = jnp.clip(data_hat, 0.0, None)

                # PCA metric: down-project and solve in sqeucl.
                data_proj = (ds.tst_y if fwd else ds.tst_x) @ ds_pca.evecs
                data_hat_proj = data_hat @ ds_pca.evecs
                tmp[f"{trn_name}_pca-proj-div_{fwd}"] = du.solve_sink_div(
                    data_proj,
                    data_hat_proj,
                    frac=MEAN_COST_FRAC,
                    cost_fn=costs.SqEuclidean(),
                    batch_size=batch_size,
                )

                pca_hat_raw, pca_hat, *_ = du.predict(
                    ds_pca,
                    pca_potentials,
                    forward=fwd,
                    batch_size=batch_size,
                )
                # compute sink-div in PCA/up-projected PCA space
                pca_div, pca_div_raw = du.pca_metric(
                    ds,
                    ds_pca,
                    data_hat=pca_hat,
                    data_hat_raw=pca_hat_raw,
                    frac=MEAN_COST_FRAC,
                    cost_fn=trn_cost_fn,
                    batch_size=batch_size,
                    forward=fwd,
                )
                tmp[f"{trn_name}_pca-div_{fwd}"] = pca_div
                tmp[f"{trn_name}_pca-uproj-div_{fwd}"] = pca_div_raw

                # correlation
                for k, v in du.feature_correlation(ds, data_hat, forward=fwd).items():
                    tmp[f"{trn_name}_{k}_{fwd}"] = v

                for tst_name, tst_cost_fn in tst_cost_fns.items():
                    tst_batch_size = 1024 if tst_name == "sqeucl" else 4
                    epsilons = du.compute_epsilons(
                        ds,
                        data_hat,
                        tst_cost_fn,
                        forward=fwd,
                        fracs=[MEAN_COST_FRAC],
                        batch_size=tst_batch_size,
                    )
                    for ix, epsilon in enumerate(epsilons):
                        # Sinkhorn divergence/primal cost metric on all genes
                        primal_cost, sink_div, _ = du.evaluate(
                            ds,
                            data_hat,
                            forward=fwd,
                            cost_fn=tst_cost_fn,
                            epsilon=epsilon,
                            batch_size=tst_batch_size,
                        )
                        tmp[f"{trn_name}_{tst_name}-primal-{ix}_{fwd}"] = primal_cost
                        tmp[f"{trn_name}_{tst_name}-div-{ix}_{fwd}"] = sink_div
                        tmp[f"{trn_name}_{tst_name}-epsilon-{ix}_{fwd}"] = epsilon

                        if tst_name == "sqeucl":
                            for kind, genes in zip(
                                ["markers", "hvgs"], [pert_markers, hvgs]
                            ):
                                primal_cost, sink_div, n_common = du.evaluate(
                                    ds,
                                    data_hat,
                                    forward=fwd,
                                    cost_fn=costs.SqEuclidean(),
                                    genes=genes,
                                    epsilon=epsilon,
                                )
                                tmp[
                                    f"{trn_name}_{tst_name}-{kind}-primal-{ix}_{fwd}"
                                ] = primal_cost
                                tmp[
                                    f"{trn_name}_{tst_name}-{kind}-div-{ix}_{fwd}"
                                ] = sink_div
                                tmp[
                                    f"{trn_name}_{tst_name}-{kind}-common-{ix}_{fwd}"
                                ] = n_common

                for kind, genes in zip(["markers", "hvgs"], [pert_markers, hvgs]):
                    for k, v in du.gene_metric(
                        ds, data_hat, forward=fwd, genes=genes
                    ).items():
                        tmp[f"{trn_name}_{kind}-{k}_{fwd}"] = v
                    # R2 between average expression and log2fc
                    for k, v in du.r2(ds, data_hat, forward=fwd, genes=genes).items():
                        tmp[f"{trn_name}_{kind}-{k}_{fwd}"] = v
                gc.collect()
        res.append(tmp)
    return res


def main(args: argparse.Namespace) -> None:
    save_dir = pathlib.Path(args.save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)

    np.random.seed(args.seed)

    adata = sc.read(args.data_path)
    adata.var_names = adata.var_names.str.upper()

    bdata = adata[adata.obs["perturbation_name"].isin([CONTROL, args.drug])]
    if args.dose is not None:
        bdata = bdata[bdata.obs["dose_character"].isin(["0", args.dose])]
    bdata = bdata[bdata.obs["cell_type"] == args.cell_line].copy()
    assert bdata.n_obs, "No cells have been selected."

    sc.pp.highly_variable_genes(bdata, n_top_genes=args.n_hvgs_eval)
    hvgs = list(bdata.var_names[bdata.var["highly_variable"]])
    logging.warning(f"Using `{len(hvgs)}` HVGs for metrics")

    # compute the markers on the trn/tst data
    pert_markers = _compute_markers(
        bdata,
        pert=args.drug,
        n_genes=args.n_degs,
        key_added="markers_small",
    )

    LOGGER.warning(
        f"Running for cell line `{args.cell_line}`, dose `{args.dose}` with data shape `{bdata.shape}`, "
        f"`{len(pert_markers)}` perturbation markers and `{len(hvgs)}` HVGs for evaluation"
    )
    tick = time.perf_counter()
    metrics = _run(bdata, args, pert_markers=pert_markers, hvgs=hvgs)
    LOGGER.warning(f"Done in `{time.perf_counter() - tick}s`")

    fname = (
        save_dir / f"result_{args.drug}_{args.cell_line}_{args.dose}_{args.gamma}.pkl"
    )
    with open(fname, "wb") as fin:
        pickle.dump(metrics, fin)
