# Partial least squares - discriminant analysis

TODO: add explanation of PLS-DA / links to further reading

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from enrichment import get_ora
from sklearn.cross_decomposition import PLSRegression
from sklearn.metrics import accuracy_score


sns.set_theme()


def separate(data, col, into, sep="_", **kwargs):
    return data.assign(
        **data.get(col).str.split(sep, expand=True, **kwargs)
        .rename(columns={i: x for i, x in enumerate(into)})
    )

## Load the data

We load the normalised counts from DESeq2 analysis:

In [None]:
df_norm = pd.read_csv("../Output/DESeq2/normalised_counts.csv", index_col=0)

We also load the MIC values and we create additional columns for level of 
antibiotic resistence for both antibiotics (0 for low ABR and 1 for high ABR):

In [None]:
mic_threshold = 30
df_mic = (
    pd.read_csv("../Data/mic.csv", dtype={"strain": str})
    .assign(
        cza_mic_level=lambda x: (x.cza_mic > mic_threshold).astype(int),
        mem_mic_level=lambda x: (x.mem_mic > mic_threshold).astype(int)
    )
)

We can plot the MIC values like below:

In [None]:
# Convert mic table to long format
df_mic_long = (
    df_mic.melt(
        id_vars=["strain", "condition"], 
        value_vars=["cza_mic", "mem_mic"], 
        var_name="mic", 
        value_name="mic_value"
    )
    .assign(mic=lambda x: x.mic.str.removesuffix("_mic").str.upper())
)

# Make the plot with seaborn FacetGrid
g = (
    sns.FacetGrid(
        df_mic_long.query("condition != 'P'"), 
        col="mic", 
        row="condition",
        margin_titles=True
    )
    .map_dataframe(sns.barplot, x="strain", y="mic_value")
    .set(ylabel="MIC value", xlabel="Strain")
    .refline(y=mic_threshold)
)

# Changes to log2 scale for y-axis
yticks = 2.0**np.arange(0, 10, 2)
for ax in g.axes.flatten():
    ax.set_yscale("log", base=2)
    ax.set
    ax.set(yticks=yticks, yticklabels=yticks, ylim=(0.5, 2**9))

## Create the model

To create the PLS-DA model, we first need to transform the dataframe such that the each gene is a feature in the model:

In [None]:
df_norm_rot = (
    df_norm
    .transpose()
    .reset_index(names="sample")
    .pipe(separate, "sample", ["strain", "condition", "replicate"], sep="_")
    .merge(df_mic, on=["strain", "condition"], how="left")
)
df_norm_rot.head()

We get columns with features (genes) and the target variable (mic values for one antibiotic) from the above table:

In [None]:
meta_cols = ["sample", "strain", "condition", "replicate", "cza_mic", "mem_mic", "cza_mic_level", "mem_mic_level"]

x = df_norm_rot.drop(columns=meta_cols).to_numpy()
y = df_norm_rot.mem_mic_level.to_numpy()

We create the PLS object and fit it to the data above:

In [None]:
pls = PLSRegression(n_components=2)
pls.fit(x, y)

We can plot the factors values and see how samples separate depending on the MIC level values:

In [None]:
x_pls, y_pls = pls.transform(x, y)

df_pls = (
    pd.DataFrame(x_pls, columns=["Factor1", "Factor2"])
    .join(df_norm_rot.get(meta_cols))
)

fig, ax = plt.subplots()
sns.scatterplot(
    df_pls.replace({"mem_mic_level": {0: "low", 1: "high"}}), 
    x="Factor1",
    y="Factor2",
    style="condition", 
    hue="mem_mic_level",
    ax=ax
)

To calculate the accuracy of the model, we predict the MIC values and use `accuracy_score()` function from scikit-learn library:

In [None]:
y_pred = [1 if y > 0.5 else 0 for y in pls.predict(x)]
accuracy_score(y, y_pred)

In [None]:
df_loadings = (
    pd.DataFrame(pls.x_loadings_, columns=["PLS1", "PLS2"])
    .assign(feature=df_norm_rot.drop(columns=meta_cols).columns)
)

**Tasks:** 
- can you determine what the top features have in common for this model?
- try to fit the other antibiotic
- is the accuracy of this model reasonable? What could you do to check this?