In [15]:
import os
import pickle

import numpy as np
import pandas as pd
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import (
    KFold,
)
from sklearn.naive_bayes import BernoulliNB
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.tree import DecisionTreeClassifier

In [16]:
pd.set_option("display.max_columns", None)
np.random.seed(0)

In [31]:
interaction_matrix = pd.read_csv("../../data/interactions/interaction_matrix.csv", sep=";").set_index("bacteria")

display(interaction_matrix.isna().sum(axis=0).to_frame().T)
display(interaction_matrix.head())

Unnamed: 0,55989_P2,LF82_P8,AL505_Ev3,LF73_P4,BCH953_P2,BCH953_P4,BCH953_P5,LF73_P1,LF73_P3,NIC06_P2,T4LD,AN17_P8,LF110_P1,LF110_P2,LF110_P3,LF110_P4,LF82_P1,LF82_P2,LF82_P3,LF82_P4,LF82_P5,LF82_P6,LF82_P9,NRG_11A2,NRG_11B1,LF31_P1,536_P6,536_P7,536_P9,536_P1,DIJ07_P1,DIJ07_P2,AN24_P2,AN24_P3,NAN33_P5,427_P2,T7_Portugal,LM07_P1,NAN33_P1,NAN33_P2,NAN33_P4,NAN33_P6,409_P1,423_P5,427_P3,427_P4,55989_P1,AN17_P1,AN24_P4,BDX03_P1,BDX03_P2,LI10_P1,LI10_P2,LM33_P1,MT1B1_3A1,409_P3,409_P5,409_P6,412_P2,416_P4,416_P5,423_P1,423_P10,423_P7,423_P9,AL505_Sd2,DIJ06_P1,LF31_P3,LI10_P3,LI10_P4,LI10_P5,LI10_P6,NIC06_P3,412_P1,412_P3,412_P4,412_P5,409_P8,536_P11,BCH953_P1,BCH953_P3,LF7074_P1,LF7074_P2,LF7074_P3,LM02_P1,LM08_P1,LM08_P2,NRG_12A1B,411_P1,536_P12,BDX09_P1,411_P2,LF50_P3,LM40_P1,LM40_P2,LM40_P3
0,1,1,2,2,1,1,1,2,1,1,1,2,2,1,1,1,1,1,1,1,1,1,1,4,0,1,1,0,0,1,1,0,2,5,8,2,0,2,2,2,2,3,2,2,2,2,1,2,1,2,5,1,1,1,2,2,2,2,2,2,2,2,2,2,2,2,1,1,1,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,2,2,1,0,2,1,1,2,1,6,1,1


Unnamed: 0_level_0,55989_P2,LF82_P8,AL505_Ev3,LF73_P4,BCH953_P2,BCH953_P4,BCH953_P5,LF73_P1,LF73_P3,NIC06_P2,T4LD,AN17_P8,LF110_P1,LF110_P2,LF110_P3,LF110_P4,LF82_P1,LF82_P2,LF82_P3,LF82_P4,LF82_P5,LF82_P6,LF82_P9,NRG_11A2,NRG_11B1,LF31_P1,536_P6,536_P7,536_P9,536_P1,DIJ07_P1,DIJ07_P2,AN24_P2,AN24_P3,NAN33_P5,427_P2,T7_Portugal,LM07_P1,NAN33_P1,NAN33_P2,NAN33_P4,NAN33_P6,409_P1,423_P5,427_P3,427_P4,55989_P1,AN17_P1,AN24_P4,BDX03_P1,BDX03_P2,LI10_P1,LI10_P2,LM33_P1,MT1B1_3A1,409_P3,409_P5,409_P6,412_P2,416_P4,416_P5,423_P1,423_P10,423_P7,423_P9,AL505_Sd2,DIJ06_P1,LF31_P3,LI10_P3,LI10_P4,LI10_P5,LI10_P6,NIC06_P3,412_P1,412_P3,412_P4,412_P5,409_P8,536_P11,BCH953_P1,BCH953_P3,LF7074_P1,LF7074_P2,LF7074_P3,LM02_P1,LM08_P1,LM08_P2,NRG_12A1B,411_P1,536_P12,BDX09_P1,411_P2,LF50_P3,LM40_P1,LM40_P2,LM40_P3
bacteria,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1
ECOR-54,0.0,1.0,0.0,1.0,1.0,0.0,0.0,4.0,4.0,0.0,0.0,0.0,3.0,3.0,3.0,3.0,2.0,1.0,1.0,2.0,2.0,2.0,4.0,3.0,3.0,0.0,0.0,1.0,2.0,3.0,4.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
DIJ06,0.0,1.0,0.0,2.0,0.0,0.0,0.0,4.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ECOR-52,0.0,0.0,0.0,1.0,0.0,0.0,0.0,4.0,4.0,0.0,0.0,0.0,1.0,1.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,1.0,1.0,1.0,3.0,4.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ECOR-51,3.0,2.0,2.0,3.0,3.0,4.0,4.0,3.0,3.0,2.0,1.0,4.0,3.0,3.0,3.0,3.0,3.0,2.0,1.0,2.0,2.0,2.0,4.0,2.0,2.0,1.0,1.0,1.0,2.0,3.0,3.0,3.0,0.0,1.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,1.0,4.0,1.0,1.0,0.0,2.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,4.0,1.0,2.0,1.0,1.0,0.0,4.0,0.0,2.0,0.0,2.0,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0
ECOR-55,2.0,3.0,1.0,4.0,0.0,2.0,1.0,3.0,4.0,1.0,0.0,4.0,3.0,4.0,4.0,3.0,3.0,4.0,2.0,4.0,4.0,4.0,4.0,1.0,1.0,2.0,0.0,2.0,2.0,3.0,2.0,3.0,0.0,0.0,0.0,2.0,0.0,1.0,0.0,1.0,0.0,0.0,3.0,2.0,1.0,1.0,0.0,4.0,1.0,2.0,0.0,2.0,2.0,0.0,0.0,3.0,3.0,1.0,0.0,3.0,3.0,0.0,1.0,0.0,1.0,2.0,2.0,0.0,4.0,3.0,4.0,4.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,2.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0


In [14]:
bact_features = pd.read_csv("../../data/genomics/bacteria/picard_collection.csv", sep=";").set_index("bacteria")
# bact_embeddings = pd.read_csv("../../data/genomics/bacteria/umap_phylogeny/coli_umap_8_dims.tsv", sep="\t").set_index("bacteria")
# bact_features = pd.merge(bact_features, bact_embeddings, left_index=True, right_index=True)
# bact_features = bact_features.filter(regex=r"(UMAP|O-type|LPS|ST_Warwick|Klebs|ABC_serotype)", axis=1)

display(bact_features.head())

Unnamed: 0_level_0,Gembase,Host,Origin,Pathotype,Clermont_Phylo,ST_Warwick,O-type,H-type,Mouse_killed_10,Capsule_ABC,Capsule_GroupIV_e,Capsule_GroupIV_e_stricte,Capsule_GroupIV_s,Capsule_Wzy_stricte,LPS_type,Collection,Klebs_capsule_type,n_defense_systems,n_infections,ABC_serotype
bacteria,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
001-023,ESCO.0622.00103,Human,Faeces,Commensal,G,1163,O149,H23,0.0,0.0,0.0,1.0,1.0,1.0,R4,Original,,9.0,32,
001-031-c1,ESCO.0622.00308,Human,Faeces,Commensal,B2,452,O81,H27,0.0,1.0,0.0,0.0,1.0,1.0,R3,Original,,8.0,5,Unknown
003-026,ESCO.0622.00119,Human,Faeces,Commensal,G,1163,O33,H23,0.0,0.0,0.0,1.0,1.0,1.0,R4,Original,,6.0,28,
013-008,ESCO.0622.00326,Human,Faeces,Commensal,B2,452,O81,H27,0.0,1.0,0.0,0.0,1.0,1.0,R3,Original,,6.0,5,Unknown
025-010,ESCO.0622.00213,Human,Faeces,Commensal,E,543,O169,H9,0.0,0.0,0.0,1.0,1.0,0.0,R4,Original,,10.0,30,


In [5]:
phage_features = pd.read_csv("../../data/genomics/phages/guelin_collection.csv", sep=";").set_index("phage")
phage_features = phage_features.loc[interaction_matrix.columns, ["Morphotype", "Genus", "Phage_host"]]
phage_features.index.name = "phage"

display(phage_features.head())

Unnamed: 0_level_0,Morphotype,Genus,Phage_host
phage,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
55989_P2,Myoviridae,Dhakavirus,55989
LF82_P8,Myoviridae,Mosigvirus,LF82
AL505_Ev3,Myoviridae,Krischvirus,AL505
LF73_P4,Myoviridae,Krischvirus,LF73
BCH953_P2,Myoviridae,Tequatrovirus,BCH953


---

In [6]:
def get_alias(model):
        aliases = {
            LogisticRegression: "LogReg",
            RandomForestClassifier: "RF",
            DummyClassifier: "Dummy",
            MLPClassifier: "MLP",
            BernoulliNB: "NaiveBayes",
            DecisionTreeClassifier: "DecTree",
        }
        name = aliases[type(model)]
        if type(model) == LogisticRegression:
            name += "_" + model.penalty
        elif type(model) == RandomForestClassifier:
            name += "_" + str(model.n_estimators) + "_" + str(model.max_depth)
        elif type(model) == DummyClassifier:
            name += "_" + model.strategy
        elif type(model) == MLPClassifier:
            hidden_layer_sizes = list(
                str(x) for x in model.get_params()["hidden_layer_sizes"]
            )
            name += (
                "_"
                + "-".join(hidden_layer_sizes)
                + "_lr="
                + str(model.get_params()["learning_rate_init"])
            )
        if hasattr(model, "class_weight") and model.class_weight is not None:
            name += "_weight=" + str(model.class_weight[1])
        return name

In [None]:
def perform_group_cross_validation(
        X,
        y,
        models,
        models_params,
        n_splits=10,
        index_names=None,
        do_scale=False,
    ):
        kfold = KFold(n_splits=n_splits)
        umap_dim = X.shape[1] // 2

        # Train feature scaler on the whole dataset (if required)
        if do_scale:
            std_scaler = MinMaxScaler()
            std_scaler.fit(X)

        performance, predictions, logs = [], [], []
        model_list = {}
        for i, (train_idx, test_idx) in enumerate(kfold.split(X, y)):  # K-fold cross-validation
            X_train, X_test, y_train, y_test = X.iloc[train_idx], X.iloc[test_idx], y.iloc[train_idx], y.iloc[test_idx]

            # check that train set observations and validation set observations are disjoint
            assert set(X_train.index).intersection(set(X_test.index)) == set()

            for model_type, param in zip(models, models_params):
                model = model_type(**param)
                alias = get_alias(model)

                # Fit model (train set)
                model.fit(X_train, y_train)

                # Model evaluation (train and test set)
                for ds, ds_name in zip(
                    [[X_train, y_train], [X_test, y_test]], ["train", "test"]
                ):
                    xset, yset = ds

                    # Feature scaling (if required)
                    if do_scale:
                        xset = pd.DataFrame(std_scaler.transform(xset), columns=X.columns)

                    # Predictions
                    y_pred, y_pred_proba = model.predict(xset), model.predict_proba(xset)

                    # Metrics
                    if np.unique(yset).shape[0] > 1:  # Cannot compute metrics if only one class is predicted
                        # tp, fp, tn, fn = confusion_matrix(yset, y_pred).ravel()
                        # precision, recall, f1 = precision_score(yset, y_pred), recall_score(yset, y_pred), f1_score(yset, y_pred)
                        # average_prec = average_precision_score(yset, y_pred_proba[:, 1])
                        # roc_auc = roc_auc_score(yset, y_pred_proba[:, 1])

                        tn = fp = fn = tp = np.nan  # Not meaningful for multiclass, but kept for compatibility
                        precision = precision_score(yset, y_pred, average="weighted", zero_division=0)
                        recall = recall_score(yset, y_pred, average="weighted", zero_division=0)
                        f1 = f1_score(yset, y_pred, average="weighted", zero_division=0)
                        try:
                            average_prec = average_precision_score(yset, y_pred_proba, average="weighted")
                        except Exception:
                            average_prec = np.nan
                        try:
                            roc_auc = roc_auc_score(yset, y_pred_proba, multi_class="ovr", average="weighted")
                        except Exception:
                            roc_auc = np.nan

                        performance.append(
                            {
                                "model": alias,
                                "fold": i,
                                "dataset": ds_name,
                                "precision": precision,
                                "recall": recall,
                                "f1": f1,
                                "roc_auc": roc_auc,
                                "avg_precision": average_prec,
                                "tp": tp,
                                "fp": fp,
                                "tn": tn,
                                "fn": fn,
                            }
                        )
                    else:
                        performance.append(
                            {
                                "model": np.nan,
                                "fold": np.nan,
                                "dataset": np.nan,
                                "precision": np.nan,
                                "recall": np.nan,
                                "f1": np.nan,
                                "roc_auc": np.nan,
                                "avg_precision": np.nan,
                                "tp": np.nan,
                                "fp": np.nan,
                                "tn": np.nan,
                                "fn": np.nan,
                            }
                        )

                    # Collect predictions (test set only)
                    if ds_name == "test":  # and not alias.startswith("Dummy"):
                        preds = index_names.iloc[test_idx].copy()
                    else:
                        preds = index_names.iloc[train_idx].copy()
                    preds["y_pred"] = model.predict(xset)
                    preds["y_pred_proba"] = model.predict_proba(xset)[:, 1]
                    preds["fold"] = i
                    preds["model"] = alias
                    preds["dataset"] = ds_name
                    predictions.append(preds)  # add bacteria-phage name as index instead of integer (avoid ambiguity)

                model_list[f"{p}_{alias}_fold={i}"] = model
                del model
            logs.append(
                {
                    "fold": i,
                    "train_size": train_idx.shape[0],
                    "test_size": test_idx.shape[0],
                    "train_idx": train_idx,
                    "test_idx": test_idx,
                }
            )

        logs = pd.DataFrame(logs)
        performance = pd.DataFrame(performance)
        all_cv_predictions = pd.concat([pred for pred in predictions])[["fold", "model", "dataset", "bacteria", "phage", "y_pred_proba", "y_pred"]]

        return logs, performance, all_cv_predictions, model_list

In [None]:
save_dir = "outputs"

for p in phage_features.index:
    print(f"Processing phage {p}...")

    # Filter phages according to phylogeny
    phage_feat = phage_features.loc[[p]]
    interaction_mat = interaction_matrix[[p]]

    phage_feat = phage_feat.drop(["Morphotype", "Genus"], axis=1)

    # wide to long
    interaction_matrix_long = (
        interaction_mat.unstack()
        .reset_index()
        .rename({"level_0": "phage", 0: "y"}, axis=1)
        .sort_values(["bacteria", "phage"])
    )  # force row order

    # Concat features and target
    interaction_with_features = pd.merge(
        interaction_matrix_long, bact_features, left_on=["bacteria"], right_index=True
    )

    # Add phage host features to predictors
    phage_host_features = pd.merge(
        phage_feat,
        bact_features.filter(regex="(ST_Warwick|O-type|H-type)", axis=1),
        left_on="Phage_host",
        right_index=True,
    ).rename(
        {
            "Clermont_Phylo": "Clermont_host",
            "LPS_type": "LPS_host",
            "O-type": "O_host",
            "H-type": "H_host",
            "ST_Warwick": "ST_host",
        },
        axis=1,
    )

    if not p.startswith("LF110"):  # do not have the data for LF110 host strain
        interaction_with_features = pd.merge(
            interaction_with_features,
            phage_host_features.drop(["Phage_host"], axis=1),
            left_on="phage",
            right_index=True,
        )

    # Recode O-type : only keep main categories to avoid having too many levels
    if "O-type" in bact_features.columns:
        otypes_to_recode = (
            bact_features.groupby("O-type")
            .filter(lambda x: x.shape[0] < 3)["O-type"]
            .unique()
        )  # less than 5 observations for the O-type value
        interaction_with_features.loc[
            interaction_with_features["O-type"].isin(otypes_to_recode), "O-type"
        ] = "Other"
        if not p.startswith("LF110"):
            interaction_with_features["same_O_as_host"] = (
                interaction_with_features["O-type"]
                == interaction_with_features["O_host"]
            )
            interaction_with_features = interaction_with_features.drop("O_host", axis=1)

    # Recode ST : only keep main categories to avoid having too many levels
    if "ST_Warwick" in bact_features.columns:
        st_to_recode = (
            bact_features.groupby("ST_Warwick")
            .filter(lambda x: x.shape[0] < 3)["ST_Warwick"]
            .unique()
        )  # less than 5 observations for the O-type value
        interaction_with_features.loc[
            interaction_with_features["ST_Warwick"].isin(st_to_recode), "ST_Warwick"
        ] = "Other"
        if not p.startswith("LF110"):
            interaction_with_features["same_ST_as_host"] = (
                interaction_with_features["ST_Warwick"]
                == interaction_with_features["ST_host"]
            )

    if "ABC_serotype" in bact_features.columns:
        if not p.startswith("LF110"):
            interaction_with_features["same_ABC_as_host"] = (
                interaction_with_features["ABC_serotype"]
                == interaction_with_features["ABC_serotype"]
            )

    if (
        "same_O_as_host" in interaction_with_features.columns
        and "same_ST_as_host" in interaction_with_features.columns
        and not p.startswith("LF110")
    ):
        interaction_with_features["same_O_and_ST_as_host"] = (
            interaction_with_features["same_O_as_host"]
            * interaction_with_features["same_ST_as_host"]
        )

    # Drop missing observations
    na_observations = interaction_with_features.loc[
        interaction_with_features["y"].isna()
    ].index
    interaction_with_features = interaction_with_features.drop(na_observations, axis=0)

    # Dummy encoding of categorical variables and standardization for numerical variables
    X, y, bact_phage_names = (
        interaction_with_features.drop(["bacteria", "phage", "y"], axis=1),
        interaction_with_features["y"],
        interaction_with_features[["bacteria", "phage"]],
    )

    num, factors = [], []
    for col_dtype, col in zip(X.dtypes, X.dtypes.index):
        if col_dtype == "float64":
            num.append(col)
        else:
            factors.append(col)
    X_oh = pd.concat(
        [
            (X[num] - X[num].mean(axis=0)) / X[num].std(axis=0),
            pd.get_dummies(X[factors], sparse=False),
        ],
        axis=1,
    )

    # Perform cross-validation
    import warnings

    from sklearn.exceptions import UndefinedMetricWarning

    warnings.filterwarnings(
        action="ignore", category=UndefinedMetricWarning
    )  # shutdown sklearn warning regarding ill-defined precision

    n_splits = 10
    redo_predictions = True
    if redo_predictions:  # avoid overwriting predictions by mistake
        # Make predictions
        models_to_test = [
            RandomForestClassifier,
            RandomForestClassifier,
            LogisticRegression,
            LogisticRegression,
            DummyClassifier,
        ]

        # choose class weight
        perc_pos_class = y.sum() / y.shape[0]
        if 0.60 <= perc_pos_class:
            cw = {0: 1, 1: 0.8}
        elif 0.4 <= perc_pos_class < 0.6:
            cw = {0: 1, 1: 1}
        elif 0.3 <= perc_pos_class < 0.4:
            cw = {0: 1, 1: 1.5}
        elif 0.2 <= perc_pos_class < 0.3:
            cw = {0: 1, 1: 2}
        else:
            cw = {0: 1, 1: 3}

        # cw = "balanced"

        params = [
            {"max_depth": 3, "n_estimators": 250, "class_weight": cw},
            {"max_depth": 6, "n_estimators": 250, "class_weight": cw},
            {"class_weight": cw, "max_iter": 10000},
            {"class_weight": cw, "penalty": "l1", "solver": "saga", "max_iter": 10000},
            {"strategy": "stratified"},
        ]
        logs, performance, cv_predictions, trained_models = (
            perform_group_cross_validation(
                X_oh,
                y,
                n_splits=n_splits,
                index_names=bact_phage_names,
                models=models_to_test,
                models_params=params,
                do_scale=False,
            )
        )

        performance["phage"] = p
        cv_predictions["phage"] = p

        performance = performance.set_index("phage")
        cv_predictions = cv_predictions.set_index("phage")

        cv_predictions = pd.merge(
            cv_predictions,
            interaction_with_features[["bacteria", "phage", "y"]],
            on=["bacteria", "phage"],
        )  # add real interaction values

        overwrite_files = True  # overwrite log and performance files
        if overwrite_files:
            logs.to_csv(
                f"{save_dir}/results/logs/logs_{p}_Group{n_splits}Fold_CV.csv",
                sep=";",
                index=False,
            )
            performance.to_csv(
                f"{save_dir}/results/performances/performance_{p}_Group{n_splits}Fold_CV.csv",
                sep=";",
            )
            cv_predictions.to_csv(
                f"{save_dir}/results/predictions/predictions_{p}_core_features_Group{n_splits}Fold_CV.csv",
                sep=";",
                index=False,
            )

            if not os.path.isdir(f"{save_dir}/results/models/{p}"):
                os.mkdir(f"{save_dir}/results/models/{p}")

            for k, mod in enumerate(trained_models):
                save_name = (
                    str(k)
                    + "_"
                    + mod.split("_")[0]
                    + "_"
                    + mod.split("_")[1]
                    + "_"
                    + mod.split("_")[-1]
                )
                with open(
                    f"{save_dir}/results/models/{p}/{mod}.pickle", "wb"
                ) as save_file:
                    pickle.dump(trained_models[mod], save_file)

            # print("Saved performances, predictions, log files and models !")

        # Feature importance retried by random forest classifier
        # print(f"Bacterial features : Clermont_Phylo, ST_Warwick, LPS_type, O-type, H-type.")
        # print(f"Phage features : Morphotype, Genus, Phage_host.")

        # get best model on test set
        perf_by_model = (
            performance.loc[performance["dataset"] == "test"]
            .groupby("model")["avg_precision"]
            .mean()
        )
        model_name = perf_by_model.sort_values(ascending=False).index[0]

        print(f"Best model: {model_name}")

        clfs = []
        for mod in os.listdir(save_dir + f"/results/models/{p}"):
            if mod.startswith(p + "_" + model_name) and mod.endswith("pickle"):
                clfs.append(
                    pickle.load(open(save_dir + f"/results/models/{p}/" + mod, "rb"))
                )

        # save feature importance
        if model_name.startswith("RF"):
            feature_importances = pd.DataFrame(
                [clf.feature_importances_ for clf in clfs], columns=X_oh.columns
            ).melt()
        elif model_name.startswith("LogReg"):
            feature_importances = pd.DataFrame(
                [clf.coef_[0] for clf in clfs], columns=X_oh.columns
            ).melt()
        else:
            continue

        sorted_by_average_importance = (
            feature_importances.groupby("variable")
            .mean()
            .sort_values("value", ascending=False)
            .reset_index()
            .rename({"value": "average_importance"}, axis=1)
        )
        feature_importances = pd.merge(
            feature_importances, sorted_by_average_importance, on="variable"
        )
        feature_importances["phage"] = p
        feature_importances["model"] = model_name
        feature_importances.to_csv(
            f"{save_dir}/results/feature_importances/{p}_feature_importance.csv",
            sep=";",
            index=False,
        )

Processing phage 55989_P2...
Best model: Dummy_stratified
Processing phage LF82_P8...
Best model: Dummy_stratified
Processing phage AL505_Ev3...
Best model: Dummy_stratified
Processing phage LF73_P4...
Best model: Dummy_stratified
Processing phage BCH953_P2...
Best model: Dummy_stratified
Processing phage BCH953_P4...
Best model: Dummy_stratified
Processing phage BCH953_P5...
Best model: Dummy_stratified
Processing phage LF73_P1...
Best model: Dummy_stratified
Processing phage LF73_P3...
Best model: Dummy_stratified
Processing phage NIC06_P2...
Best model: Dummy_stratified
Processing phage T4LD...
Best model: Dummy_stratified
Processing phage AN17_P8...
Best model: Dummy_stratified
Processing phage LF110_P1...
Best model: Dummy_stratified
Processing phage LF110_P2...
Best model: Dummy_stratified
Processing phage LF110_P3...
Best model: Dummy_stratified
Processing phage LF110_P4...
Best model: Dummy_stratified
Processing phage LF82_P1...
Best model: Dummy_stratified
Processing phage LF8



Best model: Dummy_stratified
Processing phage NAN33_P6...
Best model: Dummy_stratified
Processing phage 409_P1...


ValueError: The classes, [2.0, 3.0, 4.0], are not in class_weight