In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import sys
import pickle
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as pyg
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    RichProgressBar,
    TQDMProgressBar,
    ModelCheckpoint,
)
import wandb
import copy
import pandas as pd
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.data.datamodules import SpatialDataModule
from src.data.graph_utils import check_radius
from src.data.utils import load_celltypes, load_sample_names
from src.models.modules import DeconvolutionModel, ln_loss, beta_scheduler
from src.models.dissect_spatial import DissectSpatial, DissectHetero, DissectSpatialHybrid
from src.models.dissect import Dissect
from src.utils.utils import seed_everything
from src.train import train
from src.utils.wandb import *
from src.utils.config_utils import *

from configs.main_config import config


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

In [3]:
st_data_files = [
    "spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad",
    "spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad",
    "spatial/simulations_heart_seqFISH/embryo1_resolution0.11.h5ad",
]
experiment_dirs = [
    "experiments/experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75",
    "experiments/experiment_kidney_slideSeq_v2_105",
    "experiments/experiment_heart_seqFISH_embryo1_resolution0.11-new",
]
st_paths = ["${paths.data_dir}" + f  for f in st_data_files]
experiment_paths = ["${paths.root_dir}" + "/" + dir for dir in experiment_dirs]

In [4]:
sweep_ids = ["4trahk5d", "5wtb892z", "65rym7bj"]
experiment_names = ["hybrid", None, "transformer"]
for sweep_id, experiment_name in list(zip(sweep_ids, experiment_names))[1:2]:
    sweep_runs = get_sweep_runs_for_id(sweep_id)
    print(f"Found {len(sweep_runs)} runs for sweep {sweep_id}")
    entity="dschaub"
    project="DISSECT-src"
    api = wandb.Api()
    runs = api.runs(entity + "/" + project, filters={"sweep": sweep_id})
    run_names, mean_cccs, mean_rmses = extract_metrics_from_runs(runs, max_runs=10)
    # get top k runs
    k = 5
    top_k_indices = np.argsort(mean_cccs)[-k:]
    top_k_runs = [run_names[i] for i in top_k_indices]
    for run_name in top_k_runs:
        for st_path, experiment_path in zip(st_paths, experiment_paths):
            run_config = get_run_config(run_name)
            if "experiment" in run_config:
                exp_name = run_config["experiment"]
                del run_config["experiment"]
            else:
                exp_name = experiment_name
            print(f"Experiment name: {exp_name}")
            # load base config
            config_path = "../configs"
            config_name = "train.yaml"
            hydra.core.global_hydra.GlobalHydra.instance().clear()
            initialize(version_base="1.3", config_path=config_path)
            if exp_name is not None:
                overrides = [f"experiment={exp_name}"]
            else:
                overrides = []
            config = compose(
                config_name=config_name,
                overrides=overrides,
                return_hydra_config=True,
            )
            HydraConfig.instance().set_config(config)
            OmegaConf.set_struct(config, False)
            del config["hydra"]
            config = prepare_config(config, run_config)
            # set extra values
            config.data.st_path = st_path
            config.data.reference_dir = experiment_path
            config.trainer.max_steps = 10

            # configure more for training
            config.experiment = f"{sweep_id}_{exp_name}_test"
            tags = [sweep_id, f"{exp_name}", "test"]
            config.logger.wandb.tags = tags
            config.tags = tags
            config.extras.print_config = False
            config.model.save_predictions = True
            config.model.plotting = False
            config.trainer.devices = [7]
            wandb_mode = "online"
            config.logger.wandb.mode = wandb_mode
            # print_config(config)
            metric_dict, _ = train(config)
        break
    break

Found 1368 runs for sweep 5wtb892z


  1%|          | 9/1368 [00:05<12:48,  1.77it/s]


Experiment name: None


Global seed set to 42


Net in config: False


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdschaub[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Output()

`Trainer.fit` stopped: `max_steps=10` reached.


VBox(children=(Label(value='1.756 MB of 1.756 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train/beta,▁▁▁▁▁▁▁▁▁▁
train/l1_loss,█▇▆▆▅▄▃▂▂▁
train/l2_loss,█▇▆▆▅▄▃▂▂▁
train/mix_loss,▁▁▁▂▂▃▄▆█▇
train/sim_loss,▂▃█▇▇▄▁▇▃▄
train/total_loss,▆▆█▇▆▃▁▄▁▁
trainer/global_step,▁▂▃▃▄▅▆▆▇█

0,1
epoch,9.0
train/beta,5.0
train/l1_loss,38893.12109
train/l2_loss,1773.49939
train/mix_loss,0.0
train/sim_loss,0.90056
train/total_loss,1.09584
trainer/global_step,9.0


Experiment name: None


Global seed set to 42


Net in config: False


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Output()

`Trainer.fit` stopped: `max_steps=10` reached.


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train/beta,▁▁▁▁▁▁▁▁▁▁
train/l1_loss,█▇▆▆▅▄▃▂▂▁
train/l2_loss,█▇▆▅▅▄▃▂▂▁
train/mix_loss,▁▁▁▂▂▂▃▄▅█
train/sim_loss,▅▃█▅▃▆▁▁▃▁
train/total_loss,▇▆█▆▄▆▂▂▃▁
trainer/global_step,▁▂▃▃▄▅▆▆▇█

0,1
epoch,9.0
train/beta,5.0
train/l1_loss,40658.03906
train/l2_loss,1776.8396
train/mix_loss,0.0
train/sim_loss,0.97647
train/total_loss,1.18013
trainer/global_step,9.0


Experiment name: None


Global seed set to 42


Net in config: False


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Output()

`Trainer.fit` stopped: `max_steps=10` reached.


0,1
epoch,▁▂▃▃▄▅▆▆▇█
train/beta,▁▁▁▁▁▁▁▁▁▁
train/l1_loss,█▇▆▆▅▄▃▃▂▁
train/l2_loss,█▇▆▆▅▄▃▃▂▁
train/mix_loss,▁▁▂▃▃▅▅▇▇█
train/sim_loss,▅▂▂▂▇▆█▆▁▅
train/total_loss,▇▃▄▃▇▆█▆▁▄
trainer/global_step,▁▂▃▃▄▅▆▆▇█

0,1
epoch,9.0
train/beta,5.0
train/l1_loss,27205.8125
train/l2_loss,1812.67749
train/mix_loss,0.0
train/sim_loss,1.59217
train/total_loss,1.73246
trainer/global_step,9.0
