# PyTorch MNIST Membership Inference demo

This notebook contains an end-to-end demostration of Dioptra that can be run on any modern laptop.

## Setup

Below we import the necessary Python modules and ensure the proper environment variables are set so that all the code blocks will work as expected.

In [None]:
# Import packages from the Python standard library
import importlib.util
import os
import sys
import pprint
import time
import warnings
from pathlib import Path


def register_python_source_file(module_name: str, filepath: Path) -> None:
    """Import a source file directly.

    Args:
        module_name: The module name to associate with the imported source file.
        filepath: The path to the source file.

    Notes:
        Adapted from the following implementation in the Python documentation:
        https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    """
    spec = importlib.util.spec_from_file_location(module_name, str(filepath))
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)


# Filter out warning messages
warnings.filterwarnings("ignore")

# Experiment name
EXPERIMENT_NAME = "pytorch_mnist_mi"

# Default address for accessing the RESTful API service
RESTAPI_ADDRESS = "http://localhost:80"

# Set DIOPTRA_RESTAPI_URI variable if not defined, used to connect to RESTful API service
if os.getenv("DIOPTRA_RESTAPI_URI") is None:
    os.environ["DIOPTRA_RESTAPI_URI"] = RESTAPI_ADDRESS

# Default address for accessing the MLFlow Tracking server
MLFLOW_TRACKING_URI = "http://localhost:35000"

# Set MLFLOW_TRACKING_URI variable, used to connect to MLFlow Tracking service
if os.getenv("MLFLOW_TRACKING_URI") is None:
    os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI

# Path to workflows archive
WORKFLOWS_TAR_GZ = Path("workflows.tar.gz")

# Register the examples/scripts directory as a Python module
register_python_source_file("scripts", Path("..", "scripts", "__init__.py"))

from scripts.client import DioptraClient
from scripts.utils import make_tar

# Import third-party Python packages
import numpy as np
from mlflow.tracking import MlflowClient

# Create random number generator
rng = np.random.default_rng(54399264723942495723666216079516778448)

## Submit and run jobs

The entrypoints that we will be running in this example are implemented in the Python source files under `src/` and the `src/MLproject` file.
To run these entrypoints within Dioptra's architecture, we need to package those files up into an archive and submit it to the Dioptra RESTful API to create a new job.
For convenience, we provide the `make_tar` helper function defined in `examples/scripts/utils.py`.

In [None]:
make_tar(["src"], WORKFLOWS_TAR_GZ)

To connect with the endpoint, we will use a client class defined in the `examples/scripts/client.py` file that is able to connect with the Dioptra RESTful API using the HTTP protocol.
We connect using the client below.
The client uses the environment variable `DIOPTRA_RESTAPI_URI`, which we configured at the top of the notebook, to figure out how to connect to the Dioptra RESTful API.

In [None]:
restapi_client = DioptraClient()

We need to register an experiment under which to collect our job runs.
The code below checks if the relevant experiment exists.
If it does, then it just returns info about the experiment, if it doesn't, it then registers the new experiment.

In [None]:
response_experiment = restapi_client.get_experiment_by_name(name=EXPERIMENT_NAME)

if response_experiment is None or "Not Found" in response_experiment.get("message", []):
    response_experiment = restapi_client.register_experiment(name=EXPERIMENT_NAME)

response_experiment

Next, we need to train our model.
We will be using the V100 GPUs that are available on the DGX Workstation, which we can use by submitting our job to the `"tensorflow_gpu"` queue.
We will train one LeNet-5 model for use in the membership inference attack.


In [None]:
response_le_net_train = restapi_client.submit_job(
    workflows_file=WORKFLOWS_TAR_GZ,
    experiment_name=EXPERIMENT_NAME,
    entry_point="train",
    entry_point_kwargs=" ".join([
        "-P register_model_name=le_net",
        "-P batch_size=256",
        "-P model_architecture=le_net",
    ]),
    queue="pytorch_gpu",
    timeout="1h",
)

print("Training job for LeNet-5 neural network submitted")
print("")
pprint.pprint(response_le_net_train)

Now that we have an MNIST trained model, we can run the Membership Inference attack on it. The Membership Inference attack essentially trains a classifier on training and non-training data for the original model. It then uses this model to predict whether new data was part of the training set or not.
* batch_size - how many samples to try to use from the training/testing set
* balance_sets - if True, this will truncate the bigger of the training/testing set such that the sets have the same number of members
* attack_model_type - 'nn', 'gb' or 'rf', for neural network, gradient boosting, or random forests. this decides the type of classifier used in the membership inference attack.
* split - decimal value representing the percentage of the available data which should be used for training.


In [None]:
def mlflow_run_id_is_not_known(response_infer):
    return response_infer["mlflowRunId"] is None and response_infer["status"] not in [
        "failed",
        "finished",
    ]

response_infer = restapi_client.submit_job(
    workflows_file=WORKFLOWS_TAR_GZ,
    experiment_name=EXPERIMENT_NAME,
    entry_point="mi",
    entry_point_kwargs=" ".join(
        [
            f"-P model_name=le_net",
            "-P model_version=1",
            "-P batch_size=50000",
            "-P attack_model_type=nn",
            "-P balance_sets=True",
            "-P split=0.9",
        ]
    ),
    queue="pytorch_gpu",
)

print("INFER attack (LE net architecture) job submitted!!")
print("")
print("")

while mlflow_run_id_is_not_known(response_infer):
    time.sleep(1)
    response_infer = restapi_client.get_job_by_id(
        response_infer["jobId"]
    )