In [1]:
import re
import os
import sys
import pickle
import time
import random
from datetime import datetime
import logging
from functools import reduce

import scanpy as sc
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import RFE,VarianceThreshold,SelectKBest,chi2, SelectFromModel, f_classif, mutual_info_classif
from sklearn.model_selection import cross_validate, GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
import matplotlib.pyplot as plt

In [2]:
def convert_data(data = None, path = None, assay = ".X",label_column = None):
    """
    A function to convert sparse matrix into pandas data frame object. 
    
    Parameters:
    data: a Annot Data object. If speciefied, path should be specified in 'None'.
    path: path to h5ad file. If speciefied, data should be specified in 'None'.
    assay: ".X" or "raw" assay in Annot Data object could be specified
    label_column: the name of cell type column in the h5ad file. 
    If specified, the cell type column will be added into output.
    """
    if path != None:
        data = sc.read(path)
    
    if assay == ".X":
        counts = pd.DataFrame.sparse.from_spmatrix(data.X)
    else:
        counts = pd.DataFrame.sparse.from_spmatrix(data.raw.X)
        
    features = data.raw.var_names.tolist()
    index  = data.raw.obs_names.tolist()
    
    counts.columns = features
    counts.index = index
    
    if label_column != None:
        try:
            labels = data.obs[label_column].tolist()
            counts["cell_type"] = labels
        except:
            raise ValueError("The length of cell type column is not consistent with matrix")

    return counts

In [3]:
def quantile_normalize(data):
    """
    A function to do quantile normalization.  
    
    """
    data = data.loc[:, data.columns != "cell_type"]
    ranks = (data.rank(method = "first").stack())
    rank_mean = (data.stack().groupby(ranks).mean())
    # add interproblated values in between ranks
    finer_ranks = ((rank_mean.index + 0.5).tolist() + rank_mean.index.tolist())
    rank_mean = rank_mean.reindex(finer_ranks).sort_index().interpolate()
    data = data.rank(method = "average").stack().map(rank_mean).unstack()
    
    return data
    

In [4]:
def record_time():
    """
    A function to call out current time.
    """
    current_second_time = datetime.now()
    return current_second_time.strftime("%Y-%m-%d %H:%M:%S")

In [5]:
class log_file:
    """
    A class to easily write log information into different log files.
    
    filename: log file name 
    mode: "a" (append log information) or "w" (clean up old log information and then write new information)
    """
    n_class = 0
    def __init__(self, filename, mode):
        self.time = record_time()
        self.filename = filename
        self.mode = mode
        self.logger = None
        self.n_object = 0
        log_file.n_class += 1
        
    def set(self):
        logger = logging.getLogger("my_logger")
        logger.setLevel(logging.INFO)
        log_file_name = "{}_{}.log".format(self.filename, self.time)
        file_handler = logging.FileHandler(log_file_name, mode = self.mode)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        self.logger = logger
        
    def start(self):
        if log_file.n_class == 1 and self.n_object == 0:
            self.set()
        if log_file.n_class > 1 and self.n_object == 0:
            self.set()
        
    def write(self, category, content):
        self.start()
        self.n_object += 1
        if category == "info":
            self.logger.info(content)
        elif category == "error":
            self.logger.error(content)
        elif category == "warning":
            self.logger.warning(content)
        elif category == "debug":
            self.logger.debug(content)
        elif category == "critical":
            self.logger.critical(content)
    
    def mode_reset(self, mode):
        self.mode = mode
        
    def clear():
        log_file.n_class = 0

In [6]:
def remove_time(string):
    """
    A function to remove time information in the stdout message"
    """
    pattern = re.compile("\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}")
    res = re.sub(pattern, "", string)
    return res.rstrip().lstrip()

In [7]:
# create a log_file object for downstream recording log information
#log_file.clear()
my_logger = log_file("mylogfile4", "a")

In [9]:
def pre_processing(path, 
                   label_column,
                   assay = ".X", 
                   convert = True,
                   log_normalize = True, 
                   scale_data = False, 
                   quantile_normalize = False,
                   save = False,
                   logger = None):
    """
    A function to do preliminary normalization, encoding label and convert data into pandas data frame for downstream sklearn-based
    machine learning workflow.
    
    Parameters:
    path: path to the h5ad file.
    assay: ".X" or "raw" assay in Annot Data object could be specified.Default is ".X".
    label_column: The name of cell type column in the h5ad file.If specified, the cell type column will be added into output.
    convert: Bool value to decide whether convert Annot Data object into pandas data frame object.
    log_normalize: Bool value to decide whether standard log normalization to be done.
    scale_data: Bool value to decide whether standadlize data or not.
    quantile_normalize: Bool value to decide whether quantile normalize data or not.
    save: Bool value to decide whether write the pre-processed data into the disk.
    logger: A log_file object to write log information into disk. Default is None. 
    """
    if logger != "None":
        logger.write("info", "start pre processing")
        logger.write("critical", "Parameters used for pre_processing")
        parameters = {"path": path,
                      "label_column": label_column,
                      "assay": assay,
                      "convert": convert,
                      "log_normalize": log_normalize,
                      "scale_data": scale_data,
                      "quantile_normalize": quantile_normalize,
                      "save": save
                     }
        for key,value in parameters.items():
            logger.write("critical", "{}: {}".format(key, value))
    
    data = sc.read(path)
    
    if log_normalize:
        sc.pp.normalize_total(data, target_sum = 1e4)
        sc.pp.log1p(data)
    
    if scale_data:
        sc.pp.scale(data)
        
    if convert:
        counts = convert_data(data = data, assay = assay, label_column = label_column)
    else:
        counts = data.X
        
    if quantile_normalize:
        quant_norm_data = convert_data(data, assay = assay)
        counts = quantile_normalize(data.X)
    
    # convert string label into numeric label
    labels = data.obs[label_column].unique().tolist()
    labels.sort()
    label = data.obs[label_column].apply(lambda x: labels.index(x))
    
    res = {"matrix": counts, 
           "convert_label": label, 
           "original_label": data.obs[label_column], 
           "sort_uniq_label": labels} 
    
    if save:
        file_name = "Preprocessing_data_{}.pkl".format(record_time())
        with open(file_name, "wb") as output:
            pickle.dump(res, output)
    
    if logger != "None": 
        logger.write("info", "finish pre processing")
    
    return res

In [10]:
data = pre_processing("./cellhint_demo_folder/cellhint_demo_folder/Spleen.h5ad", 
                      assay = ".X", 
                      logger = my_logger,
                      label_column = "cell_type",
                      convert = True,
                      scale_data = False,
                      save = False)



In [14]:
def feature_selection(data, 
                      label_column,
                      filename,
                      logger = None,
                      mode = "ensemble",
                      random_foreast_threshold = None,
                      #SVM_threshold = None, 
                      variance_threshold = "zero",
                      mutual_info = True, 
                      chi_square_test = False,
                      F_test = False,
                      model = "random_foreast",
                      n_estimators = 100,
                      random_state = 10,
                      kernel = "linear",
                      decision_function_shape = "ovo",
                      n_features_to_select = None,
                      step = 100,
                      save = True):
    """
    A function to do feature seletion based on filtering, embedding and wrapping method respectively or combing those methods together.
    
    Parameters:
    data: A pandas data frame object.
    label_column: The name of cell type column in the data.
    logger: A log_file object to write log information into disk. Default is None. 
    random_foreast_threshold. A float or int value to set the cutoff (feature_importance_) by random foreast model-basedd embedding feature selection.It needs to be specified when model is set in 'random_foreast'. Default is `1 / the number of all features`
    variance_threshold: A string to decide which variance cutoff is used to filter out features."zero" or "median" could be selected. 
    mutual_info: Bool value decide whether a mutual information method is employed to filtering out features further.
    chi_sqaure_test: A Bool value decide whether a chi square test method is employed to filtering out features further.
    F_test: Bool value decide whether a F test method is employed to filtering out features further.
    model: String to decide which model is used by embedding-based feature selection. "random_foreast", "logistic" and "svm" could be selected.
    n_estimators: The number of trees in the forest.
    random_state: Controls both the randomness of the bootstrapping of the samples used when building trees (if ``bootstrap=True``) and the sampling of thefeatures to consider when looking for the best split at each node.
    kernel: Specifies the kernel type to be used in the support vector machine algorithm.
    decision_function_shape: Whether to return a one-vs-rest ('ovr') decision function of shape (n_samples, n_classes) as all other classifiers, or the originalone-vs-one ('ovo') decision function of libsvm which has shape.(n_samples, n_classes * (n_classes - 1) / 2)
    n_featurs_to_selct: int or float, default=None.The number of features to select. If `None`, half of the features are selected. If integer, the parameter is the absolute number of features to select. If float between 0 and 1, it is the fraction of features to select.
    step:int or float, default=1. If greater than or equal to 1, then ``step`` corresponds to the (integer) number of features to remove at each iteration. If within (0.0, 1.0), then ``step`` corresponds to the percentage (rounded down) of features to remove at each iteration.
    save: Bool value to decide whether write the pre-processed data into the disk.
    """
    message = "{} {}".format(record_time(), "start feature selection")
    print(message)
    
    if logger != None:
        logger.write("info", remove_time(message))
    
        parameters = {
                     "label_column": label_column,
                     "file_name": filename,
                     "mode": mode,
                     "random_foreast_threshold": random_foreast_threshold,
                     "variance_threshold": variance_threshold,
                     "mutual_info": mutual_info,
                     "chi_square_test": chi_square_test,
                      "F_test": F_test,
                      "model" : model,
                      "n_estimators": n_estimators,
                      "random_state": random_state,
                      "kernel": kernel,
                      "decision_function_shape": decision_function_shape,
                      "n_features_to_select": n_features_to_select,
                      "step" : step,
                      "save" : save
                     }

        for key, value in parameters.items():
            logger.write("critical", "{}: {}".format(key, value))
    
    # step 1 - convert category label into numeric label
    message = "{} step 1 - converting categoric label into numeric label".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
  
    
    le = LabelEncoder().fit(data[label_column])
    label = le.transform(data[label_column])
    data[label_column] = label
        
    X = data.iloc[:, 1:-1]
    y = data.iloc[:, -1]
    
    # step 2 - do feature selection
    message = "{} step 2 - do feature selection".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
    message = "{} ======== filtering based selection ========".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
        
    # filtering-based feature selection
    # filter out by variance 
    if variance_threshold == "zero":
        var_selector = VarianceThreshold()
        X_var = var_selector.fit_transform(X)
        retained_features_by_filter = X.columns[var_selector.get_support()]
        message = "* {} {} features remained after filter out features with 0 variance".format(record_time(), X_var.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    elif variance_threshold == "median":
        var_selector = VarianceThreshold(np.median(np.var(np.array(X), axis = 0)))
        X_var = var_selector.fit_transform(X)
        retained_features_by_filter = X.columns[var_selector.get_support()]
        message = "* {} {} features remained after filter out features below median variance of all features".format(record_time(), X_var.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # filter by chi sqaure test
    if chi_square_test and F_test == False:
        chivalue, pvalues_chi = chi2(X_var, y)
        k = chivalue.shape[0] - (pvalues_chi > 0.05).sum()
        selector = SelectKBest(chi2, k = k)
        X_fschi = selector.fit_transform(X_var, y)
        retained_features_by_filter = retained_features_by_filter[selector.get_support()]
        message = "** {} {} features remained after further chi sqaure test filtering".format(record_time(), X_fschi.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # filter by F test
    if F_test and chi_square_test == False:
        F, pvalues_f = f_classif(X_var, y)
        k = F.shape[0] - (pvalues_f > 0.05).sum()
        selector = SelectKBest(f_classif, k = k)
        X_fsF = selector.fit_transform(X_var, y)
        retained_features_by_filter = retained_features_by_filter[selector.get_support()]
        message = "** {} {} features remained after further chi sqaure test filtering".format(record_time(), X_fsF.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # filter by mutual infomation
    if (F_test == False and chi_square_test == False) and mutual_info:
        res = mutual_info_classif(X_var, y)
        k = res.shape[0] - sum(res <= 0)
        selector = SelectKBest(mutual_info_classif, k = k)
        X_fsmic = selector.fit_transform(X_var, y)
        retained_features_by_filter = retained_features_by_filter[selector.get_support()]
        message = "** {} {} features remained after further mutual information filtering".format(record_time(), X_fsmic.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    message = "{} ======== embedding based selection ========".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))

    # embedding-based on feature selection
    # select by random foreast model
    if model == "random_foreast":
        RFC_ = RandomForestClassifier(n_estimators = n_estimators, random_state = random_state)
        # when random_foreast_threshold is None, 
        #  `1 / number of features` will be used as threshold
        if random_foreast_threshold == None:
            random_foreast_threshold = 1 / X.shape[1] 
        RFC_embedding_selector = SelectFromModel(RFC_, threshold = random_foreast_threshold)
        X_RFC_embedding = RFC_embedding_selector.fit_transform(X, y)
        retained_features_by_embedding = RFC_embedding_selector.get_feature_names_out()
        message = "* {} {} features remained after random foreast based embedding filtering".format(record_time(), X_RFC_embedding.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # select by logistic regression model
    elif model == "logistic":
        logistic_ = LogisticRegression(multi_class = "multinomial", random_state = random_state, max_iter=200)
        log_embedding_selector = SelectFromModel(logistic_, norm_order = 1)
        X_log_embedding = log_embedding_selector.fit_transform(X, y)
        retained_features_by_embedding = log_embedding_selector.get_feature_names_out()
        message = "* {} {} features remained after logistic regression based embedding filtering".format(record_time(), X_log_embedding.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # select by SVM model
    elif model  == "svm":
        SVC_ = SVC(decision_function_shape = decision_function_shape, kernel = kernel)
        SVC_embedding_selector = SelectFromModel(SVC_, norm_order = 1)
        X_SVC_embedding = SVC_embedding_selector.fit_transform(np.array(X), y)
        retained_features_by_embedding = X.columns[SVC_embedding_selector.get_support()]
        message = "* {} {} features remained after svm based embedding filtering".format(record_time(), X_SVC_embedding.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    # selected by wrapping method
    message = "{} ======== wrapping based selection ========".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
    if n_features_to_select == None:
        # when features to select is None,
        # 50% of all features will be used as threshold
        n_features_to_select = int(X.shape[1] * 0.5)

    if model == "random_foreast":
        RFC_ = RandomForestClassifier(n_estimators = n_estimators, random_state = random_state)
        RFC_wrapping_selector = RFE(RFC_, n_features_to_select = n_features_to_select, step = step)
        X_RFC_wrapping = RFC_wrapping_selector.fit_transform(X, y)
        retained_features_by_wrapping = X.columns[RFC_wrapping_selector.support_]
        message = "* {} {} features remained after RFE - random foreast based wrapping filtering".format(record_time(), X_RFC_wrapping.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    elif model == "logistic":
        logistic_ = LogisticRegression(multi_class = "multinomial", random_state = random_state, max_iter=200)
        log_wrapping_selector = RFE(logistic_, n_features_to_select = n_features_to_select, step = step)
        X_log_wrapping = log_wrapping_selector.fit_transform(X, y)
        retained_features_by_wrapping = X.columns[log_wrapping_selector.support_]
        message = "* {} {} features remained after RFE - logistic regression based wrapping filtering".format(record_time(), X_log_wrapping.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    elif model == "svm":
        SVC_ = SVC(decision_function_shape = decision_function_shape, kernel = kernel)
        SVC_wrapping_selector = RFE(SVC_, n_features_to_select = n_features_to_select, step = step)
        X_SVC_wrapping = SVC_wrapping_selector.fit_transform(np.array(X), y)
        retained_features_by_wrapping = X.columns[SVC_wrapping_selector.support_]
        message = "* {} {} features remained after RFE - svm based wrapping filtering".format(record_time(), X_SVC_wrapping.shape[1])
        print(message)
        if logger != None:
            logger.write("info", remove_time(message))

    message = "{} ======== final feature selection ========".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
    retained_features_by_filter = set(retained_features_by_filter)
    retained_features_by_embedding = set(retained_features_by_embedding)
    retained_features_by_wrapping = set(retained_features_by_wrapping)
    final_feture_selection = reduce(lambda x,y: x.intersection(y), [retained_features_by_embedding, retained_features_by_filter, retained_features_by_wrapping])
    message = "* {} {} features remained after intersecting the key features found by filtering, embedding and wrapping-based feature selection methods".format(record_time(), len(final_feture_selection))
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
            
    output = {"retained_features_by_filtering": retained_features_by_filter,
             "retained_features_by_embedding": retained_features_by_embedding,
             "retained_features_by_wrapping": retained_features_by_wrapping,
             "final_feature_selection": final_feture_selection}
    
    if save == True:
        filename = filename + "_" + model + "_" + "feature_selection" + "_" + record_time() + ".pkl"
        with open(filename, "wb") as file:
            pickle.dump(output, file)
    
    message = "{} finish feature selection".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
    return output
    

In [13]:
test = feature_selection(data=data["matrix"].iloc[0:5000, 73000:74370], 
                  label_column="cell_type", 
                  filename="test",
                  variance_threshold="median",
                  model="random_foreast",
                  chi_square_test=True,
                  F_test=False,
                  mutual_info=False,
                  n_features_to_select=None,
                  logger = my_logger
                 )

2024-01-02 14:11:53 start feature selection
2024-01-02 14:11:53 step 1 - converting categoric label into numeric label
2024-01-02 14:11:53 step 2 - do feature selection
* 2024-01-02 14:11:53 684 features remained after filter out features below median variance of all features
** 2024-01-02 14:11:53 534 features remained after further chi sqaure test filtering
* 2024-01-02 14:11:58 363 features remained after random foreast based embedding filtering
* 2024-01-02 14:12:39 684 features remained after RFE - random foreast based wrapping filtering
* 2024-01-02 14:12:39 324 features remained after intersecting the key features found by filtering, embedding and wrapping-based feature selection methods
2024-01-02 14:12:39 finish feature selection


In [40]:
def model_training(data,
                   label_column,
                   features,
                   model,
                   logger = None,
                   test_size = 0.3,
                   random_state = 10,
                   cv = 10,
                   save = True
                  ):
    
    """
    A function to do model training.
    
    Parameters:
    d data: A pandas data frame object.
    label_column: The name of cell type column in the data.
    features: Feaures should be kept for model training in the data.
    model: Algorithm to train model. "random_foreast", "svm" or "logistic" could be selected. 
    test_size: 
    logger: A log_file object to write log information into disk. Default is None. 
    test_size: Percentage of data remained for testing model.
    random_state:
    cv: The number of cross validation for grid serach.
    save: Bool value to decide whether the result will be written into disk. Default is True.
    """
    
    message = "{} start model training".format(record_time())
    print(message)
    print("{} model traning based on {} algorithm".format(record_time(), model))
    
    if logger != None:
        logger.write("info", remove_time(message))
        
        parameters = {"label_column": label_column,
                     "features": features,
                     "model": model,
                     "test_size": test_size,
                     "random_state": random_state,
                     "cv": cv,
                     "save": save}
        
        for key,value in parameters.items():
            logger.write("critical", "{}: {}".format(key, value))
        
        logger.write("info", "model traning based on {} algorithm".format(model))
    
    X = data.loc[:, data.columns != label_column]
    y = data.loc[:, data.columns == label_column]
    
    le = LabelEncoder().fit(y)
    y_trans = le.transform(y)
    
    # only keep informative features
    X = data.loc[:, data.columns.isin(features)] 
    
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y_trans, test_size = test_size)
    
    if model == "random_foreast":
        
        parameters = {"n_estimators" : np.arange(10, 101, 10),
                     "criterion" : ["gini", "entropy"],
                     # "max_depth": np.linspace(10, 50, 5),
                     "max_features": np.linspace(0.2, 1, 5),
                     "min_samples_leaf": np.arange(10, 300, 20),
                     "min_samples_split": np.arange(2, 100, 10)}
        
        message = "{} grid search below paramters getting the best model".format(record_time())
        print(message)
        
        if logger != None:
            logger.write("info", remove_time(message))
        
        for key,value in parameters.items():
            print("{}: {}".format(key, value))
            if logger != None:
                logger.write("critical", "{}: {}".format(key, value))
           
        
        RFC_ = RandomForestClassifier(random_state = random_state)
        GS = GridSearchCV(RFC_, parameters, cv = cv)
        GS.fit(Xtrain, Ytrain)
        best_parameters = GS.best_params_
        best_core = GS.best_score_
        score_on_test_data = GS.score(Xtest, Ytest)
    
    elif model == "logistic":
        
        parameters = {"penalty": ["l1", "l2"],
                      "C": np.linspace(0,1,5),
                      "multi_class": ["ovr", "multinomial"]
                     }
        
        message = "{} grid search below paramters getting the best model".format(record_time())
        print(message)
        
        if logger != None:
            logger.write("info", remove_time(message))
        
        for key,value in parameters.items():
            print("{}: {}".format(key, value))
            if logger != None:
                logger.write("critical", "{}: {}".format(key, value))
        
        logistic_ = LogisticRegression()
        GS = GridSearchCV(logistic_, parameters, cv = cv)
        GS.fit(Xtrain, Ytrain)
        best_parameters = GS.best_params_
        best_core = GS.best_score_
        score_on_test_data = GS.score(Xtest, Ytest)
    
    elif model == "svm":
        # scale data for SVM
        Xtrain = StandardScaler().fit_transform(np.array(Xtrain))
        Xtest = StandardScaler().fit_transform(np.array(Xtest))
        
        parameters = {"C": np.linspace(0.01,30,50),
                     "kernel": ["rbf", "poly", "sigmoid", "linear"],
                     "gamma": ["auto", "scale"],
                     "coef0": np.linspace(0,5,10)}
        
        message = "{} grid search below paramters getting the best model".format(record_time())
        print(message)
        
        if logger != None:
            logger.write("info", remove_time(message))
        
        for key,value in parameters.items():
            print("{}: {}".format(key, value))
            if logger != None:
                logger.write("critical", "{}: {}".format(key, value))
        
        SVM_ = SVC()
        GS = GridSearchCV(SVM_, parameters, cv = cv)
        GS.fit(Xtrain, Ytrain)
        best_parameters = GS.best_params_
        best_core = GS.best_score_
        score_on_test_data = GS.score(Xtest, Ytest)
        
    
    output = {"model" : GS,
              "best_score": best_core,
              "best_parameters": best_parameters,
              "score_on_test_data": score_on_test_data,
             "features_used_for_training": features}
    
    if save == True:
        filename = model + "_" + "training_model" + "_" + record_time() + ".pkl"
        with open(filename, "wb") as file:
            pickle.dump(output, file)
    
    message = "{} finish model training".format(record_time())
    print(message)
    if logger != None:
        logger.write("info", remove_time(message))
    
    return(output)
        

In [None]:
test2 = model_training(data = data["matrix"].iloc[0:500, 73000:74370],
                      label_column = "cell_type",
                      features = test["final_feature_selection"],
                      model = "random_foreast",
                      logger = my_logger,
                      cv = 2)

2024-01-02 17:15:00 start model training
2024-01-02 17:15:00 model traning based on random_foreast algorithm
2024-01-02 17:15:00 grid search below paramters getting the best model
n_estimators: [ 10  20  30  40  50  60  70  80  90 100]
criterion: ['gini', 'entropy']
max_features: [0.2 0.4 0.6 0.8 1. ]
min_samples_leaf: [ 10  30  50  70  90 110 130 150 170 190 210 230 250 270 290]
min_samples_split: [ 2 12 22 32 42 52 62 72 82 92]


  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)


In [38]:
test2

{'model': GridSearchCV(cv=2, estimator=RandomForestClassifier(random_state=10),
              param_grid={'criterion': ['gini', 'entropy'],
                          'max_features': array([0.2, 0.4, 0.6, 0.8, 1. ]),
                          'n_estimators': array([ 10,  20,  30,  40,  50,  60,  70,  80,  90, 100])}),
 'best_score': 0.2914285714285714,
 'best_parameters': {'criterion': 'gini',
  'max_features': 0.8,
  'n_estimators': 30},
 'score_on_test_data': 0.26666666666666666,
 'features_used_for_training': {'ZADH2',
  'ZAP70',
  'ZBED5',
  'ZBED5-AS1',
  'ZBP1',
  'ZBTB1',
  'ZBTB10',
  'ZBTB11',
  'ZBTB14',
  'ZBTB16',
  'ZBTB18',
  'ZBTB2',
  'ZBTB20',
  'ZBTB24',
  'ZBTB25',
  'ZBTB37',
  'ZBTB38',
  'ZBTB4',
  'ZBTB40',
  'ZBTB43',
  'ZBTB44',
  'ZBTB45',
  'ZBTB49',
  'ZBTB7A',
  'ZBTB7B',
  'ZBTB8OS',
  'ZC2HC1A',
  'ZC3H12A',
  'ZC3H12D',
  'ZC3H13',
  'ZC3H14',
  'ZC3H15',
  'ZC3H18',
  'ZC3H3',
  'ZC3H4',
  'ZC3H6',
  'ZC3H7A',
  'ZC3H7B',
  'ZC3H8',
  'ZC3HAV1',
  'ZC3HC