# Imports

In [None]:
from scvi.external import CPA
from scvi.external.cpa import register_dataset
from scvi.data import setup_anndata
from scvi.distributions import NegativeBinomial

import torch
import numpy as np
import pandas as pd
import anndata as ad
import plotnine as p9
import scanpy as sc
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

In [None]:
adata = sc.read('/data/yosef2/users/pierreboyeau/CPA/datasets/Norman2019_prep_new.h5ad')

In [None]:
keys = {
    'cell_type_key': 'cell_type',
    'dose_key': 'dose_val',
    'doser_type': 'linear',
    'perturbation_key': 'condition',
    'split_key': 'split1',
    "treatments_key": "treatments",
}

In [None]:
conditions = adata.obs[keys["perturbation_key"]]
dosages = adata.obs[keys["dose_key"]].values
codes = conditions.cat.categories.str.split("+")

# Get list of unique drugs
unfolded_codes = []
for code in codes:
    unfolded_codes += code
unfolded_codes = np.array(unfolded_codes)
codes_list = np.unique(unfolded_codes)

# Construct matrix representation of drugs
n_drugs = codes_list.shape[0]
n_cells = conditions.shape[0]
drugs_obsm = np.zeros((n_cells, n_drugs))
for i in tqdm(range(n_cells)):
    cell_drugs = np.isin(codes_list, conditions[i].split("+"))
    cell_doses = np.array(dosages[i].split("+")).astype(float)
    drugs_obsm[i, cell_drugs] = cell_doses

In [None]:
adata.obsm[keys["treatments_key"]] = drugs_obsm

In [None]:
import torch
from compert.train import prepare_compert

state, args, metrics = torch.load('/data/yosef2/users/pierreboyeau/CPA//pretrained_models/Norman2019_prep_new_deg_collect/relu/sweep_Norman2019_prep_new_relu_split1_model_seed=16_epoch=80.pt', map_location=torch.device('cpu'))


# Model

In [None]:
setup_anndata(adata)
batch_keys_to_dim = register_dataset(
    adata,
    treatments_key=keys["treatments_key"],
    cat_keys=[keys["cell_type_key"]],
)

In [None]:
module_kwargs = {
    'autoencoder_depth': 3,
    'autoencoder_width': 256,
    'doser': 'linear',
}

plan_kwargs = {
    'adversary_depth': 4,
    'adversary_lr': 0.00010436428115895668,
    'adversary_steps': 5,
    'adversary_wd': 0.00020547590628803208,
    'adversary_width': 128,
    'autoencoder_wd': 1.2089626892966399e-05,
    'penalty_adversary': 0.26161136412599345,
    'reg_adversary': 58.64779813400515,
    'step_size_lr': 25,
    'lr': 0.0002333608728691712,
}

trainer_kwargs = {
    'max_epochs': 1,
    'early_stopping_patience': 20,
    'batch_size': 128,
}

In [None]:
# from pytorch_lightning.profiler import AdvancedProfiler

# profiler = AdvancedProfiler(output_filename="profiler1.p")

In [None]:
model = CPA(
    adata=adata,
    batch_keys_to_dim=batch_keys_to_dim,
    split_key="split",
    **module_kwargs,
)
model.train(
    early_stopping_monitor="reconstruction_loss_validation",
    plan_kwargs=plan_kwargs,
#     profiler=profiler,
    **trainer_kwargs,
)

In [None]:
# return self.dosers(drugs) @ self.drug_embeddings.weight

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(10, 3))
model.history["adv_loss_train"].plot(ax=axes[0])
model.history["adv_penalty_train"].plot(ax=axes[1])
model.history["reconstruction_loss_train"].plot(ax=axes[2])

# Analysis

### Basic predictions