In [2]:
default_values = {
    "GENE_EMB_NAME": "gene2vec", # gene2vec coding_pseudo coding_lncrna coding_smallrna coding_hs_mouse coding
    "TOTAL_NUMBER_OF_DATASETS": 5,
    "DATASET_TO_GET_FOR_MIXED_DATASET": None, # ["both", "human", "nonhuman", None]
    "HIDDEN_SIZE": 200,
    "PERFORMER_NET_LAST_LAYER_REQUIRES_GRAD": True,
    "FINETUNE_TO_RECONSTRUCT_EXPR_OF_ALL_GENES": False,
    "USE_PRETRAIN_MODEL_FOR_FINETUNE": True,
    "PRETRAIN_EXPERIMENT_FOR_FINETUNE": "exp9",
    "TRANSFORMER_MODEL_NAME": "Bert",
    "LAYER_NORM_EPS": 1e-12,
    "OUTPUT_ATTENTIONS": False,
    "OUTPUT_HIDDEN_STATES": False,
    "ONLY_USE_PERTURBED_GENE_TO_PREDICT": False,
    "LEARN_ON_ZERO_EXPR_GENES": False,
    "OUTPUT_PARAMETER_HIST_TO_TENSOBOARD_BY_BATCH": False,
    "TRANSFORMER_NORM_FIRST": True,
    "TRANSFORMER_HIDDEN_ACT_FUNC": "gelu",
    "MIN_MEAN_VAL_FOR_ZSCORE": 0.1,
    "SAMPLE_NUMBER_FOR_EACH_PERTURBATION": 10,
    "PERTURBED_GENE_ALWAYS_IN_INPUT_EXPR_IN_PERTURB_DATASET": False,
    "PRETRAIN_LOSS_ONLY_ON_MASKED_GENES": True,
    "USE_AND_KEEP_ZERO_EXPR_GENES": True,
    "NUM_OF_GENES_SELECTED": -1, # -1 for selecting all genes
    "ONLY_USE_POSITIVE_ZSCORES_IN_TRAINING": False,
    "SHUFFLE_GENE_INDICES_IN_EVALUATION": False,
    "SHUFFLE_EXPR_INDICES_IN_EVALUATION": False,
    "METHOD_TO_COMBINE_INPUT_AND_ENCODING": None,
    "NUM_BINS": 100,
    "MASK_FRACTIONS": [0.15, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
    "PERCENT_OF_MASKED_GENES_ASSIGNED_AS_TOKEN_ZERO": 0.8,
    "PERCENT_OF_MASKED_GENES_ASSIGNED_AS_RANDOM_TOKENS": 0.1,
    "BATCH_SIZE": 8,
    "MODEL_DEPTH": 6,
    "NUM_HEADS": 8,
    "DIM_HEAD": 32,
    "NO_RPOJECTION": False,
    "MODEL_REVERSIBLE": True,
    "FEATURE_REDRAW_INTERVAL": 1000,
    "EMB_DROPOUT": 0.,
    "FF_DROPOUT": 0.1,
    "ATTN_DROPOUT": 0.1,
    "OUTPUTLAYER2FCS_DROPOUT_RATE": 0.1,
    "GENERALIZED_ATTENTION": False,
    "EXPRESSION_EMB_TYPE": "positional",
    "TO_OUT_LAYER_TYPE": "2FCs",
    "OUTPUT_LAYER_HIDDEN_SIZE1": 40,
    "OUTPUT_LAYER_HIDDEN_SIZE2": 20,
    "PRETRAINED_TOKEN_EMB_FOR_INIT": False,
    "GENE_ID_EMB_REQUIRES_GRAD": True,
    "EXPR_EMB_REQUIRES_GRAD": True,
    "BASE_LR": 0.00001,
    "MAX_LR": 0.0001,
    "EPOCH_TO_HAVE_MANUAL_LR": 30,
    "ONE_CYCLE_LR_PCT_START": 0.2,
    "ONE_CYCLE_LR_DIV_FACTOR": 5,
    "ONE_CYCLE_LR_TOTAL_STEPS": 40,
    "ONE_CYCLE_LR_EPOCHS": 40,
    "STEP_SIZE_UP": 4,
    "EXPR_DISCRETIZATION_METHOD": "Direct_quantile",
    "TRAINING_SET_FRACTION": 0.9,
    "GRADIENT_ACCUMULATION_STEPS": 5,
    "OPTIMIZER": "AdamW",
    "ADAMW_WEIGHT_DECAY": 0.01,
    "LOSS_FN": "MSE",
    "SCHEDULER": "OneCycleLR",
    "SAVE_CHECK_POINT_BY_BATCHES": False,
    "FRACTION_OF_SAMPLES_TO_BE_FAKE": 0.5, 
    "FRACTION_OF_GENES_TO_HAVE_RANDOM_EXPR": 0.3,
    "SPECIFIED_PRETRAIN_MODEL_CHECKPOINT_PATH": None
}

In [3]:
common_params_funcs = ["TENSORBOARD_LOG_DIR_PATH", "PRETRAIN_MODEL_CHECKPOINT_PATH", "BASE_SEED", "config", "set_seed", "get_special_encoding", "cleanup", "normalize_expression", "worker_init_fn", "custom_histogram", "add_histogram_to_tensorboard", "get_current_learning_rate", "shuffle_sequences_old", "shuffle_sequences", "get_pred_using_model_and_input", "extract_hidden_states", "get_layers_in_model", "get_gene_symbols_filt_on_z_dup", "get_gene2idx", "get_gene2idx_of_whole_gene_emb", "get_gene2idx_no_special_token"]

common_funcs = ["output_to_a_file", "train", "evaluate", "output_parameter_hist_to_tensorboard", "initiate_model"]

In [9]:

def replace_text_in_file(input_file, output_file):
    """
    Reads an input file, replaces all occurrences of the target text with the replacement text, 
    and writes the modified content to an output file.
    
    Parameters:
    input_file (str): The path to the input file.
    output_file (str): The path to the output file where the modified content will be written.
    target_text (str): The text to be replaced.
    replacement_text (str): The text to replace with.
    """
    try:
        # Read the input file
        with open(input_file, 'r') as file:
            file_contents = file.read()
        modified_contents = file_contents
        found = False
        for target_text in default_values.keys():
            replacement_text = f"params.{target_text}"
            if target_text in modified_contents:
                print(f"replace({target_text}, {replacement_text})")
                modified_contents = modified_contents.replace(target_text, replacement_text)
                found = True
        if found:
                modified_contents = "from utils.params import params\n" + modified_contents
        
        new_common_param_import_str = None
        new_common_import_str = None
        for ss in common_params_funcs:
            if ss in modified_contents:
                if new_common_param_import_str == None:
                    new_common_param_import_str = f"from train.common_params_funs import {ss}"
                else:
                    new_common_param_import_str += f", {ss}"
        for ss in common_funcs:
            if ss in modified_contents:
                if new_common_import_str == None:
                    new_common_import_str = f"from train.common import {ss}"
                else:
                    new_common_import_str += f", {ss}"
        if new_common_param_import_str != None:
            modified_contents = modified_contents.replace("from train.common_params_funs import *", f"{new_common_param_import_str}\n{new_common_import_str}\n")
        if new_common_import_str != None:
            modified_contents = modified_contents.replace("from train.common import *", f"{new_common_param_import_str}\n{new_common_import_str}\n")
        
        # Write the modified contents to the output file
        if "params.py" in input_file:
            modified_contents = file_contents
        with open(output_file, 'w') as file:
            file.write(modified_contents)
            
        print(f"File '{input_file}' has been processed and output to '{output_file}'.")
    except Exception as e:
        print(f"An error occurred: {e}")

input_file = "/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/doc/before_changing_params/src/train/pretrain.py"
output_file = "/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/tmp/src_refactoring/AA.py"

input_file = "/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/doc/before_changing_params/src/data/ARCHSDataset.py"
output_file = "/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/tmp/src_refactoring/BB.py"

replace_text_in_file(input_file, output_file)


replace(GENE_EMB_NAME, params.GENE_EMB_NAME)
replace(TRANSFORMER_MODEL_NAME, params.TRANSFORMER_MODEL_NAME)
replace(MIN_MEAN_VAL_FOR_ZSCORE, params.MIN_MEAN_VAL_FOR_ZSCORE)
replace(USE_AND_KEEP_ZERO_EXPR_GENES, params.USE_AND_KEEP_ZERO_EXPR_GENES)
replace(NUM_OF_GENES_SELECTED, params.NUM_OF_GENES_SELECTED)
replace(ONLY_USE_POSITIVE_ZSCORES_IN_TRAINING, params.ONLY_USE_POSITIVE_ZSCORES_IN_TRAINING)
replace(NUM_BINS, params.NUM_BINS)
replace(MASK_FRACTIONS, params.MASK_FRACTIONS)
replace(EXPR_DISCRETIZATION_METHOD, params.EXPR_DISCRETIZATION_METHOD)
File '/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/doc/before_changing_params/src/data/ARCHSDataset.py' has been processed and output to '/g/data/yr31/zs2131/tasks/2023/RNA_expr_net/tmp/src_refactoring/BB.py'.
