In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.utilities.model_summary import ModelSummary
import wandb
import copy
import pandas as pd
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.data.utils import *
from src.utils.wandb import *
from src.utils.config_utils import *
from src.utils.experiment import run_experiment
# from src.utils.data_paths_valid import *
from src.utils.data_paths_anca import *


from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.nn import GCNConv, GCN, MLP, TransformerConv, GATv2Conv, GAT




# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

## Define data paths

In [None]:
st_paths, experiment_paths = get_paths_for_training(
    st_data_files,
    sc_data_files,
    celltype_cols=celltype_cols,
    extra_settings_sim=extra_settings_sim,
    extra_settings_prep=extra_settings_prep,
    use_old_experiment_paths=True,
    experiment_dirs=experiment_dirs,
)
# st_paths, experiment_paths


In [None]:
experiment_paths

## Setup experiment config

### Load default config

In [None]:
experiment_name = None
# experiment_name = "hybrid"
experiment_name = "multi_channel"
if experiment_name is not None:
    overrides = [f"experiment={experiment_name}"]
else:
    overrides = []

hydra.core.global_hydra.GlobalHydra.instance().clear()
# import GlobalHydra

config_path = "../configs"
config_name = "train.yaml"

initialize(version_base="1.3", config_path=config_path)
config = compose(
    config_name=config_name,
    overrides=overrides,
    return_hydra_config=True,
)
HydraConfig.instance().set_config(config)
OmegaConf.set_struct(config, False)
del config["hydra"]
config.experiment = experiment_name
# print_config(config)

### Load config from previous experiment

In [None]:
# run_name = "northern-river-217"
run_name = "fiery-music-355"
project = "dissect-spatial"

run_name = "electric-field-1"
run_name = "eager-wind-245"
project = "multi-channel-gnn"

config = get_run_config(run_name, project=project)
# run_name = "electric-sweep-37"
# # run_name = "swift-sweep-658"
# config = get_run_config(run_name, project="DISSECT-src")
config = convert_wandb_to_dict_config(config)
config.experiment = "multi_channel"

### Set custom values

#### Modify network architecture

In [None]:
# config.net.latent_dim = 64
# # config.net.encoder_kwargs.fusion = "gating"
# # config.net.encoder_kwargs.fusion = "concat_skip"
# config.net.encoder_kwargs.fusion = "concat"
# config.net.encoder_kwargs.lin_channel = True
# config.net.encoder_kwargs.num_heads = 4
# config.net.encoder_kwargs.num_layers = 1
# config.net.encoder_kwargs.use_ffn = False
# config.net.encoder_kwargs.use_pos = False
# config.net.encoder_kwargs.plain_last = False
# # config.net.use_pos = True
# config.net.encoder_kwargs.use_sparse = False
# config.net.decoder_kwargs.norm = None
# config.net.encoder_kwargs.norm = None
# config.net.encoder_kwargs.norm_last = False
# try:
#     del config.net.encoder_kwargs.knn
#     del config.net.encoder_kwargs.cosine
#     # del config.net.encoder_kwargs.num_heads
# except:
#     pass
# config.net.encoder_kwargs.spatial_channel_kwargs.num_layers = 2
# config.net.encoder_kwargs.spatial_channel_kwargs.activation = "elu"
# config.net.encoder_kwargs.spatial_channel_kwargs.back_to_latent = True
# config.net.encoder_kwargs.spatial_channel_kwargs.conv_type = "gcn"

In [None]:
print_config(config.net)

#### Modify training parameters

In [None]:
# config.data.radius = 0.04
# config.data.use_rctd_genes = True
# config.data.with_platform_effects = True
# config.data.renormalize = False

In [None]:
# config.model.l2_lambda = 5e-5
# config.model.l1_lambda = 0.0
# config.model.weight_decay = 0.0
# config.model.beta = 5
# config.model.learning_rate = 1e-3
# config.model.alpha_max = 0.1
# config.model.alpha_min = 0.0
# config.model.g_mix_from_g_real = False
# config.model.combine_graphs = True
config.model.combine_graphs = False

config.trainer.max_steps = 5000
# config.trainer.check_val_every_n_epoch = 500

In [None]:
print_config(config.model)
# print_config(config.trainer)

## Run training

In [None]:
# run training on all datasets if no dataset is specified
tags = "latest"
tags = "latestv2"
tags = "anca_latest"
# tags = "reference_switch"
# tags = "latestv2-switched"
project = "multi-channel-gnn"
wandb_mode = "disabled"
wandb_mode = "online"

if wandb_mode == "online":
    config.trainer.fast_dev_run = False
else:
    config.trainer.fast_dev_run = True
config.trainer.deterministic = True
# config.trainer.deterministic = False
for st_path, experiment_path in list(zip(st_paths, experiment_paths))[0::]:
    print(f"Running experiment on {st_path} and experiment path {experiment_path}")
    metric_dict, object_dict = run_experiment(
        config,
        st_path,
        experiment_path,
        wandb_mode=wandb_mode,
        tags=tags,
        config_path="../../configs",
        progress_bar=False,
        device=7,
        project=project,
    )


In [None]:
summary = ModelSummary(object_dict["model"], max_depth=-1)
print(summary)

In [None]:
1500 // 1000

In [None]:
object_dict["net"]

In [None]:
for param in object_dict["net"].parameters():
    print(param.shape)

## Run multiple trainings