# Computation of individual phenotypes from shap profiles

The core function to generate shap values is `get_shaps()` below. Some notes on its usage:

- abundance data should be a pandas dataframe consisting of abundance data in wide form, i.e. `dat` should have:
  - first column: `subject_id` column
  - remaining columns: relative abundance values matching the features (taxa/timepoint) of the cohort models
  - can also provided pre-quantized data as `dat_quant`
- It is recommended to run the shap analysis several times to improve the quality of estimates

Two main approaches are supported - providing modeling data, and providing existing models:
- Data can provided to retrain cohort models with each run of the shap analysis
  - Data should be in long form (see examples in `qbiome` package) - pass as `appr_data` and `sub_data`; or in wide form, with subject abundances already quantized
  - Quantizer can be provided, or omitted and generated from the data
- Can also use pretrained models (`ahcg_list` and `shcg_list`), which are recycled as necessary to complete the requested number of runs (`n_runs`) 
  - In this case, a quantizer must also be provided

# Load packages and data

In [1]:
import os
from datetime import datetime

import numpy as np
import pandas as pd
import qbiome.data_formatter
import qbiome.forecaster
import qbiome.hypothesis
import qbiome.qnet_orchestrator
import qbiome.quantizer
import qbiome.qutil
import shap
from quasinet import qnet, qsampling
from tqdm.notebook import tqdm

appr_data_quant = pd.read_csv("uchicago_appropriate_cohort_quantized.csv")
sub_data_quant = pd.read_csv("uchicago_suboptimal_cohort_quantized.csv")

boston_data_quant = pd.read_csv("boston_data_quantized.csv")

qnt = qbiome.quantizer.Quantizer(num_levels=26)
qnt.load_quantizer_states("quantizer.pkl")

## Data format for new samples

In [2]:
boston_data_quant.head(1)

Unnamed: 0,subject_id,Acidimicrobiia_25,Acidimicrobiia_26,Acidimicrobiia_27,Acidimicrobiia_28,Acidimicrobiia_29,Acidimicrobiia_30,Acidimicrobiia_31,Acidimicrobiia_32,Acidimicrobiia_33,...,unclassified_Verrucomicrobiota_27,unclassified_Verrucomicrobiota_28,unclassified_Verrucomicrobiota_29,unclassified_Verrucomicrobiota_30,unclassified_Verrucomicrobiota_31,unclassified_Verrucomicrobiota_32,unclassified_Verrucomicrobiota_33,unclassified_Verrucomicrobiota_34,unclassified_Verrucomicrobiota_35,unclassified_Verrucomicrobiota_36
0,136000.0,,,,,,,M,M,M,...,,,,,M,M,M,M,M,M


# Compute SHAP values

## Load existing models, if desired

In [3]:
appr_qnet_lst = [qnet.load_qnet("shap-ahcg-1.joblib"), qnet.load_qnet("shap-ahcg-2.joblib")]
sub_qnet_lst = [qnet.load_qnet("shap-shcg-1.joblib"), qnet.load_qnet("shap-shcg-2.joblib")]

## SHAP computation function

In [4]:
def get_shaps(
    dat=None,
    dat_quant=None,
    appr_data=None,
    sub_data=None,
    appr_data_quant=None,
    sub_data_quant=None,
    nsamples=1100,
    quantizer=None,
    ahcg_lst=None,
    shcg_lst=None,
    n_qsamps=3,
    n_runs=1,
    shap_dir="tmp_shap/",
    save_qnets=False,
    verbose=True,
):
    """
    Computes shap values for data `dat`.

            Parameters:
                    dat (pandas.core.frame.DataFrame): dataframe of relative abundances (wide form)
                    dat_quant (pandas.core.frame.DataFrame): dataframe of quantized relative abundances (wide form)
                    appr_data (pandas.core.frame.DataFrame): dataframe of abundance observations used to construct the appropriate cohort model
                    sub_data (pandas.core.frame.DataFrame): dataframe of abundance observations used to construct the suboptimal cohort model
                    appr_data_quant (pandas.core.frame.DataFrame): dataframe of quantized abundance observations used to construct the appropriate cohort model
                    sub_data_quant (pandas.core.frame.DataFrame): dataframe of quantized abundance observations used to construct the suboptimal cohort model
                    nsamples (int): passed to shap.KernelExplainer.shap_values()
                    ahctg_lst (list): optional list of pretrained qnet models used for the appropriate cohort (one per run, recycled if needed). can be omitted and computed from provided data
                    shctg_lst (list): optional list of pretrained qnet models used for the suboptimal cohort (one per run, recycled if needed). can be omitted and computed from provided data
                    n_qsamps (int): number of qsamples to generate from each cohort to use as background data for the Shap computation
                    n_runs (int): number of times to regenerate shap values for each subject in `dat`
                    shap_dir (str): directory to save computed shap values and computed models if `save_qnets` is enabled
                    save_qnets (bool): indicator of whether to save the qnet computed on each run

            Returns:
                    shaps (numpy.ndarray): shap values for final run
                    explainer (shap.KernelExplainer): shap explainer for final run
                    ahctg (Quasinet.Qnet): appropriate qnet model for final run
                    shctg (Quasinet.Qnet): suboptimal qnet model for final run
                    time (str): time of completion of final run
                    
    """
    if not os.path.isdir(shap_dir):
        os.makedirs(shap_dir)

    def _get_labels(num_lvls):
        """
        Helper function in case we want to quantize into more than 26 levels
        """
        import string
        lbls = list(string.ascii_uppercase)
        for i in range(2, int(np.ceil(num_lvls / 26)) + 1):
            lbls = lbls + [char * i for char in string.ascii_uppercase]

        lbls = tuple(lbls[:num_lvls])

        return {lbl: idx for idx, lbl in enumerate(lbls)}

    def _get_quantizer(data, num_levels=26):
        """
        Construct a quantizer from provided `data`
        """
        quantizer = qbiome.quantizer.Quantizer(num_levels=num_levels)
        if num_levels > 26:
            quantizer.labels = _get_labels(num_levels)
        data_quantized = quantizer.quantize_df(data)

        return quantizer

    def _quant_df_to_np(df):
        """
        Converted quantized pandas dataframe into numpy format for qnet computations
        """
        df_np = np.char.replace(
            df.drop("subject_id", axis=1, errors="ignore").to_numpy(dtype="str"),
            "nan",
            "",
        )
        return df_np

    def risk_np(x, ahcg, shcg):
        """
        Compute risk from quantized numpy sequence x using appropriate and suboptimal qnets
        """
        theta_s = qnet.qdistance(np.full_like(x, fill_value=""), x, shcg, shcg)
        theta_a = qnet.qdistance(np.full_like(x, fill_value=""), x, ahcg, ahcg)
        if theta_a > 0:
            risk = theta_s / theta_a
        elif theta_s > 0:
            risk = 100
        else:
            # risk = np.nan
            risk = 1

        return risk
    
    if dat is None and dat_quant is None:
        raise Exception("Data for new subjects not provided. Pass either dat or dat_quant.")

    if quantizer is None:
        if appr_data is not None and sub_data is not None:
            quantizer = _get_quantizer(pd.concat([appr_data, sub_data]))
        else:
            raise Exception("Quantizer must be provided if appr_data and sub_data are not provided.")

    def __get_qnet(
        data=None, data_quantized=None, num_levels=26, quantizer=quantizer, alpha=0.3, min_samples_split=2
    ):
        """
        Compute qnet from provided data
        """
        if data is None and data_quantized is None:
            raise Exception("Either data or data_quantized must be provided.")
        if quantizer is None:
            quantizer = qbiome.quantizer.Quantizer(num_levels=num_levels)
        if num_levels > 26:
            quantizer.labels = _get_labels(num_levels)
        orchestrator = qbiome.qnet_orchestrator.QnetOrchestrator(quantizer)
        if data_quantized is None:
            data_quantized = quantizer.quantize_df(data)
        features, label_matrix = quantizer.get_qnet_inputs(data_quantized)
        orchestrator.train_qnet(
            features,
            label_matrix,
            alpha=alpha,
            min_samples_split=min_samples_split,
        )

        return orchestrator.model
    
    if ahcg_lst is not None and shcg_lst is not None:
        ahcg_lst = (ahcg_lst * n_runs)[:n_runs]
        shcg_lst = (shcg_lst * n_runs)[:n_runs]
    elif appr_data is None and appr_data_quant is None:
        raise Exception("Either appr_data/sub_data, appr_data_quant/sub_data_quant, or ahcg_lst/shcg_lst are required")

    if verbose is True:
        print("Starting SHAP runs\n")
    for i in tqdm(range(n_runs), desc="n_runs"):
        if ahcg_lst is None:
            if verbose is True:
                print("Fitting approriate cohort model\n")
            if appr_data is not None:
                ahcg = __get_qnet(data=appr_data)
            else:
                ahcg = __get_qnet(data_quantized=appr_data_quant)
        else:
            ahcg = ahcg_lst[i]
        if shcg_lst is None:
            if verbose is True:
                print("Fitting suboptimal cohort model\n")
            if sub_data is not None:
                shcg = __get_qnet(data=sub_data)
            else:
                shcg = __get_qnet(data_quantized=sub_data_quant)
        else:
            shcg = shcg_lst[i]

        def __risk_fquant_shap(x):
            r = np.array([risk_np(s, ahcg, shcg) for s in x])

            return r

        shaps = []

        background_data = np.vstack(
            (
                np.array(
                    [
                        qsampling.qsample(
                            np.full(len(ahcg.feature_names), fill_value=""),
                            ahcg,
                            5000,
                        )
                        for i in range(n_qsamps)
                    ]
                ),
                np.array(
                    [
                        qsampling.qsample(
                            np.full(len(shcg.feature_names), fill_value=""),
                            shcg,
                            5000,
                        )
                        for i in range(n_qsamps)
                    ]
                ),
            )
        )

        explainer = shap.KernelExplainer(
            __risk_fquant_shap,
            background_data,
        )

        if dat_quant is None:
            dat_q_np = _quant_df_to_np(quantizer._quantize_df(dat))
        else:
            dat_q_np = _quant_df_to_np(dat_quant)

        if verbose is True:
            print("Computing SHAP values for subjects this run:\n")
        for s in tqdm(dat_q_np, desc="subjects"):
            shaps.append(explainer.shap_values(s, nsamples=nsamples))

        time = str(datetime.now())

        np.savetxt(
            shap_dir + "shap_vals_risk " + time + ".csv",
            np.array(shaps),
            delimiter=",",
        )

        if save_qnets is True:
            qnet.save_qnet(
                ahcg, shap_dir + "qnet-appr-" + time + ".joblib"
            )
            os.remove(shap_dir + "qnet-appr-" + time + ".joblib")
            qnet.save_qnet(
                shcg, shap_dir + "qnet-sub-" + time + ".joblib"
            )
            os.remove(shap_dir + "qnet-sub-" + time + ".joblib")

    return np.array(shaps), explainer, ahcg, shcg, time

## Compute Shap values from existing models, quantizer

In [5]:
get_shaps(
    dat_quant=boston_data_quant.head(1),
    #appr_data=appr_data,
    #sub_data=sub_data,
    quantizer=qnt,
    ahcg_lst=appr_qnet_lst,
    shcg_lst=sub_qnet_lst,
    nsamples=1100,
    n_qsamps=2,
    n_runs=2,
);

Starting SHAP runs



n_runs:   0%|          | 0/2 [00:00<?, ?it/s]

Computing SHAP values for subjects this run:



subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Computing SHAP values for subjects this run:



subjects:   0%|          | 0/1 [00:00<?, ?it/s]

# Classify subject phenotypes

## Classification function

In [6]:
def classify_subjects(
    subject_data=None,
    subject_data_quant=None,
    shap_df=None,
    bacteroidia_regex="^(Bacteroidia_2)",
    actino_regex="^(Actinobacteria_2)",
    shap_dir="tmp_shap/",
    thresh=0.005,
    appr_data=None,
    sub_data=None,
    appr_data_quant=None,
    sub_data_quant=None,
    nsamples=1100,
    quantizer=None,
    ahcg_lst=None,
    shcg_lst=None,
    n_qsamps=3,
    n_runs=1,
    save_qnets=False,
    verbose=True,
):
    """
    Classify subjects into interventional phenotypes.

            Parameters:
                    subject_data (pandas.core.frame.DataFrame): wide form abundance data of subjects for which to compute shap_values. incompatible with shap_df
                    subject_data_quant (pandas.core.frame.DataFrame): wide form quantized abundance data of subjects for which to compute shap_values. incompatible with shap_df
                    shap_df (pandas.core.frame.DataFrame): dataframe of shap values for subjects to be classified. incompatible with subject_data
                    bacteroidia_regex (str): regex identifying relevant Bacteroidia variables on which to base interventional grouping
                    actino_regex (str): regex identifying relevant Actinobacteria variables on which to base interventional grouping
                    thresh (float): require all relevant shap values of a subject < `thresh` to assign subject to the corresponding phenotype
                    appr_data (pandas.core.frame.DataFrame): (used with subject_data only) dataframe of abundance observations used to construct the appropriate cohort model
                    sub_data (pandas.core.frame.DataFrame): (used with subject_data only) dataframe of abundance observations used to construct the suboptimal cohort model
                    nsamples (int): (used with subject_data only) passed to shap.KernelExplainer.shap_values()
                    ahctg_lst (list): (used with subject_data only) optional list of pretrained qnet models used for the appropriate cohort (one per run, recycled if needed). can be omitted and computed from provided data
                    shctg_lst (list): (used with subject_data only) optional list of pretrained qnet models used for the suboptimal cohort (one per run, recycled if needed). can be omitted and computed from provided data
                    n_qsamps (int): (used with subject_data only) number of qsamples to generate from each cohort to use as background data for the Shap computation
                    n_runs (int): (used with subject_data only) number of times to regenerate shap values for each subject in `dat`
                    shap_dir (str): (used with subject_data only) directory to save computed shap values and computed models if `save_qnets` is enabled
                    save_qnets (bool): (used with subject_data only) indicator of whether to save the qnet computed on each run
                    verbose (bool): (used with subject_data only) print status updates during shap computation

            Returns:
                    phenotype_df (pandas.core.frame.DataFrame): assignment of subjects by index into phenotype groupings
    """
    
    if subject_data is None and subject_data_quant is None:
        if shap_df is None:
            raise Exception("Either subject data or a shap dataframe must be provided.")
    if subject_data is not None or subject_data_quant is not None:
        if shap_df is not None:
            raise Exception("Provide either subject data or a shap dataframe, not both.")
    
        
    if shap_df is None:        
        get_shaps(
            dat=subject_data,
            dat_quant=subject_data_quant,
            appr_data=appr_data,
            sub_data=sub_data,
            appr_data_quant=appr_data_quant,
            sub_data_quant=sub_data_quant,
            quantizer=quantizer,
            ahcg_lst=ahcg_lst,
            shcg_lst=shcg_lst,
            nsamples=nsamples,
            n_qsamps=n_qsamps,
            n_runs=n_runs,
            save_qnets=save_qnets,
            verbose=verbose,
        )
        import glob
        shaps_mean = np.mean(
        np.stack(
            [np.loadtxt(file, delimiter=",") for file in glob.glob(shap_dir + "*.csv")],
            axis=-1,
        ),
        axis=-1,
        )

        if subject_data is not None:
            shap_df = pd.DataFrame(
                data=[shaps_mean],
                columns=subject_data.drop("subject_id", axis=1).columns,
            )
        else:
            shap_df = pd.DataFrame(
                data=[shaps_mean],
                columns=subject_data_quant.drop("subject_id", axis=1).columns,
            )
    

    bact_subs = (
        shap_df.filter(regex=bacteroidia_regex, axis=1)
        .max(axis=1)[lambda x: x < thresh]
        .index.values
    )
    act_subs = (
        shap_df.filter(regex=actino_regex, axis=1)
        .max(axis=1)[lambda x: x < thresh]
        .index.values
    )
    act_subs = [x for x in act_subs if x not in bact_subs]
    confl_subs = [
        x for x in shap_df.index.values if (x not in bact_subs) & (x not in act_subs)
    ]

    bact_df = (
        pd.DataFrame(bact_subs)
        .rename(columns={0: "subject idx"})
        .assign(phenotype="Bacteroidia intervention")
    )
    act_df = (
        pd.DataFrame(act_subs)
        .rename(columns={0: "subject idx"})
        .assign(phenotype="Actinobacteria intervention")
    )
    confl_df = (
        pd.DataFrame(confl_subs)
        .rename(columns={0: "subject idx"})
        .assign(phenotype="Null intervention")
    )

    return pd.concat([bact_df, act_df, confl_df])

## Classify from abundance data

In [7]:
classify_subjects(
    subject_data_quant=boston_data_quant.head(1),
    appr_data_quant=appr_data_quant,
    sub_data_quant=sub_data_quant,
    quantizer=qnt,
    #ahcg_lst=a_lst,
    #shcg_lst=s_lst,
    nsamples=1100,
    n_qsamps=2,
    n_runs=2,)

Starting SHAP runs



n_runs:   0%|          | 0/2 [00:00<?, ?it/s]

Fitting approriate cohort model

Fitting suboptimal cohort model

Computing SHAP values for subjects this run:



subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Fitting approriate cohort model

Fitting suboptimal cohort model

Computing SHAP values for subjects this run:



subjects:   0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,subject idx,phenotype
0,0.0,Bacteroidia intervention


## Classify from precomputed shaps

In [8]:
import glob

shap_dir = "./tmp_shap/"

shaps_mean = np.mean(
    np.stack(
        [np.loadtxt(file, delimiter=",") for file in glob.glob(shap_dir + "*.csv")],
        axis=-1,
    ),
    axis=-1,
)

shap_df = pd.DataFrame(
    data=[shaps_mean],
    columns=boston_data_quant.drop("subject_id", axis=1).columns,
)

classify_subjects(shap_df=shap_df)

Unnamed: 0,subject idx,phenotype
0,0.0,Bacteroidia intervention
