In [16]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

class PredAnnModel:
    def __init__(
        self,
        current_genes,
        dropout_rate=0.3,
        balance=True,
        l2_reg=-0.2,
        batch_size=16,
        num_epochs=5000,
        report_frequency=1,
        auc_threshold=0.9,
        clipnorm=2.0,
        simplify_categories=True,
        holdout_size=0.5,
        multiplier=3,
        auc_thresholds = [0.6, 0.7, 0.8, 0.85, 0.88,0.89,0.90,0.91,0.92],
        lr_dict = lr_dict = {
                    0.6:  0.005,
                    0.7:  0.001,
                    0.8:  0.0005,
                    0.85: 0.0001,
                    0.88: 0.00005,
                    0.89: 0.00001,
                    0.9:  0.000005,
                    0.91: 0.000001,
                    0.92: 0.0000005
}
    ):
        """
        Initializes the PredAnnModel with specified hyperparameters and configuration.

        Parameters:
        - current_genes (list): A non-empty list of genes to be used as model features.
        - dropout_rate (float): Dropout rate to prevent overfitting (default: 0.3).
        - balance (bool): Whether to balance technology and outcome variables during training (default: True).
        - l2_reg (float): Strength of L2 regularization (default: -0.2).
        - batch_size (int): Batch size for training (default: 16).
        - num_epochs (int): Total number of training epochs (default: 5000).
        - report_frequency (int): Frequency of reporting model metrics (AUC and Accuracy) during training (default: 1).
        - auc_threshold (float): AUC threshold for early stopping (default: 0.9).
        - clipnorm (float): Gradient clipping norm to prevent exploding gradients (default: 2.0).
        - simplify_categories (bool): Whether to simplify categories in the dataset (default: True).
        - holdout_size (float): Proportion of samples withheld during training (default: 0.5).
        - multiplier (int): Scales the number of nodes in most network layers (default: 3).
        - auc_thresold (list): auc values for the test set for which the learning rate should be adjusted
        - lr_dict (dict): Dictionary defining the learning rate based on the measured test set accuracy.

        Raises:
        - ValueError: If `current_genes` is not a non-empty list.
        """
        if not isinstance(current_genes, list) or not current_genes:
            raise ValueError("The 'current_genes' parameter must be a non-empty list of genes.")

        self.current_genes = current_genes  # List of genes provided by the user to define model features.
        self.dropout_rate = dropout_rate  # Dropout rate for regularization.
        self.balance = balance  # Balance technology and outcomes during training.
        self.l2_reg = l2_reg  # Degree of L2 regularization.
        self.batch_size = batch_size  # Batch size for training.
        self.num_epochs = num_epochs  # Total number of training epochs.
        self.report_frequency = report_frequency  # Frequency for collecting metrics during training.
        self.auc_threshold = auc_threshold  # AUC threshold for early stopping.
        self.clipnorm = clipnorm  # Gradient clipping value to prevent exploding gradients.
        self.simplify_categories = simplify_categories  # Whether to reduce data categories (e.g., microarray vs. sequencing).
        self.holdout_size = holdout_size  # Proportion of samples withheld during training.
        self.multiplier = multiplier  # Scales the number of nodes in most layers of the network.
        self.auc_thresholds = auc_thresholds  # AUC values at which the learning rate should be adjusted
        self.lr_dict =  lr_dict  # Dynamically adjusts the learning rate based on the test set accuracy



In [39]:
import subprocess
from rpy2.robjects import r

from rc_data_preparation import RcDataPreparation
current_data = RcDataPreparation()

def get_genes_list(p_thresh, split_train):
    # Define the R script path
    r_script = "rc_get_diff_genes.r"
    
    # Build the command to run the R script
    command = ["Rscript", r_script, str(p_thresh), str(split_train)]
    
    result = subprocess.run(command, capture_output=True, text=True)
    
    # Check if the R script ran successfully
    if result.returncode == 0:
        print("R script executed successfully.")
    
        # Read the generated file
        rds_path = '/tmp/work/RCproject_code/sean_ann_python/ann_gene_set.rds'
        current_genes = r.readRDS(rds_path)
        print(len(current_genes))
        
    else:
        print("Error in R script execution:")
        print(result.stderr)

    return(current_genes)

test_list = get_genes_list(0.001, True)
test_list = get_genes_list(0.01, True)
test_list = get_genes_list(0.02, True)
test_list = get_genes_list(0.03, True)
test_list = get_genes_list(0.04, True)
test_list = get_genes_list(0.05, True)
test_list = get_genes_list(0.06, True)
test_list = get_genes_list(0.07, True)
test_list = get_genes_list(0.08, True)
test_list = get_genes_list(0.09, True)
test_list = get_genes_list(0.1, True)

# next steps

# complete rc_pred_ann_model class
# include feature selection
# execute one instance of the model and return neccesary values


Data successfully loaded.
R script executed successfully.
3
R script executed successfully.
55
R script executed successfully.
137
R script executed successfully.
211
R script executed successfully.
305
R script executed successfully.
392
R script executed successfully.
494
R script executed successfully.
599
R script executed successfully.
699
R script executed successfully.
806
R script executed successfully.
910


In [None]:
current_data.train

In [33]:
current_data.X_train.T.loc[test_list].T

# counts_df.loc[current_genes]

Unnamed: 0,ADD3,ADM,AGL,AHNAK2,ANGPTL4,ARFRP1,ARPIN,ASF1A,ATG9B,ATP10D,...,ZBTB41,ZBTB46,ZDHHC12,ZFAND6,ZFP28,ZNF106,ZNF318,ZNF37A,ZNF91,ZSCAN5A
GSM5732588_GSE190826,1.180988,-0.688915,-0.580224,0.952104,-0.950512,1.052496,0.533552,-0.309182,-1.525149,0.139732,...,0.277266,-0.060575,-0.463336,0.870671,-0.666224,1.865663,0.946850,0.720781,1.776053,-0.569965
GSM4523147_GSE150082,0.309097,1.344393,-0.209032,1.239429,2.095264,-0.224816,1.435388,-0.445756,0.257738,-0.864950,...,0.343656,0.677148,0.194815,0.104485,0.180254,-0.053188,0.215329,0.053610,-0.118891,0.422420
GSM6390460_GSE209746,1.961813,-0.631515,0.399917,1.300648,-1.407259,1.234002,0.328982,-0.067544,0.196523,0.116284,...,0.614654,-0.163843,-0.527112,0.377993,-0.813033,1.240820,1.209928,0.723053,1.710498,-0.707765
GSM1103637_GSE45404-2,-0.358321,-2.217002,0.061467,0.493808,-0.360714,-0.208450,0.015326,0.930978,-0.518636,-2.985596,...,1.075983,-0.599135,0.306217,-0.321206,0.331521,-0.060993,0.637079,0.584704,1.567761,2.200492
GSM6390447_GSE209746,1.427188,-0.124451,0.505614,-0.004200,-0.823146,1.274152,0.278339,-0.206245,-1.146486,0.391364,...,0.521848,-0.427769,-0.384109,0.717061,-0.381875,1.456431,0.731867,1.181563,1.536453,-0.711073
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GSM5732594_GSE190826,1.426613,-0.226100,-0.330146,0.651468,-0.593693,1.080082,0.804484,-0.435530,-0.381600,-0.069083,...,0.768841,-0.435530,-0.754010,1.124957,-0.626080,1.706369,0.845288,0.896642,1.937302,-0.257111
GSM4523152_GSE150082,-1.012080,-0.939375,0.114690,0.592377,-0.213290,0.476715,-0.230018,0.134202,-0.594097,-1.303909,...,-1.071811,-1.202016,0.170496,-0.486505,-0.208698,-0.259765,-0.050148,0.449278,-0.223684,-0.084292
GSM6390449_GSE209746,1.107566,0.320661,0.132498,-0.212398,-0.572163,1.183067,0.109052,-0.064079,-0.480885,-0.021883,...,0.327834,-0.584028,-0.526510,0.905278,-0.693374,1.460496,0.746609,1.026593,1.606066,-0.394899
GSM5732574_GSE190826,1.154417,-0.485386,-0.447063,2.770060,-0.584077,1.533757,0.535338,-0.237766,0.732368,0.856403,...,0.249021,-0.665466,-0.656065,0.744469,-0.965747,1.614586,0.749271,0.418151,1.648676,-0.533465


In [35]:
test = 0.05

In [37]:
str(test)

'0.05'