<a href="https://colab.research.google.com/github/davidmirror-ops/flyte-school/blob/main/Secure_your_AI_orchestration_platform.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🔐 ML Security Workshop: Serverless

[![Colab Badge](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/unionai-oss/ml-security/blob/main/workshop_serverless.ipynb)

First, go to https://signup.union.ai/ to sign up for a Union account. This
will take a few minutes, after which you should be able to go to
https://serverless.union.ai/ to see the Union Serverless dashboard.

In [None]:
%pip install -U 'flytekit>=1.14.0' 'union==0.1.138' flytekitplugins-onnxpytorch flytekitplugins-pandera>=0.16.0 joblib openai pandas pyarrow scikit-learn

## Login to Union Serverless

In [None]:
!union create login --auth device-flow --serverless

🔐 [33mConfiguration saved to [0m[33m/root/.union/[0m[33mconfig.yaml[0m
Login successful into [1;32mserverless[0m


## Part 1: 🥒 Pickled Model Attack

### 🏋️ Training a model

In [None]:
import sys
from functools import partial
from typing import NamedTuple

import joblib
import pandas as pd

import union
from flytekit.deck import MarkdownRenderer
from flytekit.types.file import FlyteFile

from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score



image = union.ImageSpec.from_env(
    name="ml-security",
    packages=[
        "bandit",
        "flytekit>=1.14.0",
        "joblib",
        "openai",
        "pandas",
        "pyarrow",
        "scikit-learn",
        "union==0.1.138",
    ],
)

task = partial(
    union.task,
    container_image=image,
    cache=True,
    cache_version="4",
)

ModelOutput = NamedTuple("Output", [("model", FlyteFile), ("accuracy", float)])


@task
def load_data() -> tuple[pd.DataFrame, pd.Series]:
    wine = load_wine()
    X = pd.DataFrame(wine.data, columns=wine.feature_names)
    y = pd.Series(wine.target)
    return X, y


@task
def split_data(X: pd.DataFrame, y: pd.Series) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    return train_test_split(X, y, test_size=0.2, random_state=42)


@task
def train_model(X_train: pd.DataFrame, y_train: pd.Series) -> FlyteFile:
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    path = "model.joblib"
    joblib.dump(model, path)
    return FlyteFile(path=path)


@task(enable_deck=True)
def evaluate_model(model: FlyteFile, X_test: pd.DataFrame, y_test: pd.Series) -> float:
    with open(model, "rb") as f:
        model = joblib.load(f)
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    deck = union.Deck(name="Accuracy Report", html=MarkdownRenderer().to_html(f"# Test accuracy: {accuracy}"))
    union.current_context().decks.insert(0, deck)
    return accuracy


@union.workflow
def wine_classification_workflow() -> ModelOutput:
    X, y = load_data()
    X_train, X_test, y_train, y_test = split_data(X, y)
    model = train_model(X_train, y_train)
    accuracy = evaluate_model(model, X_test, y_test)
    return model, accuracy

Create a `UnionRemote` client to run our workflows.

In [None]:
from union.remote import UnionRemote

serverless = UnionRemote()

In [None]:
execution = serverless.execute(wine_classification_workflow, inputs={})
execution

Read the model file back into the notebook session:

In [None]:
execution.wait(poll_interval=1)
model_file = execution.outputs["model"]

with open(model_file, "rb") as f:
    model = joblib.load(f)

model

Load some features and make predictions:

In [None]:
features, _ = load_data()
predictions = model.predict(features)
predictions

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2])

### 🍽️ Serving the model in batch mode

Here we define a simple batch prediction task.

In [None]:
Predictions = NamedTuple("Predictions", [("predictions", list[float])])

@union.task(container_image=image)
def batch_predict(model: FlyteFile, data: pd.DataFrame) -> Predictions:
    with open(model, "rb") as f:
        model = joblib.load(f)
    return Predictions([float(x) for x in model.predict(data)])

Run it on Union Serverless:

In [None]:
execution = serverless.execute(
    batch_predict,
    inputs={"model": model_file, "data": features}
)
execution

Fetch the predictions:

In [None]:
execution.wait(poll_interval=1)
predictions = execution.outputs["predictions"]
predictions[:5]

[0.0, 0.0, 0.0, 0.0, 0.0]

### 🥒 The Pickle Attack

In [None]:
class PickleAttack:
    def __init__(self): ...

    def __reduce__(self):
        # os.system will execute the command
        import os
        return (os.system, ('echo "👋 Hello there, I\'m a pickle attack! 🥒"',))


fake_model = PickleAttack()
fake_model_path ="model.joblib"
with open(fake_model_path, "wb") as f:
    joblib.dump(fake_model, f)

fake_model_path

'model.joblib'

In [None]:
execution = serverless.execute(
    batch_predict, inputs={"model": fake_model_path, "data": features}
)
execution

### Mitigation: include md5hash metadata

In [None]:
import hashlib
from dataclasses import dataclass


@dataclass
class Model:
    file: FlyteFile
    md5hash: str

    def __post_init__(self):
        with open(self.file, "rb") as f:
            md5hash = hashlib.md5(f.read()).hexdigest()
        if md5hash != self.md5hash:
            raise ValueError(
                "⛔️ Model md5hash mismatch: expected "
                f"{self.md5hash}, found {md5hash}."
            )

ModelOutput = NamedTuple("Output", [("model", Model), ("accuracy", float)])

@task
def secure_train_model(X_train: pd.DataFrame, y_train: pd.Series) -> Model:
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    path = "model.joblib"
    joblib.dump(model, path)
    md5hash = hashlib.md5(open(path, 'rb').read()).hexdigest()
    return Model(file=FlyteFile(path=path), md5hash=md5hash)


@task(enable_deck=True)
def secure_evaluate_model(model: Model, X_test: pd.DataFrame, y_test: pd.Series) -> float:
    with open(model.file, "rb") as f:
        model = joblib.load(f)
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    deck = union.Deck(name="Accuracy Report", html=MarkdownRenderer().to_html(f"# Test accuracy: {accuracy}"))
    union.current_context().decks.insert(0, deck)
    return accuracy


@union.workflow
def secure_wine_classification_workflow() -> ModelOutput:
    X, y = load_data()
    X_train, X_test, y_train, y_test = split_data(X, y)
    model = secure_train_model(X_train, y_train)
    accuracy = secure_evaluate_model(model, X_test, y_test)
    return model, accuracy

Run the secure training workflow:

In [None]:
execution = serverless.execute(secure_wine_classification_workflow, inputs={})
execution

In [None]:
execution.wait(poll_interval=1)
model_file = execution.outputs["model"]

with open(model_file.file, "rb") as f:
    model = joblib.load(f)

print(f"md5hash: {model_file.md5hash}")
model

md5hash: b087efd0595a961982db5d35bce8a690


Create a secure batch prediction workflow:

In [None]:
@union.task(container_image=image)
def model_guard(model: FlyteFile, md5hash: str) -> Model:
    return Model(file=model, md5hash=md5hash)


@union.task(container_image=image)
def secure_batch_predict(model: Model, data: pd.DataFrame) -> Predictions:
    with open(model.file, "rb") as f:
        model = joblib.load(f)
    return Predictions([float(x) for x in model.predict(data)])


@union.workflow
def secure_batch_prediction_workflow(
    model: FlyteFile,
    md5hash: str,
    data: pd.DataFrame
) -> Predictions:
    checked_model = model_guard(model, md5hash)
    return secure_batch_predict(checked_model, data)

Generate predictions with the correct model

In [None]:
execution = serverless.execute(
    secure_batch_prediction_workflow,
    inputs={
        "model": model_file.file,
        "md5hash": model_file.md5hash,
        "data": features
    }
)
execution

In [None]:
execution.wait(poll_interval=1)
predictions = execution.outputs["predictions"]
predictions[:5]

[0.0, 0.0, 0.0, 0.0, 0.0]

Call the secure batch prediction workflow with the fake model

In [None]:
execution = serverless.execute(
    secure_batch_prediction_workflow,
    inputs={
        "model": fake_model_path,
        "md5hash": model_file.md5hash,
        "data": features
    }
)
execution

### Pickle attack: mitigating by using secure serialization format (ONNX)

In [None]:
import sys
from functools import partial
from typing import NamedTuple

import onnxruntime as rt
from flytekit.types.file import ONNXFile
from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig
from skl2onnx.common.data_types import FloatTensorType
import pandas as pd

import union
from flytekit.deck import MarkdownRenderer
from flytekit.types.file import FlyteFile

from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score



image = union.ImageSpec.from_env(
    name="ml-security",
    packages=[
        "bandit",
        "flytekit>=1.14.0",
        "joblib",
        "openai",
        "pandas",
        "pyarrow",
        "scikit-learn",
        "union==0.1.138",
        "flytekitplugins-onnxpytorch"
    ],
)

task = partial(
    union.task,
    container_image=image,
    cache=True,
    cache_version="4",
)

ModelOutput = NamedTuple("Output", [("model", FlyteFile), ("accuracy", float)])


@task
def load_data() -> tuple[pd.DataFrame, pd.Series]:
    wine = load_wine()
    X = pd.DataFrame(wine.data, columns=wine.feature_names)
    y = pd.Series(wine.target)
    return X, y


@task
def split_data(X: pd.DataFrame, y: pd.Series) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    return train_test_split(X, y, test_size=0.2, random_state=42)


@task
def train_model(X_train: pd.DataFrame, y_train: pd.Series) -> FlyteFile:
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    path = "model.joblib"
    joblib.dump(model, path)
    return FlyteFile(path=path)


@task(enable_deck=True)
def evaluate_model(model: FlyteFile, X_test: pd.DataFrame, y_test: pd.Series) -> float:
    with open(model, "rb") as f:
        model = joblib.load(f)
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    deck = union.Deck(name="Accuracy Report", html=MarkdownRenderer().to_html(f"# Test accuracy: {accuracy}"))
    union.current_context().decks.insert(0, deck)
    return accuracy


@union.workflow
def wine_classification_workflow() -> ModelOutput:
    X, y = load_data()
    X_train, X_test, y_train, y_test = split_data(X, y)
    model = train_model(X_train, y_train)
    accuracy = evaluate_model(model, X_test, y_test)
    return model, accuracy

## Part 2: ☣️ Data poisoning

In [None]:
import typing
import union
import flytekitplugins.pandera
import joblib
import pandas as pd
import pandera as pa
from flytekit import ImageSpec, task, workflow
from flytekit.types.file import JoblibSerializedFile
from pandera.typing import DataFrame, Index, Series
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

custom_image = union.ImageSpec.from_env(
                   name="ml-security-02",
                   packages=[
                       "flytekitplugins-pandera",
                       "scikit-learn",
                       "pyarrow",
                       "pandera",
                       "pyspark",
                    ]
                )

#---Encode the statistical properties of the data into a Pandera schema---#

class RawData(pa.DataFrameModel):
    age: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 200})
    sex: Series[int] = pa.Field(isin=[0, 1])
    cp: Series[int] = pa.Field(
        isin=[
            1,  # typical angina
            2,  # atypical angina
            3,  # non-anginal pain
            4,  # asymptomatic
        ]
    )
    trestbps: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 200})
    chol: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 600})
    fbs: Series[int] = pa.Field(isin=[0, 1])
    restecg: Series[int] = pa.Field(
        isin=[
            0,  # normal
            1,  # having ST-T wave abnormality
            2,  # showing probable or definite left ventricular hypertrophy by Estes' criteria
        ]
    )
    thalach: Series[int] = pa.Field(in_range={"min_value": 0, "max_value": 300})
    exang: Series[int] = pa.Field(isin=[0, 1])
    oldpeak: Series[float] = pa.Field(in_range={"min_value": 0, "max_value": 10})
    slope: Series[int] = pa.Field(
        isin=[
            1,  # upsloping
            2,  # flat
            3,  # downsloping
        ]
    )
    ca: Series[int] = pa.Field(isin=[0, 1, 2, 3])
    thal: Series[int] = pa.Field(
        isin=[
            3,  # normal
            6,  # fixed defect
            7,  # reversible defect
        ]
    )
    target: Series[int] = pa.Field(ge=0, le=4)

    class Config:
        coerce = True


### Fetch the dataset

In [None]:
@union.task(container_image=custom_image)
def fetch_raw_data() -> DataFrame[RawData]:
    print("fetching raw data")
    data_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data"
    return (
        pd.read_csv(data_url, header=None, names=RawData.to_schema().columns.keys())
        .replace({"ca": {"?": None}, "thal": {"?": None}})
        .dropna(subset=["ca", "thal"])
        .astype({"ca": float, "thal": float})
    )


### Parsing the raw data

In [None]:
class ParsedData(RawData):
    target: Series[int] = pa.Field(isin=[0, 1])

#---Check fundamental medical relationships between features and target---#
class TrainingData(ParsedData):
    @pa.dataframe_check(error="Patients with heart disease should not have higher average cholesterol")
    def validate_cholesterol(cls, df: pd.DataFrame) -> bool:
        healthy_chol = df[df.target == 0].chol.mean()
        disease_chol = df[df.target == 1].chol.mean()
        return disease_chol > healthy_chol

    @pa.dataframe_check(error="Patients with heart disease should not have lower max heart rate (thalach) on average")
    def validate_max_heart_rate(cls, df: pd.DataFrame) -> bool:
        healthy_thalach = df[df.target == 0].thalach.mean()
        disease_thalach = df[df.target == 1].thalach.mean()
        return disease_thalach < healthy_thalach

    @pa.dataframe_check(error="Exercise-induced angina is not more common in disease group")
    def validate_exercise_induced_angina(cls, df: pd.DataFrame) -> bool:
        exang_ratio = df[df.target == 1].exang.mean() / df[df.target == 0].exang.mean()
        return exang_ratio > 2.0

    @pa.dataframe_check
    def validate_feature_correlations(cls, df: pd.DataFrame) -> bool:
        """Ensure key feature correlations with target remain strong"""
        corrs = df.corr()['target'].abs()
        return all(corrs[['cp', 'exang', 'oldpeak']] > 0.3)  # These should be strongly correlated

@union.task(container_image=custom_image)
def parse_raw_data(raw_data: DataFrame[RawData]) -> DataFrame[ParsedData]:
    print("parsing raw data")
    return raw_data.assign(target=lambda _: (_.target > 0).astype(int))

### Splitting the data

In [None]:
DataSplits = typing.NamedTuple("DataSplits", training_set=DataFrame[ParsedData], test_set=DataFrame[ParsedData])


@union.task(container_image=custom_image)
def split_data(parsed_data: DataFrame[ParsedData], test_size: float, random_state: int) -> DataSplits:
    print("splitting data")
    training_set = parsed_data.sample(frac=test_size, random_state=random_state)
    test_set = parsed_data[~parsed_data.index.isin(training_set.index)]
    return training_set, test_set

### Data poisoning attack

In [None]:
@union.task(container_image=custom_image)
def poison_training_data(
    training_set: DataFrame[ParsedData],
    poison_fraction: float,
    random_state: int
) -> DataFrame[TrainingData]:
    print("starting poisin training data")
    if poison_fraction <= 0:
        return training_set
    print("POISONING DATA")
    poisoned = training_set.copy()
    n_poison = int(len(poisoned) * poison_fraction)
    poisoned_indices = poisoned.sample(n=n_poison, random_state=random_state).index
    poisoned.loc[poisoned_indices, 'target'] = 1 - poisoned.loc[poisoned_indices, 'target']
    return poisoned

In [None]:
data = fetch_raw_data()
data.head()

fetching raw data


Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,1,1,145,233,1,2,150,0,2.3,3,0,6,0
1,67,1,4,160,286,0,2,108,1,1.5,2,3,3,2
2,67,1,4,120,229,0,2,129,1,2.6,2,2,7,1
3,37,1,3,130,250,0,0,187,0,3.5,3,0,3,0
4,41,0,2,130,204,0,2,172,0,1.4,1,0,3,0


In [None]:
parsed_data = parse_raw_data(data)

parsing raw data


In [None]:
poison_training_data(parsed_data, 0.2, 101)

starting poisin training data
POISONING DATA


### Train the model

In [None]:
def get_features_and_target(dataset):
    X = dataset[[x for x in dataset if x != "target"]]
    y = dataset["target"]
    return X, y

@union.task(container_image=custom_image)
def train_model(training_set: DataFrame[ParsedData], random_state: int) -> JoblibSerializedFile:
    # Pandera will validate TrainingData schema before execution
    model = RandomForestClassifier(n_estimators=100, random_state=random_state)
    X, y = get_features_and_target(training_set)
    model.fit(X, y)
    model_fp = "/tmp/model.joblib"
    joblib.dump(model, model_fp)
    return JoblibSerializedFile(path=model_fp)

### Model evaluation

In [None]:
@union.task(container_image=custom_image)
def evaluate_model(model: JoblibSerializedFile, test_set: DataFrame[ParsedData]) -> float:
    with open(model, "rb") as f:
        model = joblib.load(f)
    X, y = get_features_and_target(test_set)
    preds = model.predict(X)
    return accuracy_score(y, preds)

### Put everything together

In [None]:
@union.workflow
def pipeline(
    data_random_state: int = 42,
    model_random_state: int = 42,
    poison_fraction: float = 0.0,  # Default to no poisoning
    poison_random_state: int = 42
) -> float:
    raw_data = fetch_raw_data()
    parsed_data = parse_raw_data(raw_data=raw_data)
    poisoned_training_set = poison_training_data(
        training_set=parsed_data,
        poison_fraction=poison_fraction,
        random_state=poison_random_state
    )
    training_set, test_set = split_data(
        parsed_data=poisoned_training_set,
        test_size=0.2,
        random_state=data_random_state
    )
    model = train_model(
        training_set=training_set,
        random_state=model_random_state
    )
    return evaluate_model(model=model, test_set=test_set)


### Run the workflow with no poisoning

In [None]:
pipeline(poison_random_state=4, poison_fraction=0.2)

fetching raw data
parsing raw data
starting poisin training data
POISONING DATA


In [None]:
execution = serverless.execute(
    pipeline,
    inputs={}
)
execution

[34mImage ml-security-02:BCSs9GJ09Sg45UeNNYy4rw was not found or has expired.[0m
[34m[1m🐳 Submitting a new build...[0m


[33m[1m👍 Build submitted![0m
[1m⏳ Waiting for build to finish at: [36mhttps://serverless.union.ai/org/cosmicbboy/projects/system/domains/production/executions/aqhnqh966h58wjrrq7cs[0m[0m
[32m[1m✅ Build completed in 0:01:07![0m


### Trigger and prevent the attack

In [None]:
mitigated_execution = serverless.execute(
    pipeline,
    inputs={
        "poison_random_state": 4,
        "poison_fraction": 0.8,
    }
)
mitigated_execution

Congratulations 🎉! You've completed the workshop.

To summarize, you've learned the basic concepts, setup, and mitigations for
the pickled model attack and the LLM prompt injection attack using Union together
with popular open source tools for code analysis and security.

## 🤔 Learn more

### Tools and Resources

- [bandit](https://github.com/PyCQA/bandit): code scanning tool for Python
- [skops](https://skops.readthedocs.io/en/stable/): model serialization library for scikit-learn
- [onnx](https://onnx.ai/): model serialization format for ML
- [safetensors](https://huggingface.co/docs/safetensors/en/index): model serialization library for PyTorch
- [LLM Guard](https://llm-guard.com/get_started/quickstart/): input/output guardrails for LLMs
- [Llama Guard](https://arxiv.org/abs/2312.06674): model for IO safeguards for LLMs

You can learn more about Union at https://union.ai.

If you have any questions, please reach out to us at support@union.ai.

Thank you for attending!