In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.utilities.model_summary import ModelSummary
import wandb
import copy
import pandas as pd
from scipy.sparse import csr_matrix, coo_matrix
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.data.utils import *
from src.utils.wandb import *
from src.utils.config_utils import *
from src.utils.experiment import run_experiment
from src.utils.data_paths_valid import *

# from src.utils.data_paths_anca import *


from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.nn import GCNConv, GCN, MLP, TransformerConv, GATv2Conv, GAT


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

## Define data paths

In [3]:
st_paths, experiment_paths = get_paths_for_training(
    st_data_files,
    sc_data_files,
    celltype_cols=celltype_cols,
    extra_settings_sim=extra_settings_sim,
    extra_settings_prep=extra_settings_prep,
    use_old_experiment_paths=True,
    experiment_dirs=experiment_dirs,
)
# st_paths, experiment_paths

In [5]:
st_files_real = [
    "spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad",
    "spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad",
]
st_files_sim = [
    "../deconvolution/simulation/slideseqv2_mouse_kidney_2a/st_simulated.h5ad",
    "../deconvolution/simulation/slideseqv2_mouse_kidney_4b/st_simulated.h5ad",
]

In [6]:
index = 6

st_data = sc.read_h5ad(f"data/{st_data_files[index]}")
sc_data = sc.read_h5ad(f"data/{sc_data_files[index]}")

print(type(st_data.X))
print(type(sc_data.X))

<class 'numpy.ndarray'>
<class 'scipy.sparse._csr.csr_matrix'>


## Setup experiment config

### Load default config

In [7]:
experiment_name = None
# experiment_name = "hybrid"
experiment_name = "multi_channel"

config_path = "../../configs"
config_name = "train.yaml"

base_config = load_base_config(
    experiment_name=experiment_name,
    config_path=config_path,
    config_name=config_name,
    overrides=["data=digraph.yaml"],
)
# print_config(config)

# print_config(base_config)

### Load config from previous experiment

In [8]:
# run_name = "northern-river-217"
run_name = "fiery-music-355"
project = "dissect-spatial"

run_name = "electric-field-1"
run_name = "eager-wind-245"
project = "multi-channel-gnn"

config = get_run_config(run_name, project=project)
# run_name = "electric-sweep-37"
# # run_name = "swift-sweep-658"
# config = get_run_config(run_name, project="DISSECT-src")
config = convert_wandb_to_dict_config(config)
config.experiment = "multi_channel"

# optionally override config with baseconfig
config.data = base_config.data

# print_config(config)

Net in config: True


### Set custom values

#### Modify network architecture

In [13]:
config.net.encoder_kwargs.lin_channel = True
config.net.encoder_kwargs.spatial_channel = False
config.net.encoder_kwargs.mha_channel = False
config.net.encoder_kwargs.latent_channel = False

config.net.encoder_kwargs.inter_skip = False

# config.net.latent_dim = 64
# # config.net.encoder_kwargs.fusion = "gating"
# # config.net.encoder_kwargs.fusion = "concat_skip"
# config.net.encoder_kwargs.fusion = "concat"
# config.net.encoder_kwargs.lin_channel = True
# config.net.encoder_kwargs.num_heads = 4
# config.net.encoder_kwargs.num_layers = 1
# config.net.encoder_kwargs.use_ffn = False
# config.net.encoder_kwargs.use_pos = False
# config.net.encoder_kwargs.plain_last = False
# # config.net.use_pos = True
# config.net.encoder_kwargs.use_sparse = False
# config.net.decoder_kwargs.norm = None
# config.net.encoder_kwargs.norm = None
# config.net.encoder_kwargs.norm_last = False
# try:
#     del config.net.encoder_kwargs.knn
#     del config.net.encoder_kwargs.cosine
#     # del config.net.encoder_kwargs.num_heads
# except:
#     pass
# config.net.encoder_kwargs.spatial_channel_kwargs.num_layers = 2
# config.net.encoder_kwargs.spatial_channel_kwargs.activation = "elu"
# config.net.encoder_kwargs.spatial_channel_kwargs.back_to_latent = True
# config.net.encoder_kwargs.spatial_channel_kwargs.conv_type = "gcn"


In [14]:
print_config(config.net)


use_id: false
use_pos: false
_target_: src.models.dissect_spatial.DissectSpatial
activation: elu
latent_dim: 256
encoder_type: gnn
decoder_kwargs:
  norm: None
  num_layers: 2
  hidden_channels: 64
encoder_kwargs:
  norm: None
  fusion: concat_skip
  use_id: false
  dropout: 0
  use_ffn: false
  use_pos: false
  num_heads: 4
  inter_skip: false
  norm_first: false
  num_layers: 1
  plain_last: false
  use_sparse: true
  lin_channel: true
  mha_channel: false
  sim_pos_enc: false
  ff_hidden_dim: 256
  latent_channel: false
  spatial_channel: false
  mha_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1
  init_embed_hidden_channels:
  - 512
  - 256
  latent_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1
  spatial_channel_kwargs:
    norm: None
    num_heads: 8
    activation: None
    num_layers: 1



#### Modify training parameters

In [10]:
# config.data.radius = 0.04
# config.data.use_rctd_genes = True
# config.data.with_platform_effects = True
# config.data.renormalize = False


In [15]:
# config.model.l2_lambda = 5e-5
# config.model.l1_lambda = 0.0
# config.model.weight_decay = 0.0
# config.model.beta = 5
# config.model.learning_rate = 1e-3
# config.model.alpha_max = 0.1
# config.model.alpha_min = 0.0
# config.model.g_mix_from_g_real = False
# config.model.combine_graphs = True
config.model.combine_graphs = False

config.trainer.max_steps = 5000
config.trainer.check_val_every_n_epoch = 500


In [None]:
# print_config(config.model)
# print_config(config.trainer)


## Run training

In [None]:
# run training on all datasets if no dataset is specified
tags = "latest"
tags = "latestv2"
# tags = "anca_latest"
# tags = "reference_switch"
# tags = "latestv2-switched"
project = "multi-channel-gnn"
wandb_mode = "disabled"
wandb_mode = "online"

if wandb_mode == "online":
    config.trainer.fast_dev_run = False
else:
    config.trainer.fast_dev_run = True
config.trainer.deterministic = True
# config.trainer.deterministic = False
for st_path, experiment_path in list(zip(st_paths, experiment_paths))[0:1]:
    print(f"Running experiment on {st_path} and experiment path {experiment_path}")
    config.data.st_path = st_path
    config.data.reference_dir = experiment_path
    metric_dict, object_dict = run_experiment(
        config,
        wandb_mode=wandb_mode,
        tags=tags,
        config_path="../../configs",
        progress_bar=False,
        device=7,
        project=project,
    )

In [None]:
summary = ModelSummary(object_dict["model"], max_depth=-1)
print(summary)


In [None]:
object_dict["net"]

In [None]:
for param in object_dict["net"].parameters():
    print(param.shape)


## Run DiGraph training

In [16]:
st_paths_real = ["${paths.data_dir}" + f"{st_file}" for st_file in st_files_real]
st_paths_sim = ["${paths.root_dir}" + f"/{st_file}" for st_file in st_files_sim] 

tags = ["digraph"]
project = "multi-channel-gnn"
wandb_mode = "test"
# wandb_mode = "disabled"
wandb_mode = "online"

if wandb_mode == "online":
    config.trainer.fast_dev_run = False
elif wandb_mode == "test":
    wandb_mode = "disabled"
    config.trainer.fast_dev_run = True
else:
    config.trainer.fast_dev_run = False

config.trainer.deterministic = True

st_path_real = st_paths_real[0]
st_path_sim = st_paths_sim[0]

print(f"Running experiment on {st_path_real} and {st_path_sim}")
config.data.st_path_real = st_path_real
config.data.st_path_sim = st_path_sim
metric_dict, object_dict = run_experiment(
    config,
    wandb_mode=wandb_mode,
    tags=tags,
    config_path="../../configs",
    progress_bar=False,
    device=7,
    project=project,
    overrides=["data=digraph.yaml"],
)

Running experiment on ${paths.data_dir}spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad and ${paths.root_dir}/../deconvolution/simulation/slideseqv2_mouse_kidney_2a/st_simulated.h5ad


Global seed set to 42


Key combine_graphs not in base config.
Key num_heads not in base config.
Key experiment not in base config.
Using deterministic mode
Removing genes with variance below 0.1
Genes in real data before filtering: 18961
Genes in simulated data before filtering: 18961
Genes in real data after filtering: 6113
Genes in simulated data after filtering: 13668
There are 6113 common genes in both datasets.
Input kwargs: {'num_heads': 4}


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669316325957577, max=1.0…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 280 K 
----------------------------------------
280 K     Trainable params
0         Non-trainable params
280 K     Total params
1.123     Total estimated model params size (MB)


Epoch: 0, mean_rmse: 0.13705699145793915, mean_corr: 0.0021247830932867185, mean_ccc: -3.0253615101293047e-05
Saving predictions...
Epoch: 499, mean_rmse: 0.16550347208976746, mean_corr: 0.11210999563122157, mean_ccc: 0.05838067045735656
Saving predictions...
Epoch: 999, mean_rmse: 0.16775226593017578, mean_corr: 0.11462167091773724, mean_ccc: 0.0665320533176891
Saving predictions...
Epoch: 1499, mean_rmse: 0.17159324884414673, mean_corr: 0.14774045026536492, mean_ccc: 0.10666477261501954
Saving predictions...
Epoch: 1999, mean_rmse: 0.1791640818119049, mean_corr: 0.09149982338274629, mean_ccc: 0.0912811945428558
Saving predictions...
Epoch: 2499, mean_rmse: 0.18254795670509338, mean_corr: 0.06559020202275195, mean_ccc: 0.07754926793556242
Saving predictions...
Epoch: 2999, mean_rmse: 0.18457502126693726, mean_corr: 0.05227734628803131, mean_ccc: 0.07107377777621235
Saving predictions...
Epoch: 3499, mean_rmse: 0.18641170859336853, mean_corr: 0.04247076805673238, mean_ccc: 0.0648367593

`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='20.337 MB of 20.337 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████▆▆▆▆▆▆▆▆
train/l1_loss,▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████████
train/l2_loss,▁▁▁▁▂▂▂▂▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇████████
train/mix_loss,▁▁▁▁▁▁▁▁▂▁▂▂▃▄▁▅▅▆▆▅▇▇▅██▇███▆▇▂▆█▅▇▆▇▆▇
train/sim_loss,█▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_loss,█▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/mean_ccc,▁▂█▆▄▃▂▂▁▁
validation/mean_ccc_,▁▂█▆▄▃▂▂▁▁

0,1
epoch,4999.0
train/beta,10.0
train/l1_loss,36822.35547
train/l2_loss,912.90271
train/mix_loss,0.00024
train/sim_loss,0.00571
train/total_loss,0.00906
trainer/global_step,4999.0
validation/mean_ccc_,0.05686
validation/mean_corr_,0.0338


## Run multiple trainings