# Welcome to the MALDI-UI session

In this notebook, we will work with **MALDI-TOF mass spectrometry data**, specifically the **DRIAMS B database**, which contains routine MALDI-TOF MS data from the Canton Hospital Basel-Land.

Our objective for this practical session is to create a simple UI, where the user can:
- Load in dataset of spectra
- Perform basic preprocessing of the MALDI-TOF MS data
- Show a selected spectra
- Check for quality metrics
- See model prediction
- Show important features for the prediction

Let's get started!


# Set-up

In [30]:
!pip install gradio plotly maldi-nn shap



In [31]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [119]:
from tensorflow.keras.models import load_model
import joblib
import json
import shap

mlp_model = load_model("/content/drive/MyDrive/ESCMID-AI-Data-Advances-Techniques/Models/mlp_keras_model.h5")
xgb_model = joblib.load("/content/drive/MyDrive/ESCMID-AI-Data-Advances-Techniques/Models/xgboost_top5_species_model.pkl")

# Do we implement shapley values as well?

# SHAPLEY values
# Load saved background
background = np.load("/content/drive/MyDrive/ESCMID-AI-Data-Advances-Techniques/Models/mlp_shap_background.npy")
feature_names = [f"{2000 + i*3}-{2000 + (i+1)*3} Da" for i in range(6000)]

# For XGBoost (TreeExplainer is fast and efficient)
xgb_explainer = shap.Explainer(xgb_model)

# For MLP
mlp_explainer = shap.DeepExplainer(mlp_model, background)


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.


The structure of `inputs` doesn't match the expected structure.
Expected: input_layer
Received: inputs=['Tensor(shape=(100, 6000))']



# Gradio app

In [120]:
import gradio as gr
import pandas as pd
import plotly.graph_objs as go
import maldi_nn.spectrum as maldi_spectrum
from maldi_nn.spectrum import SpectrumObject
import numpy as np
import os


# --- Helper to read spectrum from txt ---
def read_spectrum_file(file):
    spec_obj = SpectrumObject.from_tsv(file)
    spec_obj.intensity = np.array(spec_obj.intensity[1:], dtype=float)
    spec_obj.mz = np.array(spec_obj.mz[1:], dtype=float)
    return spec_obj

# --- Preprocessing ---
def preprocess_spectra(files, var_stabilizer, smoother, baseline, normalizer, binner):
    if not files:
        return "No files uploaded", []

    spectra = [read_spectrum_file(file) for file in files]
    steps = []
    if var_stabilizer:
        steps.append(maldi_spectrum.VarStabilizer(method="sqrt"))
    if smoother:
        steps.append(maldi_spectrum.Smoother(halfwindow=10))
    if baseline:
        steps.append(maldi_spectrum.BaselineCorrecter(method="SNIP", snip_n_iter=20))
    if normalizer:
        steps.append(maldi_spectrum.Normalizer(sum=1))
    if binner:
        steps.append(maldi_spectrum.Binner(step=3))

    preprocessor = maldi_spectrum.SequentialPreprocessor(*steps)
    spectra_preprocessed = [preprocessor(spectrum) for spectrum in spectra]

    return "Preprocessing complete", spectra_preprocessed

# --- Plotting ---
def plot_preprocessed_spectra(preprocessed_spectra, files, selected_indexes):
    if not preprocessed_spectra or not selected_indexes:
        return go.Figure()

    selected_indexes = [int(i) for i in selected_indexes]
    fig = go.Figure()

    for i in selected_indexes:
        spectrum = preprocessed_spectra[i]
        filename = os.path.splitext(os.path.basename(files[i].name))[0]
        fig.add_trace(go.Scatter(x=spectrum.mz, y=spectrum.intensity, mode="lines", name=filename))

    fig.update_layout(title="Preprocessed Spectra", xaxis_title="m/z", yaxis_title="Intensity")
    return fig

# --- Prediction ---
def predict_species_from_spectra(preprocessed_spectra, selected_indexes, model_choice):
    if not preprocessed_spectra or not selected_indexes:
        return "No spectra or indexes selected"

    predictions = []
    selected_indexes = [int(i) for i in selected_indexes]

    for i in selected_indexes:
        spectrum = preprocessed_spectra[i]
        features = spectrum.intensity.reshape(1, -1)

        # Dummy model predictions – replace with actual logic
        if model_choice == "MLP":
            pred_probs = mlp_model.predict(features)[0]
            pred_class = np.argmax(pred_probs)
        else:
            pred_probs = xgb_model.predict_proba(features)[0]
            pred_class = np.argmax(pred_probs)

        predictions.append(f"Spectrum {i}: Predicted class {pred_class} (probs: {pred_probs})")

    return "\n".join(predictions)

# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("## MALDI-TOF MS Tool: Preprocessing, Visualization, and Species Prediction")

    file_input = gr.File(label="Upload .txt spectra files", file_types=[".txt"], file_count="multiple")

    with gr.Row():
        var_stabilizer = gr.Checkbox(label="Variance Stabilizer (sqrt)", value=True)
        smoother = gr.Checkbox(label="Smoother (halfwindow=10)", value=True)
        baseline = gr.Checkbox(label="Baseline Correction (SNIP)", value=True)
        normalizer = gr.Checkbox(label="Normalization (sum=1)", value=True)
        binner = gr.Checkbox(label="Binner (step=3)", value=True)

    selected_indexes = gr.CheckboxGroup(choices=[], label="Select Spectra")

    # Update selection options after file upload
    def update_choices(files):
        return gr.update(choices=[str(i) for i in range(len(files))], value=[])

    file_input.change(update_choices, inputs=file_input, outputs=selected_indexes)

    # State to hold preprocessed spectra between steps
    preprocessed_spectra_state = gr.State()

    # Preprocessing section
    preprocess_button = gr.Button("Run Preprocessing")
    preprocessing_output = gr.Textbox(label="Preprocessing Status")
    preprocess_button.click(
        fn=preprocess_spectra,
        inputs=[file_input, var_stabilizer, smoother, baseline, normalizer, binner],
        outputs=[preprocessing_output, preprocessed_spectra_state]
    )

    # Plotting section
    plot_button = gr.Button("Plot Selected Spectra")
    plot_output = gr.Plot(label="Processed Spectra")
    plot_button.click(
        fn=plot_preprocessed_spectra,
        inputs=[preprocessed_spectra_state, file_input, selected_indexes],
        outputs=plot_output
    )

    # Prediction section
    model_selector = gr.Dropdown(choices=["MLP", "XGBoost"], label="Select ML Model", value="MLP")
    predict_button = gr.Button("Run Prediction")
    prediction_output = gr.Textbox(label="Predicted Species")
    predict_button.click(
        fn=predict_species_from_spectra,
        inputs=[preprocessed_spectra_state, selected_indexes, model_selector],
        outputs=prediction_output
    )

demo.launch(share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://0ffdfcd199e119e271.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 109ms/step
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7863 <> https://0ffdfcd199e119e271.gradio.live


