In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import sys
import pickle
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as pyg
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    RichProgressBar,
    TQDMProgressBar,
    ModelCheckpoint,
)
import wandb
import copy
import pandas as pd
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.data.datamodules import SpatialDataModule
from src.data.graph_utils import check_radius
from src.data.utils import load_celltypes, load_sample_names
from src.models.modules import DeconvolutionModel, ln_loss, beta_scheduler
from src.models.dissect_spatial import DissectSpatial, DissectHetero, DissectSpatialHybrid
from src.models.dissect import Dissect
from src.utils.utils import seed_everything
from src.train import train
from src.utils.wandb import *
from src.utils.config_utils import *

from configs.main_config import config


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

# set seeds
seed_everything(seed=42)

## Load configs

### Load run config

In [3]:
sweep_ids = ["4trahk5d", "5wtb892z", "65rym7bj"]
for sweep_id in sweep_ids[0:1]:
    sweep_runs = get_sweep_runs_for_id(sweep_id)
    print(f"Found {len(sweep_runs)} runs for sweep {sweep_id}")

Found 368 runs for sweep 4trahk5d


In [7]:
entity="dschaub"
project="DISSECT-src"
# project = "dissect-spatial"
api = wandb.Api()
runs = api.runs(entity + "/" + project, filters={"sweep": sweep_id})

In [10]:
for run in runs:
    print(run.config["tags"])
    break

['dev']


In [13]:
run_names = []
mean_cccs = []
mena_rmses = []
t = 1
for run in runs:
    # mean_ccc = run.history(keys=["validation/mean_ccc"])["validation/mean_ccc"].iloc[-1]
    try:
        mean_ccc = run.history(keys=["validation/mean_ccc"])["validation/mean_ccc"].iloc[-1]
        mean_rmse = run.history(keys=["validation/mean_rmse"])["validation/mean_rmse"].iloc[-1]
    except:
        continue
    if isinstance(mean_ccc, str):
        continue
    mean_cccs.append(mean_ccc)
    mena_rmses.append(mean_rmse)
    run_names.append(run.name)
    t += 1
    if t > 10:
        break
    

In [14]:
# get top k runs
k = 5
top_k_indices = np.argsort(mean_cccs)[-k:]
top_k_runs = [run_names[i] for i in top_k_indices]

In [19]:
for run_name in top_k_runs:
    run_config = get_run_config(run_name)

In [4]:
run_config = get_run_config(run_name="sweet-sweep-72", project="DISSECT-src")

In [5]:
if "experiment" in run_config:
    experiment_name = run_config["experiment"]
    del run_config["experiment"]
else:
    experiment_name = None
print(f"Experiment name: {experiment_name}")

Experiment name: hybrid


### Load base config

In [6]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
# import GlobalHydra

config_path = "../configs"
config_name = "train.yaml"

initialize(version_base="1.3", config_path=config_path)
if experiment_name is not None:
    overrides = [f"experiment={experiment_name}"]
else:
    overrides = []
config = compose(
    config_name=config_name,
    overrides=overrides,
    return_hydra_config=True,
)
HydraConfig.instance().set_config(config)
OmegaConf.set_struct(config, False)
del config["hydra"]
# print_config(config)


In [7]:
config = prepare_config(config, run_config)

Net in config: False


## Run training

In [8]:
config.extras.print_config = False
config.model.save_predictions = True
config.model.plotting = False
config.trainer.devices = [7]
wandb_mode = "online"
config.logger.wandb.mode = wandb_mode
# config.logger.wandb.name = "test123"
train(config)

Global seed set to 42
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdschaub[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Output()