In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.utilities.model_summary import ModelSummary
import wandb
import copy
import pandas as pd
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.utils.wandb import *
from src.utils.config_utils import *
from src.utils.experiment import run_experiment
from src.utils.data_paths_valid import get_paths_for_training


from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.nn import GCNConv, GCN, MLP, TransformerConv, GATv2Conv, GAT




# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

## Define data paths

In [3]:
st_paths, experiment_paths = get_paths_for_training()
# st_paths, experiment_paths

In [4]:
experiment_paths

['${paths.root_dir}/experiments/experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75',
 '${paths.root_dir}/experiments/experiment_kidney_slideSeq_v2_105',
 '${paths.root_dir}/experiments/experiment_heart_seqFISH/embryo1_resolution0.11',
 '${paths.root_dir}/experiments/experiment_heart_seqFISH/embryo2_resolution0.11',
 '${paths.root_dir}/experiments/experiment_heart_seqFISH/embryo3_resolution0.11',
 '${paths.root_dir}/experiments/experiment_hypothalamus_MERFISH/0_resolution0.04',
 '${paths.root_dir}/experiments/experiment_hypothalamus_MERFISH/1_resolution0.04',
 '${paths.root_dir}/experiments/experiment_mouse_st',
 '${paths.root_dir}/experiments/experiment_lymph_node',
 '${paths.root_dir}/experiments/experiment_kidney/UMOD-WT.WT-2a_resolution75+kidney_full_v3_harmonized_celltypes_raw_mouseGenes_cell_type',
 '${paths.root_dir}/experiments/experiment_kidney/UMOD-KI.KI-4b_resolution105+kidney_full_v3_harmonized_celltypes_raw_mouseGenes_cell_type',
 '${paths.root_dir}/experiments/experi

## Setup experiment config

### Load default config

In [5]:
experiment_name = None
# experiment_name = "hybrid"
experiment_name = "multi_channel"
if experiment_name is not None:
    overrides = [f"experiment={experiment_name}"]
else:
    overrides = []

hydra.core.global_hydra.GlobalHydra.instance().clear()
# import GlobalHydra

config_path = "../configs"
config_name = "train.yaml"

initialize(version_base="1.3", config_path=config_path)
config = compose(
    config_name=config_name,
    overrides=overrides,
    return_hydra_config=True,
)
HydraConfig.instance().set_config(config)
OmegaConf.set_struct(config, False)
del config["hydra"]
config.experiment = experiment_name
# print_config(config)

### Load config from previous experiment

In [6]:
# run_name = "northern-river-217"
run_name = "fiery-music-355"
project = "dissect-spatial"

run_name = "electric-field-1"
run_name = "eager-wind-245"
project = "multi-channel-gnn"

config = get_run_config(run_name, project=project)
# run_name = "electric-sweep-37"
# # run_name = "swift-sweep-658"
# config = get_run_config(run_name, project="DISSECT-src")
config = convert_wandb_to_dict_config(config)
config.experiment = "multi_channel"

Net in config: True


### Set custom values

#### Modify network architecture

In [7]:
# config.net.latent_dim = 64
# # config.net.encoder_kwargs.fusion = "gating"
# # config.net.encoder_kwargs.fusion = "concat_skip"
# config.net.encoder_kwargs.fusion = "concat"
# config.net.encoder_kwargs.lin_channel = True
# config.net.encoder_kwargs.num_heads = 4
# config.net.encoder_kwargs.num_layers = 1
# config.net.encoder_kwargs.use_ffn = False
# config.net.encoder_kwargs.use_pos = False
# config.net.encoder_kwargs.plain_last = False
# # config.net.use_pos = True
# config.net.encoder_kwargs.use_sparse = False
# config.net.decoder_kwargs.norm = None
# config.net.encoder_kwargs.norm = None
# config.net.encoder_kwargs.norm_last = False
# try:
#     del config.net.encoder_kwargs.knn
#     del config.net.encoder_kwargs.cosine
#     # del config.net.encoder_kwargs.num_heads
# except:
#     pass
# config.net.encoder_kwargs.spatial_channel_kwargs.num_layers = 2
# config.net.encoder_kwargs.spatial_channel_kwargs.activation = "elu"
# config.net.encoder_kwargs.spatial_channel_kwargs.back_to_latent = True
# config.net.encoder_kwargs.spatial_channel_kwargs.conv_type = "gcn"

In [8]:
print_config(config.net)

use_id: false
use_pos: false
_target_: src.models.dissect_spatial.DissectSpatial
activation: elu
latent_dim: 256
encoder_type: gnn
decoder_kwargs:
  norm: None
  num_layers: 2
  hidden_channels: 64
encoder_kwargs:
  norm: None
  fusion: concat_skip
  use_id: false
  dropout: 0
  use_ffn: false
  use_pos: false
  num_heads: 4
  inter_skip: true
  norm_first: false
  num_layers: 1
  plain_last: false
  use_sparse: true
  lin_channel: true
  mha_channel: true
  sim_pos_enc: false
  ff_hidden_dim: 256
  latent_channel: false
  spatial_channel: true
  mha_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1
  init_embed_hidden_channels:
  - 512
  - 256
  latent_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1
  spatial_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1



#### Modify training parameters

In [9]:
# config.data.radius = 0.04
print(config.data.radius)

0.02


In [10]:
# config.model.l2_lambda = 5e-5
# config.model.l1_lambda = 0.0
# config.model.weight_decay = 0.0
# config.model.beta = 5
# config.model.learning_rate = 1e-3
# config.model.alpha_max = 0.1
# config.model.alpha_min = 0.0
# config.model.g_mix_from_g_real = False
config.trainer.max_steps = 5000
# config.trainer.check_val_every_n_epoch = 500

In [11]:
print_config(config.model)
# print_config(config.trainer)

beta: None
_target_: src.models.modules.DeconvolutionModel
plotting: false
alpha_max: 0.2
alpha_min: 0
l1_lambda: 0
l2_lambda: 1.0e-06
log_hparams: false
sim_loss_fn: kl_div
weight_decay: 0
learning_rate: 1.0e-05
save_predictions: true
move_data_to_device: true



## Run training

In [12]:
# run training on all datasets if no dataset is specified
tags = "latest"
tags = "latestv2"
# tags = "reference_switch"
# tags = "latestv2-switched"
project = "multi-channel-gnn"
wandb_mode = "disabled"
wandb_mode = "online"
if wandb_mode == "online":
    config.trainer.fast_dev_run = False
else:
    config.trainer.fast_dev_run = True
config.trainer.deterministic = True
# config.trainer.deterministic = False
for st_path, experiment_path in list(zip(st_paths, experiment_paths))[0:1]:
    print(f"Running experiment on {st_path} and experiment path {experiment_path}")
    metric_dict, object_dict = run_experiment(
        config,
        st_path,
        experiment_path,
        wandb_mode=wandb_mode,
        tags=tags,
        config_path="../../configs",
        progress_bar=False,
        device=7,
        project=project,
    )


Running experiment on ${paths.data_dir}spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad and experiment path ${paths.root_dir}/experiments/experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75


Global seed set to 42


Key num_heads not in base config.
Key experiment not in base config.
Using deterministic mode
Input kwargs: {'num_heads': 4}


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdschaub[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 1.1 M 
----------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.295     Total estimated model params size (MB)


Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...
Saving predictions...


`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='19.870 MB of 19.870 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████▆▆▆▆▆▆▆▆
train/l1_loss,██▇▇▇▆▆▅▅▄▄▃▃▂▂▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃
train/l2_loss,█▇▇▆▆▅▅▄▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃
train/mix_loss,▁▁▃▄▄▅▅▆▆▇▅█▇▇██▅▅▄▅▅▅▅▅▅▅▄▂▄▅▂▅▃▄▃▄▅▄▃▃
train/sim_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/mean_ccc,▁▆▇▇▇█████
validation/mean_ccc_,▁▆▇▇▇█████

0,1
epoch,4999.0
train/beta,10.0
train/l1_loss,60202.34766
train/l2_loss,1866.63696
train/mix_loss,0.00144
train/sim_loss,0.12418
train/total_loss,0.14049
trainer/global_step,4999.0
validation/mean_ccc_,0.65714
validation/mean_corr_,0.68853


In [None]:
summary = ModelSummary(object_dict["model"], max_depth=-1)
print(summary)

In [32]:
1500 // 1000

1

In [None]:
object_dict["net"]

DissectSpatial(
  (encoder): MultiChannelGNNEncoder(
    (mlp): MLP(-1, 512, 256, 256)
    (layers): ModuleList(
      (0): MultiChannelGNNBlock(
        (spatial_channel): GNNChannel(
          (convs): ModuleList(
            (0): GATv2Conv(256, 256, heads=8)
            (1): GATv2Conv(256, 256, heads=8)
          )
          (norms): ModuleList(
            (0): None
            (1): None
          )
          (dropouts): ModuleList(
            (0): Dropout(p=0.0, inplace=False)
            (1): Dropout(p=0.0, inplace=False)
          )
          (activations): ModuleList(
            (0): ELU(alpha=1.0)
            (1): ELU(alpha=1.0)
          )
          (concat_linears): ModuleList(
            (0): Linear(in_features=2048, out_features=256, bias=True)
            (1): Linear(in_features=2048, out_features=256, bias=True)
          )
        )
        (lin_channel): Linear(256, 256, bias=True)
        (mha_channel): MHAChannel(
          (convs): ModuleList(
            (0): Mu

In [None]:
for param in object_dict["net"].parameters():
    print(param.shape)

In [8]:
ablation_settings = {
    "data": {
        "radius": [0.0001, 0.01, 0.03, 0.04],
    },
    "net": {
        "latent_dim": [64, 128, 512],
        "activation": ["relu", "leaky_relu", "elu"],
        "use_pos": [True],
    },
    "encoder": {
        "num_heads": [1, 2, 4],
        "inter_skip": [False],
        "mha_channel": [False],
        "lin_channel": [False],
        "spatial_channel": [False],
        "num_layers": [2, 3],
        "plain_last": [True],
        "use_ffn": [True],
        "fusion": ["concat", "gating", "concat_simple"],
        "latent_dim": [64, 128, 512],
        "use_pos": [True],
    },
    "model": {
        "l2_lambda": [0.0, 1e-5, 1e-7],
        "beta": [0.0, 7.5],
        "alpha_max": [0.2, 0.5],
        "learning_rate": [1e-3, 1e-4],
    },
}

In [9]:
for setting, entries in ablation_settings.items():
    for entry, values in entries.items():
        for value in values:
            config_copy = copy.deepcopy(config)
            if setting == "data":
                config_copy.data[entry] = value
            if setting == "net": 
                config_copy.net[entry] = value
            if setting == "encoder": 
                config_copy.net.encoder_kwargs[entry] = value
            if setting == "model":
                config_copy.model[entry] = value
            print(config_copy.model)
    break

{'beta': 'None', '_target_': 'src.models.modules.DeconvolutionModel', 'plotting': False, 'alpha_max': 0.2, 'alpha_min': 0, 'l1_lambda': 0, 'l2_lambda': 1e-06, 'log_hparams': False, 'sim_loss_fn': 'kl_div', 'weight_decay': 0, 'learning_rate': 1e-05, 'save_predictions': False, 'move_data_to_device': True}
{'beta': 'None', '_target_': 'src.models.modules.DeconvolutionModel', 'plotting': False, 'alpha_max': 0.2, 'alpha_min': 0, 'l1_lambda': 0, 'l2_lambda': 1e-06, 'log_hparams': False, 'sim_loss_fn': 'kl_div', 'weight_decay': 0, 'learning_rate': 1e-05, 'save_predictions': False, 'move_data_to_device': True}
{'beta': 'None', '_target_': 'src.models.modules.DeconvolutionModel', 'plotting': False, 'alpha_max': 0.2, 'alpha_min': 0, 'l1_lambda': 0, 'l2_lambda': 1e-06, 'log_hparams': False, 'sim_loss_fn': 'kl_div', 'weight_decay': 0, 'learning_rate': 1e-05, 'save_predictions': False, 'move_data_to_device': True}
{'beta': 'None', '_target_': 'src.models.modules.DeconvolutionModel', 'plotting': Fal