In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import sys
import pickle
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as pyg
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    RichProgressBar,
    TQDMProgressBar,
    ModelCheckpoint,
)
import wandb
import pandas as pd

from src.data.datasets import *
from src.data.datamodules import SpatialDataModule
from src.data.graph_utils import check_radius
from src.data.utils import (
    load_celltypes,
    load_sample_names,
    load_prepared_data,
    load_real_groundtruth,
)
from src.models.modules import DeconvolutionModel, ln_loss
from src.models.gnn import DissectSpatial
from src.models.dissect import Dissect
from src.utils.utils import seed_everything


from configs.main_config import config


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

# set seeds
seed_everything(seed=42)

In [3]:
data_dir = f"./data"

# experiment_dir = "experiments/experiment_mouse_st"
# sc_data_dir = f"{data_dir}/single-cell"
# sc_file_name = "Allenbrain_forSimulation_uniquect.h5ad"
# st_data_dir = f"{base_path}/data/spatial"
# st_file_name = "V1_Mouse_Brain_Sagittal_Anterior.h5ad"
# radius = 0.02
# st_file_name = "puck_kidney_mus_normal.h5ad"

experiment_dir = "experiments/experiment_kidney_slideSeq_v2_105"
st_data_dir = f"{data_dir}/spatial/simulations_kidney_slideSeq_v2"
st_file_name = "UMOD-KI.KI-4b_resolution105.h5ad"
sc_data_dir = f"{data_dir}/spatial/kidney_slideSeq_v2"
sc_file_name = "UMOD-KI.KI-4b.h5ad"

In [4]:
sc_data = sc.read_h5ad(f"{sc_data_dir}/{sc_file_name}")
st_data = sc.read_h5ad(f"{st_data_dir}/{st_file_name}")
X_real, X_real_train, X_sim, y_sim = load_prepared_data(experiment_dir)
y_real = st_data.obs[st_data.obs.columns[2::]].to_numpy()

In [5]:
seed_everything(seed=42)
train_data = prepare_dataset(X_real_train, X_sim, y_sim, y_real=y_real, st_data=st_data)
test_data = prepare_dataset(X_real, X_sim, y_sim, y_real=y_real, st_data=st_data)

train_loader = DataLoader(
    train_data,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    drop_last=True,
)

test_loader = DataLoader(
    test_data, batch_size=3000, shuffle=False, pin_memory=True, num_workers=0
)

In [6]:
# batch = next(iter(train_loader))

In [7]:
# setup callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="train/total_loss", mode="min", save_last=True, dirpath="checkpoints"
)
# add another callback for plotting all stuff requires original celltype and sample names
# load cell types
celltype_names = load_celltypes(f"{experiment_dir}/datasets/celltypes.txt")

# load sample names
sample_names = load_sample_names(f"{experiment_dir}/datasets/sample_names.txt")

In [8]:
net = Dissect(
    num_genes=X_real.shape[1],
    num_celltypes=y_sim.shape[1],
    use_pos=True,
)

model = DeconvolutionModel(
    net,
    # l1_lambda=0.0,
    # l2_lambda=0.0,
    sim_loss_fn="kl_div",
    spatial_data=st_data,
    celltype_names=celltype_names,
    sample_names=sample_names,
    beta=None,
    alpha_max=0.1,
    alpha_min=0.1,
)

In [9]:
wandb.finish()
wandb_mode = "online"
# wandb_mode = "disabled"
wandb_logger = WandbLogger(project="dissect-spatial", log_model=True, mode=wandb_mode)

# training
trainer = pl.Trainer(
    max_epochs=5000,
    max_steps=-1,
    accelerator="gpu",
    limit_train_batches=1,
    log_every_n_steps=1,
    check_val_every_n_epoch=500,
    devices=[7],
    precision=32,
    logger=wandb_logger,
    deterministic="warn",
    enable_checkpointing=True,
    # fast_dev_run=True,
    # profiler="simple",
    enable_progress_bar=False,
    callbacks=[checkpoint_callback],
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdschaub[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type    | Params
---------------------------------
0 | net  | Dissect | 4.0 M 
---------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.028    Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5000` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████▆▆▆▆▆▆▆▆
train/l1_loss,███▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁
train/l2_loss,███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁
train/mix_loss,▁▁▂▂▂▄▃▄▄▆▇▇▇█▇█▆▅▆█▆▇▆▅▆▅▅▆▇▇▅▆▅▆▆▆▇█▅▇
train/sim_loss,█▇▆▆▅▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/total_loss,█▇▆▆▅▅▅▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
validation/mean_corr,▁▅▆▇▇▇████
validation/mean_rmse,█▄▃▂▁▂▁▁▁▁

0,1
epoch,4999.0
train/beta,10.0
train/l1_loss,0.54035
train/l2_loss,0.01353
train/mix_loss,0.00119
train/sim_loss,0.2632
train/total_loss,0.82898
trainer/global_step,4999.0
validation/mean_corr,0.7055
validation/mean_rmse,0.08214
