In [1]:
import os
import yaml
import itertools
from copy import deepcopy

import numpy as np

In [2]:
CONFIG_FNAME = "bnfix_minmax_samp_split.yml"
MODEL_NAME = "CellDART"

In [3]:
with open(os.path.join("configs", MODEL_NAME, CONFIG_FNAME), "r") as f:
    config = yaml.safe_load(f)
print(yaml.safe_dump(config))

data_params:
  all_genes: false
  data_dir: data
  n_markers: 20
  n_mix: 8
  n_spots: 200000
  samp_split: true
  sample_id_n: '151673'
  scaler_name: minmax
  st_split: false
lib_params:
  manual_seed: 2634
model_params:
  celldart_kwargs:
    bn_momentum: 0.01
    emb_dim: 64
  model_version: bnfix_minmax_samp_split
train_params:
  alpha: 0.6
  alpha_lr: 5
  batch_size: 512
  early_stop_crit: 100
  early_stop_crit_adv: 10
  initial_train_epochs: 10
  lr: 0.001
  min_epochs: 10
  min_epochs_adv: 10
  n_iter: 3000
  pretraining: true



In [4]:
## CellDART

# data_params
n_markers_l = [5, 10, 20, 40, 80]
n_mix_l =  [5, 8, 10, 15, 20]
n_spots_l = [5000, 10000, 20000, 50000, 100000, 200000]
scaler_name_l = ["minmax", "standard"]

# model_params
bn_momentum_l = [0.01, 0.1, 0.9, 0.99]


# train_params
alpha_l =  [0.1, 0.6, 1.0, 2.0]
alpha_lr_l = [1, 2, 5, 10]
batch_size_l = [256, 512, 1024]
lr_l =[0.01, 0.001, 0.0001]


In [5]:
total_configs = len(n_markers_l)*len(n_mix_l)*len(n_spots_l)*len(scaler_name_l)*len(bn_momentum_l)*len(alpha_l)*len(alpha_lr_l)*len(batch_size_l)*len(lr_l)
total_configs

172800

In [6]:
config

{'data_params': {'all_genes': False,
  'data_dir': 'data',
  'n_markers': 20,
  'n_mix': 8,
  'n_spots': 200000,
  'samp_split': True,
  'sample_id_n': '151673',
  'scaler_name': 'minmax',
  'st_split': False},
 'lib_params': {'manual_seed': 2634},
 'model_params': {'celldart_kwargs': {'bn_momentum': 0.01, 'emb_dim': 64},
  'model_version': 'bnfix_minmax_samp_split'},
 'train_params': {'alpha': 0.6,
  'alpha_lr': 5,
  'batch_size': 512,
  'early_stop_crit': 100,
  'early_stop_crit_adv': 10,
  'initial_train_epochs': 10,
  'lr': 0.001,
  'min_epochs': 10,
  'min_epochs_adv': 10,
  'n_iter': 3000,
  'pretraining': True}}

In [7]:
rng = np.random.default_rng(567)

yes_samples = set(rng.choice(total_configs, size=1000, replace=False))


data_params_l = []
for n_markers, n_mix, n_spots, scaler_name in itertools.product(n_markers_l, n_mix_l, n_spots_l, scaler_name_l):
    data_params_l.append(dict(n_markers=n_markers, n_mix=n_mix, n_spots=n_spots, scaler_name=scaler_name))

model_params_l = []
for bn_momentum in bn_momentum_l:
    model_params_l.append(dict(bn_momentum=bn_momentum))

train_params_l = []
for alpha, alpha_lr, batch_size, lr in itertools.product(alpha_l, alpha_lr_l, batch_size_l, lr_l):
    train_params_l.append(dict(alpha=alpha, alpha_lr=alpha_lr, batch_size=batch_size, lr=lr))


if not os.path.exists(os.path.join("configs/generated", MODEL_NAME)):
    os.makedirs(os.path.join("configs/generated", MODEL_NAME))
for i, (data_params, model_params, train_params) in enumerate(itertools.product(data_params_l, model_params_l, train_params_l)):

    if i not in yes_samples:
        continue
    new_config = deepcopy(config)
    new_config["data_params"].update(data_params)
    new_config["model_params"]["celldart_kwargs"].update(model_params)
    new_config["train_params"].update(train_params)
    new_config["lib_params"]["manual_seed"] = int(rng.integers(0, 2**32))

    new_config["train_params"]["n_iter"] = 30000



    version = f"gen_v1_perm_{i}"
    new_config["model_params"]["model_version"] = version

    with open(os.path.join("configs/generated", MODEL_NAME, f"{version}.yml"), "w") as f:
        yaml.safe_dump(new_config, f)

In [8]:
new_config


{'data_params': {'all_genes': False,
  'data_dir': 'data',
  'n_markers': 80,
  'n_mix': 20,
  'n_spots': 200000,
  'samp_split': True,
  'sample_id_n': '151673',
  'scaler_name': 'standard',
  'st_split': False},
 'lib_params': {'manual_seed': 2223721707},
 'model_params': {'celldart_kwargs': {'bn_momentum': 0.99, 'emb_dim': 64},
  'model_version': 'gen_v1_perm_172778'},
 'train_params': {'alpha': 2.0,
  'alpha_lr': 2,
  'batch_size': 512,
  'early_stop_crit': 100,
  'early_stop_crit_adv': 10,
  'initial_train_epochs': 10,
  'lr': 0.0001,
  'min_epochs': 10,
  'min_epochs_adv': 10,
  'n_iter': 30000,
  'pretraining': True}}

In [9]:
import glob

In [10]:
lines = [
    os.path.basename(name)
    for name in glob.glob(os.path.join("configs/generated", MODEL_NAME, "*.yml"))
]
with open(
    os.path.join("configs/generated", MODEL_NAME, "a_list.txt"),
    mode="wt",
    encoding="utf-8",
) as myfile:
    myfile.write("\n".join(lines))
    myfile.write("\n")