In [None]:
import functools
from typing import Literal, Optional, Iterable

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import scanpy as sc
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm, genot
from ott.neural.networks.layers import time_encoder
from ot_pert.nets.nets import CondVelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import joypy
import pandas as pd
import pickle
import yaml
from ot_pert.nets.nets import VelocityFieldWithAttention
from ot_pert.metrics import compute_metrics, compute_mean_metrics

In [None]:
train_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_0_seen_genes.h5ad"
test_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_1_seen_genes.h5ad"
ood_path = "/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_2_seen_genes.h5ad"

In [None]:
adata_train = sc.read_h5ad(train_path)
adata_train.obs = adata_train.obs.rename(columns={"perturbation_name": "condition"})
adata_test = sc.read_h5ad(test_path)
adata_test.obs = adata_test.obs.rename(columns={"perturbation_name": "condition"})
adata_ood = sc.read_h5ad(ood_path)
adata_ood.obs = adata_ood.obs.rename(columns={"perturbation_name": "condition"})


In [None]:
train_data = load_data(adata_train, cfg, return_dl=False) if cfg.training.n_train_samples != 0 else {}
test_data = load_data(adata_test, cfg, return_dl=False) if cfg.training.n_test_samples != 0 else {}
ood_data = load_data(adata_ood, cfg, return_dl=False) if cfg.training.n_ood_samples != 0 else {}
dl = load_data(adata_train, cfg, return_dl=True)
comp_metrics_fn = compute_metrics_fast if cfg.training.fast_metrics else compute_metrics

In [None]:
PCs = 30
obsm_key_data = "X_pca"
obsm_key_cond_1 = "emb_1"
obsm_key_cond_2 = "emb_2"

In [None]:
def load_data(adata, return_dl: bool):
    """Loads data and preprocesses it based on configuration."""
    dls = []
    data_source = {}
    data_target = {}
    data_source_decoded = {}
    data_target_decoded = {}
    data_conditions = {}
    source = adata[adata.obs["condition"] == "control"].obsm[obsm_key_data][:,:PCs]
    source_decoded = adata[adata.obs["condition"] == "control"].X

    for cond in adata.obs["condition"].cat.categories:
        if cond != "control":
            target = adata[adata.obs["condition"] == cond].obsm[cfg.dataset.obsm_key_data][:,:cfg.dataset.PCs]
            target_decoded = adata[adata.obs["condition"] == cond].X.A
            condition_1 = adata[adata.obs["condition"] == cond].obsm[cfg.dataset.obsm_key_cond_1]
            condition_2 = adata[adata.obs["condition"] == cond].obsm[cfg.dataset.obsm_key_cond_2]
            assert np.all(np.all(condition_1 == condition_1[0], axis=1))
            assert np.all(np.all(condition_2 == condition_2[0], axis=1))
            expanded_arr = np.expand_dims(
                np.concatenate((condition_1[0, :][None, :], condition_2[0, :][None, :]), axis=0), axis=0
            )
            conds = np.tile(expanded_arr, (len(source), 1, 1))

            if return_dl:
                dls.append(
                    DataLoader(
                        datasets.OTDataset(
                            datasets.OTData(
                                lin=source,
                                condition=conds,
                            ),
                            datasets.OTData(lin=target),
                        ),
                        batch_size=cfg.training.batch_size,
                        shuffle=True,
                    )
                )
            else:
                data_source[cond] = source
                data_target[cond] = target
                data_source_decoded[cond] = source_decoded
                data_target_decoded[cond] = target_decoded
                data_conditions[cond] = conds
    if return_dl:
        return ConditionalLoader(dls, seed=0)

    deg_dict = {k: v for k, v in adata.uns["rank_genes_groups_cov_all"].items() if k in data_conditions.keys()}

    return {
        "source": data_source,
        "target": data_target,
        "source_decoded": data_source_decoded,
        "target_decoded": data_target_decoded,
        "conditions": data_conditions,
        "deg_dict": deg_dict,
    }


In [None]:
test_data = load_data(adata_test) 
ood_data = load_data(adata_ood) 

In [None]:
yaml_config = """
num_heads: 4
qkv_feature_dim: 32
max_seq_length: 2
hidden_dims: [1024, 1024, 1024]
output_dims: [1024, 1024, 1024]
condition_dims: [4096, 4096, 4096]
time_dims: [512, 512, 512]
time_n_freqs: 1024
flow_noise: 1.0
learning_rate: 0.00005
multi_steps: 20
epsilon: 0.01
tau_a: 0.999
tau_b: 0.999
dropout_rate: 0.2
"""

In [None]:
model_config = yaml.safe_load(yaml_config)

In [None]:
output_dim = 30
condition_dim = 1024

In [None]:
vf = VelocityFieldWithAttention(
    num_heads=model_config["num_heads"],
    qkv_feature_dim=model_config["qkv_feature_dim"],
    max_seq_length=model_config["max_seq_length"],
    hidden_dims=model_config["hidden_dims"],
    time_dims=model_config["time_dims"],
    output_dims=model_config["output_dims"] + [output_dim],
    condition_dims=model_config["condition_dims"],
    dropout_rate=model_config["dropout_rate"],
    time_encoder=functools.partial(time_encoder.cyclical_time_encoder, n_freqs=model_config["time_n_freqs"]),
)


model = otfm.OTFlowMatching(
    vf,
    flow=dynamics.ConstantNoiseFlow(model_config["flow_noise"]),
    match_fn=None,
    condition_dim=condition_dim,
    rng=jax.random.PRNGKey(13),
    optimizer=optax.MultiSteps(optax.adam(model_config["learning_rate"]), model_config["multi_steps"]),
)

In [None]:
load_path = f"/lustre/groups/ml01/workspace/ot_perturbation/models/otfm/combosciplex/fearless-durian-2804_model.pkl"

# Open the file containing the saved parameters
with open(load_path, 'rb') as f:
    loaded_params = pickle.load(f)

In [None]:
from flax.core.frozen_dict import FrozenDict
new_params = FrozenDict(loaded_params)

if hasattr(model.vf_state, 'replace'):
    model.vf_state = model.vf_state.replace(params=new_params)
else:
    raise ValueError

In [None]:
predictions_test = jtu.tree_map(model.transport, test_data["source"], test_data["conditions"])


In [None]:
predictions_ood = jtu.tree_map(model.transport, ood_data["source"], ood_data["conditions"])


In [None]:
def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [None]:
reconstruct_data_fn = functools.partial(
    reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T
)

In [None]:
predictions_test_decoded = jtu.tree_map(reconstruct_data_fn, predictions_test)
predictions_ood_decoded = jtu.tree_map(reconstruct_data_fn, predictions_ood)