In [35]:
import os
import json
import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB

In [36]:
root = os.getcwd() # gives path to git clone
data_dir = os.path.join(root, "cellular_clarity")
associated_motifs_filename = "associated_motifs.json"
genes_filenames = "all_genes.xlsx"

In [37]:
# all_genes_dir = os.path.join(data_dir, genes_filenames["all"])
# genes = pd.read_csv(all_genes_dir)
# selected_genes = np.random.choice(genes["AGI"], 1000, replace=False)
#
# data = {}
# for gene in genes["AGI"]:
#     count = np.random.choice(4, replace=False, p=[0.5, 0.45, 0.045, 0.005])
#     data[gene] = list(np.random.choice(selected_genes, count, replace=False))
#
# # Write the dictionary to a JSON file
# associated_motifs_dir = os.path.join(data_dir, associated_motifs_filename)
# with open(associated_motifs_dir, "w") as file:
#     json.dump(data, file, indent=4) # indent for better readability
#
# print(f"Dictionary written to {associated_motifs_dir}")

In [98]:
df = pd.read_excel(os.path.join(data_dir, genes_filenames))
df.set_index("AGI", drop=True, inplace=True)
for t in range(6, 42, 6):
    c = chr(ord('A') + t//6)
    _dir = os.path.join(data_dir, "DEGs")
    filename = f"DEGs_{c}_minus_vs_{c}_plus.csv"
    df_tmp = pd.read_csv(str(os.path.join(_dir, filename)))
    df_tmp.set_index("Unnamed: 0", drop=True, inplace=True)
    df.loc[df_tmp.index, f"logFC @ {t} hrs"] = df_tmp["logFC"]
    df_tmp["max logFC"] = df.loc[df_tmp.index, "logFC"]
    df[f"DE @ {t} hrs"] = 0
    df.loc[df_tmp.index, f"DE @ {t} hrs"] = ((df_tmp["logFC"].abs() > 0.75) & (df_tmp["FDR"] < 0.05)).astype(int)
    df.loc[df_tmp.index, "logFC"] = df_tmp["max logFC"].where(df_tmp["logFC"].abs() < df_tmp["max logFC"].abs(), df_tmp["logFC"])

    df_tmp["max logFC"] = df.loc[df_tmp.index, "logFC"]
    df_tmp = df_tmp[df_tmp["logFC"] == df_tmp["max logFC"]]
    df.loc[df_tmp.index, ["logCPM", "F", "PValue", "FDR"]] = df_tmp[["logCPM", "F", "PValue", "FDR"]]

df["DEG"] = (~df["logFC"].isna()).astype(float)
df.to_excel(os.path.join(data_dir, genes_filenames))
df

Unnamed: 0_level_0,Length,logFC @ 6 hrs,logFC @ 12 hrs,logFC @ 18 hrs,logFC @ 24 hrs,logFC @ 30 hrs,logFC @ 36 hrs,logFC,logCPM,F,PValue,FDR,DEG,Cluster,DE @ 6 hrs,DE @ 12 hrs,DE @ 18 hrs,DE @ 24 hrs,DE @ 30 hrs,DE @ 36 hrs
AGI,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
AT1G01010,1688,,,,,,,,,,,,0.0,,0,0,0,0,0,0
AT1G01020,1571,,,,,,,,,,,,0.0,,0,0,0,0,0,0
AT1G01030,1905,,,,,,,,,,,,0.0,,0,0,0,0,0,0
AT1G01040,6279,,,,,,,,,,,,0.0,,0,0,0,0,0,0
AT1G01046,207,,,,,,,,,,,,0.0,,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ATMG09730,80,,,,,,,,,,,,0.0,,0,0,0,0,0,0
ATMG09740,72,,,,,,,,,,,,0.0,,0,0,0,0,0,0
ATMG09950,75,,,,,,,,,,,,0.0,,0,0,0,0,0,0
ATMG09960,74,,,,,,,,,,,,0.0,,0,0,0,0,0,0


In [99]:
def load_json(filepath):
    """
    Loads JSON data from a file.

    Args:
        filepath (str): The path to the JSON file.

    Returns:
        dict or list: The JSON data as a Python dictionary or list, or None if an error occurs.
    """
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found: {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON format in: {filepath}")
        return None

In [100]:
def associated_motifs_(dir, filename) -> dict:
    """
    loads and returns the json file containing a list of all relevant genes and the set of motifs associated to each, based on AME
    Expected format of this file is a dict that maps each gene to all the motifs present in their promoter via the related TF to each such motif
    associated_motifs = {
        "AT5G18090": [
            "AT3G26790",
            "AT4G33280",
        ],
        ...,
        "AT3G18990": [
            "AT5G60130",
        ]
    }
    """
    associated_motifs_dir = os.path.join(dir, filename)
    return load_json(associated_motifs_dir)

associated_motifs = associated_motifs_(data_dir, associated_motifs_filename)
associated_motifs

{'AT1G01010': [],
 'AT1G01020': ['AT5G03315'],
 'AT1G01030': ['AT3G02455'],
 'AT1G01040': ['AT3G53890'],
 'AT1G01046': ['AT5G47077'],
 'AT1G01050': [],
 'AT1G01060': [],
 'AT1G01070': ['AT1G08887'],
 'AT1G01080': ['AT1G18140', 'AT2G26710'],
 'AT1G01090': [],
 'AT1G01100': [],
 'AT1G01110': [],
 'AT1G01120': ['AT2G31240'],
 'AT1G01130': [],
 'AT1G01140': ['AT3G05000', 'AT3G63050'],
 'AT1G01150': ['AT4G02380'],
 'AT1G01160': ['AT5G08215'],
 'AT1G01170': [],
 'AT1G01180': ['AT4G05315'],
 'AT1G01183': [],
 'AT1G01190': ['AT3G51840'],
 'AT1G01200': [],
 'AT1G01210': [],
 'AT1G01220': [],
 'AT1G01225': [],
 'AT1G01230': [],
 'AT1G01240': [],
 'AT1G01250': ['AT1G08887'],
 'AT1G01260': [],
 'AT1G01270': ['AT3G49630'],
 'AT1G01280': ['AT5G03805'],
 'AT1G01290': [],
 'AT1G01300': ['AT1G70730'],
 'AT1G01305': ['AT4G18520'],
 'AT1G01310': [],
 'AT1G01320': ['AT5G62150'],
 'AT1G01335': [],
 'AT1G01340': [],
 'AT1G01350': ['AT1G69490'],
 'AT1G01355': [],
 'AT1G01360': ['AT3G45510'],
 'AT1G01370': ['

In [106]:
def build_data(dir, filenames, associated_motifs, response, cluster=None):
    """
    Target genes are either those differentially expressed (if DEG is set to True) or others, and if cluster is stated, they should also be in a specific cluster
    Creates the suitable X and y (or Y) data for model fitting
    :param response: an string from ["DEG", "logFC", "DEG_ts", "logFC_ts"]

    """
    all_genes_dir = os.path.join(dir, filenames)
    genes = pd.read_excel(all_genes_dir)
    genes.set_index("AGI", drop=True, inplace=True)

    degs = genes[genes["DEG"] == 1].copy()
    if cluster is not None:
        degs = degs[degs["Cluster"] == cluster]
    treat = sorted(list(degs.index))
    control =  sorted(list(set(genes.index) - set(degs.index)))

    freq = []
    relevant_motifs = set()
    for gene in set(treat).intersection(set(associated_motifs.keys())):
        relevant_motifs = relevant_motifs.union(set(associated_motifs[gene]))
    for gene in control:
        if gene in associated_motifs.keys():
            count = len(relevant_motifs.intersection(set(associated_motifs[gene])))
        else:
            count = 0
        freq.append(count)
    relevant_motifs = sorted(list(relevant_motifs))

    # AGI1 : 0,
    # AGI2 : 2,
    # ...

    freq = pd.Series(index=control, data=freq)
    freq = freq[freq > 0]
    sum = np.sum(freq)
    indices = np.random.choice(len(freq), min([len(treat), len(freq)]), replace=False, p=freq/sum)
    control = sorted(list(freq.iloc[indices].index))

    considering_genes = treat.copy()
    considering_genes.extend(control)
    considering_genes = sorted(considering_genes)
    X = pd.DataFrame(index=considering_genes, columns=relevant_motifs, data=0.0)
    for gene in X.index:
        if gene in associated_motifs.keys():
            for motif in associated_motifs[gene]:
                if motif in relevant_motifs:
                    X.at[gene, motif] = 1.0

    target_genes = genes.loc[considering_genes]
    # Case where y = vector of binary values
    if response == "DEG":
        y = target_genes["DEG"].copy()

    # Case where y = vector of continuous values
    if response == "logFC":
        y = target_genes["logFC"].copy()

    # Case where y = matrix of binary values for all times
    if response == "DEG_ts": #, "logFC_ts"
        y = target_genes[[f"DE @ {i} hrs" for i in range(6, 42, 6)]].copy()

    # Case where y = matrix of continuous values for all times
    if response == "logFC_ts":
        y = target_genes[[f"logFC @ {i} hrs" for i in range(6, 42, 6)]].copy()

    # todo: remove this in ideal scenario
    y.fillna(0, inplace=True)

    return X, y

X, y = build_data(data_dir, genes_filenames, associated_motifs, "DEG")
X

Unnamed: 0,AT1G01320,AT1G01610,AT1G01760,AT1G01860,AT1G01880,AT1G02040,AT1G02500,AT1G02690,AT1G03080,AT1G03440,...,AT5G67360,AT5G67580,ATCG00550,ATCG00570,ATCG00830,ATCG01130,ATMG00610,ATMG00650,ATMG00830,ATMG01090
AT1G01030,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01210,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01305,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01360,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AT1G01380,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ATMG00700,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATMG00740,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATMG00980,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ATMG01210,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [107]:
y

AGI
AT1G01030    0
AT1G01210    1
AT1G01305    0
AT1G01360    0
AT1G01380    1
            ..
ATMG00700    0
ATMG00740    0
ATMG00980    0
ATMG01210    0
ATMG09980    0
Name: DEG, Length: 3776, dtype: int64

In [108]:
models = {
    "Logistic Regression":
        # Initialize Logistic Regression with Lasso (L1 regularization)
        LogisticRegression(penalty='l1', solver='saga', max_iter=1000, random_state=0),  # 'saga' solver works for L1 penalty
    "Gaussian Naive Bayes":
        GaussianNB(),

}

cv = KFold(n_splits=10, shuffle=True, random_state=0)

for key in models.keys():
    print(f"---------------\nmodel: {key}")

    # Evaluate the model using cross-validation
    scores = cross_val_score(models[key], X, y, cv=cv, scoring='accuracy')

    # Print the accuracy scores for each fold and the average accuracy score
    print("Accuracy scores for each fold:", scores)
    print("Average accuracy score:", scores.mean())

---------------
model: Logistic Regression
Accuracy scores for each fold: [0.54761905 0.55026455 0.52116402 0.6031746  0.60582011 0.57142857
 0.58885942 0.58090186 0.58355438 0.55702918]
Average accuracy score: 0.5709815727057107
---------------
model: Gaussian Naive Bayes
Accuracy scores for each fold: [0.65343915 0.62433862 0.56878307 0.60582011 0.64814815 0.62962963
 0.5994695  0.58885942 0.60742706 0.59416446]
Average accuracy score: 0.6120079154561914


In [111]:
# Set the number of experiment repetitions (k)
k = 20

# Store the coefficients for each iteration (if desired)
coefficients = []

# Run logistic regression with Lasso (L1) for k iterations
for _ in range(k):
    # Split the data into training and test sets (e.g., 80% training, 20% testing)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=_)

    # Initialize Logistic Regression with Lasso (L1 regularization)
    model = LogisticRegression(penalty='l1', solver='saga', max_iter=1000, random_state=_)  # 'saga' solver works for L1 penalty

    # Fit the model
    clf = model.fit(X_train, y_train)
    # clf.score(X_test, y_test)
    # clf.predict_proba(X_test)

    # Optionally, store the coefficients for this iteration
    coefficients.append(model.coef_.flatten())

    # Print the coefficients for this iteration (if desired)
    # print(f"Iteration {_+1} - Coefficients: {model.coef_}")

# Example of accessing coefficients from all iterations
result = pd.DataFrame(index=range(k), columns=X.columns, data=coefficients)
nonzero_cols = np.all(result.to_numpy() != 0, axis=0)
result = result.iloc[:, nonzero_cols]
result = result.mode(axis=0)
result

Unnamed: 0,AT1G04430,AT1G05233,AT1G05373,AT1G06210,AT1G07897,AT1G08125,AT1G27430,AT1G32650,AT1G35340,AT1G54120,...,AT5G40320,AT5G42520,AT5G45770,AT5G51900,AT5G52797,AT5G59380,AT5G61770,AT5G66730,AT5G67040,AT5G67200
0,-1.892285,-1.894886,-2.224774,-2.430704,-1.667397,-1.389413,-1.873401,-1.913413,-1.546326,-1.873402,...,-1.86876,-1.926423,-1.205337,-1.687151,-2.000933,-2.459537,-2.09554,-1.447148,-2.08435,-1.902901
1,-1.391371,-1.671097,-1.974939,-2.318395,-1.383455,-1.383436,-1.582343,-1.890538,-1.38768,-1.867214,...,-1.449123,-1.688701,-1.177689,-1.602606,-1.976116,-2.297239,-2.08523,-1.368477,-1.909841,-1.888687
2,-1.390681,-1.401795,-1.9236,-2.078272,-0.985734,-1.377878,-1.195522,-1.216681,-1.02014,-1.860652,...,-1.35417,-1.375486,-1.15935,-1.387574,-1.799022,-2.169485,-1.860646,-1.331808,-1.903645,-1.670378
3,-1.368134,-1.195522,-1.680501,-1.981984,-0.969366,-0.996326,-1.168083,-1.209753,-0.984011,-1.665544,...,-1.304774,-1.341184,-0.928702,-1.116251,-1.777172,-1.729567,-1.686692,-1.31669,-1.671007,-1.645617
4,-1.16748,-1.136769,-1.637493,-1.979375,-0.967211,-0.982194,-1.160091,-1.208866,-0.762684,-1.320515,...,-1.298697,-1.316263,-0.791912,-0.9971,-1.607323,-1.679059,-1.667398,-1.281959,-1.556528,-1.387655
5,-1.112419,-1.134092,-1.536818,-1.597391,-0.940446,-0.977853,-1.132095,-1.18025,-0.753056,-1.163432,...,-1.253701,-1.201742,-0.783886,-0.991208,-1.364869,-1.65341,-1.285554,-1.161424,-1.555941,-1.3797
6,-1.111494,-1.105935,-1.523654,-1.584714,-0.932006,-0.974239,-0.982195,-0.996325,-0.708155,-1.161815,...,-1.095012,-1.155873,-0.783817,-0.988137,-1.210497,-1.62563,-1.280893,-1.139446,-1.538214,-1.371692
7,-0.972688,-0.98641,-1.473688,-1.579582,-0.926421,-0.944327,-0.977948,-0.987353,-0.702298,-0.997108,...,-0.941713,-1.155348,-0.776702,-0.963046,-1.200619,-1.593218,-1.274313,-1.056278,-1.382945,-1.207062
8,-0.969788,-0.960564,-1.469814,-1.575444,-0.925259,-0.709429,-0.952465,-0.98396,-0.678545,-0.985726,...,-0.93572,-1.061369,-0.76949,-0.951498,-1.156424,-1.592343,-1.257601,-1.05564,-1.382675,-1.200998
9,-0.964464,-0.952466,-1.46702,-1.574413,-0.91762,-0.709236,-0.929145,-0.982193,-0.656659,-0.977514,...,-0.910977,-1.024097,-0.748415,-0.948946,-1.133468,-1.58348,-1.119508,-1.015914,-1.379703,-0.997104


In [112]:
result.mean()

AT1G04430   -0.926056
AT1G05233   -1.011239
AT1G05373   -1.474555
AT1G06210   -1.617662
AT1G07897   -0.814311
AT1G08125   -0.825947
AT1G27430   -0.983940
AT1G32650   -1.036728
AT1G35340   -0.648205
AT1G54120   -1.097485
AT1G58090   -0.961292
AT1G59460   -0.661350
AT1G59700   -0.926646
AT1G63460   -0.694831
AT1G66060   -1.638182
AT2G01350   -1.118387
AT2G15290   -0.886349
AT2G17080   -0.791577
AT2G22430   -1.092760
AT2G22970   -1.453503
AT2G42170   -0.937787
AT2G42330   -0.851910
AT2G47770   -1.116150
AT3G01795   -0.872471
AT3G08535   -1.268989
AT3G15850   -0.860586
AT3G21750   -1.212715
AT3G23940   -1.192191
AT3G51840   -0.889370
AT3G60176   -0.484354
AT4G01160   -1.428389
AT4G07666   -0.644100
AT4G08575   -1.361924
AT4G09850   -0.767888
AT4G13310   -1.189287
AT4G20890   -1.181393
AT4G22360   -1.238246
AT4G26980   -0.982291
AT4G27435   -0.711427
AT4G32520   -0.785717
AT5G03875   -0.801514
AT5G07600   -0.752177
AT5G24318   -0.753348
AT5G28650   -0.910207
AT5G36080   -0.911306
AT5G39880 