In [3]:
import hydra
import wandb
from ott.neural import datasets
import sys
from omegaconf import DictConfig
import jax.numpy as jnp
from jax import random
from typing import Optional, Literal
import jax
import pathlib
import optax
import yaml
from datetime import datetime
from flax import linen as nn
import functools
from tqdm import tqdm
from flax.training import train_state

from ott.neural.networks.layers import time_encoder
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils
import jax.tree_util as jtu
from ott.neural.networks.layers import time_encoder
import pandas as pd
import os

import hydra
from omegaconf import DictConfig, OmegaConf

from torch.utils.data import DataLoader
import numpy as np

import scanpy as sc

In [4]:
from src.ot_pert.metrics import compute_metrics, compute_mean_metrics
from src.ot_pert.nets.nets import VelocityFieldWithAttention
from src.ot_pert.utils import ConditionalLoader

In [7]:

def reconstruct_data(embedding: np.ndarray, projection_matrix: np.ndarray, mean_to_add: np.ndarray) -> np.ndarray:
    return np.matmul(embedding, projection_matrix.T) + mean_to_add

In [5]:
def data_match_fn(
        src_lin: Optional[jnp.ndarray], tgt_lin: Optional[jnp.ndarray],
        src_quad: Optional[jnp.ndarray], tgt_quad: Optional[jnp.ndarray], *,
        typ: Literal["lin", "quad", "fused"], epsilon: float = 1e-2, tau_a: float = 1.0,
        tau_b: float = 1.0,
    ) -> jnp.ndarray:
        if typ == "lin":
            return solver_utils.match_linear(x=src_lin, y=tgt_lin, scale_cost="mean", epsilon=epsilon, tau_a=tau_a, tau_b=tau_b)
        if typ == "quad":
            return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad)
        if typ == "fused":
            return solver_utils.match_quadratic(
                xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin
            )
        raise NotImplementedError(f"Unknown type: {typ}.")

In [19]:
adata_train = sc.read('/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_0_seen_genes.h5ad')
adata_train.obs = adata_train.obs.rename(columns={"perturbation_name": "condition"})
dls = []

train_data_source = {}
train_data_target = {}
train_data_source_decoded = {}
train_data_target_decoded = {}
train_data_conditions = {}


source = adata_train[adata_train.obs["condition"]=="control"].obsm['X_pca']
source_decoded = adata_train[adata_train.obs["condition"]=="control"].X
for cond in adata_train.obs["condition"].cat.categories:
    if cond == "ctrl":
        continue
    target = adata_train[adata_train.obs["condition"]==cond].obsm['X_pca']
    target_decoded = adata_train[adata_train.obs["condition"]==cond].X.A
    condition_1 = adata_train[adata_train.obs["condition"]==cond].obsm['emb_1']
    condition_2 = adata_train[adata_train.obs["condition"]==cond].obsm['emb_2']
    assert np.all(np.all(condition_1 == condition_1[0], axis=1))
    assert np.all(np.all(condition_2 == condition_2[0], axis=1))
    expanded_arr = np.expand_dims(np.concatenate((condition_1[0,:][None,:],condition_2[0,:][None,:]), axis=0), axis=0)
    conds = np.tile(expanded_arr, (len(source), 1, 1))
    dls.append(DataLoader(datasets.OTDataset(datasets.OTData(
        lin=source,
        condition=conds,
    ), datasets.OTData(lin=target)), batch_size=128, shuffle=True))
    train_data_source[cond] = source
    train_data_target[cond] = target
    train_data_conditions[cond] = conds
    train_data_source_decoded[cond] = source_decoded
    train_data_target_decoded[cond] = target_decoded

# train_loader = ConditionalLoader(dls, seed=0)

# reconstruct_data_fn = functools.partial(reconstruct_data, projection_matrix=adata_train.varm["PCs"], mean_to_add=adata_train.varm["X_train_mean"].T)

# source_dim = source.shape[1]
# target_dim = source_dim
# condition_dim = condition_1.shape[1]

# source_dim = source.shape[1]
# target_dim = source_dim
# condition_dim = condition_1.shape[1]

# vf = VelocityFieldWithAttention(
#     num_heads=cfg.model.num_heads,
#     qkv_feature_dim=cfg.model.qkv_feature_dim,
#     max_seq_length=cfg.model.max_seq_length,
#     hidden_dims=cfg.model.hidden_dims,
#     time_dims=cfg.model.time_dims,
#     output_dims=cfg.model.output_dims+[target_dim],
#     condition_dims=cfg.model.condition_dims,
#     time_encoder = functools.partial(time_encoder.cyclical_time_encoder, n_freqs=cfg.model.time_n_freqs),
#     )

# print(vf)

# model = otfm.OTFlowMatching(vf,
#     flow=dynamics.ConstantNoiseFlow(cfg.model.flow_noise),
#     match_fn=jax.jit(functools.partial(data_match_fn, typ="lin", src_quad=None, tgt_quad=None, epsilon=cfg.model.epsilon, tau_a=cfg.model.tau_a, tau_b=cfg.model.tau_b)),
#     condition_dim=condition_dim,
#     rng=jax.random.PRNGKey(13),
#     optimizer=optax.MultiSteps(optax.adam(learning_rate=cfg.model.learning_rate), cfg.model.multi_steps)
# )

# print(model)

# training_logs = {"loss": []}


In [18]:
adata_train.obs.condition

index
AAACCTGAGAAGAAGC-1          control
AAACCTGCACGAAGCA-1          control
AAACCTGCAGCCTTGG-1            MAML2
AAACCTGCATTACCTT-1      ETS2+MAP7D1
AAACCTGGTATAATGG-1          control
                          ...      
TTTGTCAGTATAAACG-8    SAMD1+UBASH3B
TTTGTCAGTCAGAATA-8          control
TTTGTCATCAGTACGT-8            FOXA3
TTTGTCATCCCAACGG-8           BCORL1
TTTGTCATCTGGCGAC-8           MAP4K3
Name: condition, Length: 59579, dtype: category
Categories (134, object): ['AHR', 'ARID1A', 'ARRDC3', 'ATL1', ..., 'ZBTB25', 'ZC3HAV1', 'ZNF318', 'control']

In [15]:
torch.where(b==torch.tensor([[0.0000, 0.0000, 0.0000]]), torch.tensor([[0.0000, 0.0000, 1.0000]]), b)

tensor([[0.8811, 0.5417, 0.8617],
        [0.5732, 0.9433, 0.4122],
        [0.0000, 0.0000, 1.0000],
        [0.0000, 0.0000, 1.0000]])

In [20]:
a = {'b': 2}

In [21]:
if a:
    print("f")

f


In [32]:
(np.log(10) - np.log(10.0001))*2

-1.999990000012275e-05

In [33]:
(10-10.0001)*2

-0.00019999999999953388

In [35]:
dicta = {'adata_train_path': '/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_train_0_seen_genes.h5ad',
'adata_test_path': '/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_test_0_seen_genes.h5ad',
'adata_ood_path': '/lustre/groups/ml01/workspace/ot_perturbation/data/norman/adata_val_0_seen_genes.h5ad'}

In [36]:
import scanpy as sc

In [41]:
train = sc.read_h5ad(dicta['adata_test_path'])
train.obsm['emb_1']

array([[-0.00188   , -0.00241112,  0.03851924, ...,  0.08719031,
        -0.12514378,  0.01683793],
       [ 0.04002452, -0.02404516,  0.00446936, ..., -0.00661617,
        -0.133329  ,  0.06671482],
       [-0.02541345, -0.00573373,  0.03895705, ...,  0.02111108,
        -0.07714359, -0.02685452],
       ...,
       [-0.04211034, -0.09814648, -0.00608418, ...,  0.10736639,
        -0.06857122, -0.0775749 ],
       [ 0.04481288, -0.02440529,  0.05576424, ...,  0.00169797,
        -0.1467784 ,  0.00504636],
       [-0.06218398, -0.0427513 ,  0.01320167, ...,  0.08727837,
        -0.10445533, -0.13681167]])

In [38]:
train.obsm

AxisArrays with keys: X_pca, X_umap, emb_1, emb_2

In [43]:
DATA_PATH = "/lustre/groups/ml01/projects/super_rad_project/data"
RAD_PATH = f"{DATA_PATH}/Cas9_data_nofilter_mean_combn_with_singletons_minmax.csv"
HORLBECK_PATH = f"{DATA_PATH}/Horlbeck_minmax.csv"
SCORE_PATH = f"{DATA_PATH}/SCORE_minmax.csv"

In [45]:
score =pd.read_csv(SCORE_PATH, index_col=0)

In [50]:
score.iv1.unique()[200]

'ACTA1'