# Tensorflow MNIST Classifier demo

This notebook contains an end-to-end demostration of Dioptra that can be run on any modern laptop.
Please see the [example README](README.md) for instructions on how to prepare your environment for running this example.

## Setup

Below we import the necessary Python modules and ensure the proper environment variables are set so that all the code blocks will work as expected,

In [None]:
EXPERIMENT_NAME = "mnist_fgm"
EXPERIMENT_DESC = "applying the fast gradient sign (FGM) attack to a classifier trained on MNIST"
QUEUE_NAME = 'tensorflow_cpu'
QUEUE_DESC = 'Tensorflow CPU Queue'
PLUGIN_FILES = '../task-plugins/dioptra_custom/vc/'
MODEL_NAME = "mnist_classifier"

# Default address for accessing the RESTful API service
RESTAPI_ADDRESS = "http://localhost:20080"

# Default address for accessing the MLFlow Tracking server
MLFLOW_TRACKING_URI = "http://localhost:35000"

In [None]:
# Import packages from the Python standard library
import importlib.util
import os
import sys
import pprint
import time
import warnings
from pathlib import Path
from IPython.display import display, clear_output
import logging
import structlog
import yaml

# Filter out warning messages
warnings.filterwarnings("ignore")
structlog.configure(
    wrapper_class=structlog.make_filtering_bound_logger(logging.CRITICAL),
)

def register_python_source_file(module_name: str, filepath: Path) -> None:
    """Import a source file directly.

    Args:
        module_name: The module name to associate with the imported source file.
        filepath: The path to the source file.

    Notes:
        Adapted from the following implementation in the Python documentation:
        https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    """
    spec = importlib.util.spec_from_file_location(module_name, str(filepath))
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
register_python_source_file("scripts", Path("..", "scripts", "__init__.py"))

# Register the examples/scripts directory as a Python module
from scripts.client import DioptraClient
from scripts.utils import make_tar

# Set DIOPTRA_RESTAPI_URI variable if not defined, used to connect to RESTful API service
if os.getenv("DIOPTRA_RESTAPI_URI") is None:
    os.environ["DIOPTRA_RESTAPI_URI"] = RESTAPI_ADDRESS

# Set MLFLOW_TRACKING_URI variable, used to connect to MLFlow Tracking service
if os.getenv("MLFLOW_TRACKING_URI") is None:
    os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI

## Dataset

We obtained a copy of the MNIST dataset when we ran `download_data.py` script. If you have not done so already, see [How to Obtain Common Datasets](https://pages.nist.gov/dioptra/getting-started/acquiring-datasets.html).
The training and testing images for the MNIST dataset are stored within the `/dioptra/data/Mnist` directory as PNG files that are organized into the following folder structure,

    Mnist
    ├── testing
    │   ├── 0
    │   ├── 1
    │   ├── 2
    │   ├── 3
    │   ├── 4
    │   ├── 5
    │   ├── 6
    │   ├── 7
    │   ├── 8
    │   └── 9
    └── training
        ├── 0
        ├── 1
        ├── 2
        ├── 3
        ├── 4
        ├── 5
        ├── 6
        ├── 7
        ├── 8
        └── 9

The subfolders under `training/` and `testing/` are the classification labels for the images in the dataset.
This folder structure is a standardized way to encode the label information and many libraries can make use of it, including the Tensorflow library that we are using for this particular demo.

## Submit and run jobs

To connect with the endpoint, we will use a client class defined in the `examples/scripts/client.py` file that is able to connect with the Dioptra RESTful API using the HTTP protocol.
We connect using the client below.
The client uses the environment variable `DIOPTRA_RESTAPI_URI`, which we configured at the top of the notebook, to figure out how to connect to the Dioptra RESTful API.

In [None]:
client = DioptraClient()

It is necessary to login to the RESTAPI to be able to perform any functions. Here we create a user if it is not created already, and login with it.

In [None]:
try:
    client.users.create('pluginuser','pluginuser@dioptra.nccoe.nist.gov','pleasemakesuretoPLUGINthecomputer','pleasemakesuretoPLUGINthecomputer')
except:
    pass # ignore if user exists already
client.auth.login('pluginuser','pleasemakesuretoPLUGINthecomputer')

The following function can be used to clear all experiments, entrypoints, jobs, models, plugins, tags, and queues in the database, if a fresh start is desired. It is not currently used anywhere in this notebook, but is included for utility.

In [None]:
def delete_all():
    for d in client.experiments.get_all(pageLength=100000)['data']:
        client.experiments.delete_by_id(d['id'])
    for d in client.entrypoints.get_all(pageLength=100000)['data']:
        client.entrypoints.delete_by_id(d['id'])
    for d in client.jobs.get_all(pageLength=100000)['data']:
        client.jobs.delete_by_id(d['id'])
    for d in client.models.get_all(pageLength=100000)['data']:
        client.models.delete_by_id(d['id'])
    for d in client.plugins.get_all(pageLength=100000)['data']:
        try:
            client.plugins.delete_by_id(d['id'])
        except:
            pass
    for d in client.tags.get_all(pageLength=100000)['data']:
        client.tags.delete_by_id(d['id'])
    for d in client.pluginParameterTypes.get_all(pageLength=100000)['data']:
        try:
            client.pluginParameterTypes.delete_by_id(d['id'])
        except:
            pass
    for d in client.queues.get_all(pageLength=100000)['data']:
        client.queues.delete_by_id(d['id'])

The following functions are used for registering plugins located in the `../examples/task-plugins/` folder, associating them with endpoints in the ./src/ folder, and then associating those endpoints with an experiment. When `run_experiment` is called, it will create plugins based on the YML files provided, and upload any additional files in the directory specified by `PLUGIN_FILES` at the top of the notebook.

In [None]:
basic_types = ['integer', 'string', 'number', 'any', 'boolean', 'null']

def create_or_get_experiment(group, name, description, entrypoints):
    found = None
    for exp in client.experiments.get_all(search=name,pageLength=100000)['data']:
        if exp['name'] == name:
            found = exp
    if (found != None):
        client.experiments.modify_by_id(found['id'], name, description, entrypoints)
        return found
    else:
        return client.experiments.create(group, name, description, entrypoints)
def create_or_get_entrypoints(group, name, description, taskGraph, parameters, queues, plugins):
    found = None
    for entrypoint in client.entrypoints.get_all(search=name,pageLength=100000)['data']:
        if entrypoint['name'] == name:
            found = entrypoint
    if (found != None):
        client.entrypoints.modify_by_id(found['id'], name, description, taskGraph, parameters, queues)
        client.entrypoints.add_plugins_by_entrypoint_id(found['id'], plugins)
        return found
    else:
        return client.entrypoints.create(group, name, description, taskGraph, parameters, queues, plugins)
def create_or_get_plugin_type(group, name, description, structure):
    ret = None
    for pt in client.pluginParameterTypes.get_all(pageLength=100000)['data']:
        if (pt['name'] == name):
            ret = pt
    if (ret is None):
        ret = client.pluginParameterTypes.create(group, name, description, structure)
    return ret
def find_plugin_type(name, types):
    for t in types.keys():
        if t == name:
            return create_or_get_plugin_type(1, name, name, types[t])['id']
    for t in basic_types:
        if t == name:
            return create_or_get_plugin_type(1, name, 'primitive', {})['id']

    print("Couldn't find type", name, "in types definition.")

def create_or_get_queue(group, name, description):
    ret = None
    for queue in client.queues.get_all(pageLength=100000)['data']:
        if queue['name'] == name:
            ret = queue
    if (ret is None):
        ret = client.queues.create(group, name, description)
    return ret
def plugin_to_py(plugin):
    return '../task-plugins/' + '/'.join(plugin.split('.')[:-1]) + '.py'
def create_inputParam_object(inputs, types):
    ret = []
    for inp in inputs:
        if 'name' in inp:
            inp_name = inp['name']
            inp_type = inp['type']
        else:
            inp_name = list(inp.keys())[0]
            inp_type = inp[inp_name]
        if 'required' in inp:
            inp_req = inp['required']
        else:
            inp_req = True
        inp_type = find_plugin_type(inp_type, types)
        ret += [{
           'name': inp_name,
           'parameterType': inp_type,
           'required': inp_req
        }]
    return ret
def create_outputParam_object(outputs, types):
    ret = []
    for outp in outputs:
        if isinstance(outp, dict):
            outp_name = list(outp.keys())[0]
            outp_type = outp[outp_name]
        else:
            outp_name = outp
            outp_type = outputs[outp_name]
        outp_type = find_plugin_type(outp_type, types)
        ret += [{
           'name': outp_name,
           'parameterType': outp_type,
        }]
    return ret

def read_yaml(filename):
    with open(filename) as stream:
        try:
            ret = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    return ret
def register_basic_types(declared):
    for q in basic_types:
        type_def = create_or_get_plugin_type(1, q, 'primitive', {})
    for q in declared:
        type_def = create_or_get_plugin_type(1, q, 'declared', declared[q])
def get_plugins_to_register(yaml_file, plugins_to_upload=None):
    plugins_to_upload = {} if plugins_to_upload is None else plugins_to_upload
    yaml = read_yaml(yaml_file)
    task_graph = yaml['graph']
    plugins = yaml['tasks']
    types = yaml['types']
    
    register_basic_types(types)
    tasks = []
    for plugin in plugins:
        name = plugin
        definition = plugins[plugin]
        python_file = plugin_to_py(definition['plugin'])
        upload = {}
        upload['name'] = name
        if 'inputs' in definition:
            inputs = definition['inputs']
            upload['inputParams'] = create_inputParam_object(inputs, types)
        else:
            upload['inputParams'] = []
        if 'outputs' in definition:
            outputs = definition['outputs']
            upload['outputParams'] = create_outputParam_object(outputs, types) 
        else:
            upload['outputParams'] = []
        if (python_file in plugins_to_upload):
            plugins_to_upload[python_file] += [upload]
        else:
            plugins_to_upload[python_file] = [upload]
    return plugins_to_upload
def create_or_get_plugin(group, name, description):
    ret = None
    for plugin in client.plugins.get_all(search=name,pageLength=100000)['data']:
        if plugin['name'] == name:
            ret = plugin
    if (ret is None):
        ret = client.plugins.create(group, name, description)
    return ret
def create_or_modify_plugin_file(plugin_id, filename, contents, description, tasks):
    found = None
    for plugin_file in client.plugins.files.get_files_by_plugin_id(plugin_id, pageLength=100000)['data']:
        if plugin_file['filename'] == filename:
            found = plugin_file
    if (found != None):
        return client.plugins.files.modify_files_by_plugin_id_file_id(plugin_id, found['id'], filename, contents, description, tasks)
    else:
        return client.plugins.files.create_files_by_plugin_id(plugin_id, filename, contents, description, tasks)
def register_plugins(group, plugins_to_upload):
    plugins = []
    for plugin_file in plugins_to_upload.keys():
        plugin_path = Path(plugin_file)
        contents = plugin_path.read_text().replace("\r", '')
        tasks = plugins_to_upload[plugin_file]
        filename = plugin_path.name
        description = 'custom plugin for ' + filename
        plugin_id = create_or_get_plugin(group, plugin_path.parent.name, description)['id']
        plugins += [plugin_id]
        uploaded_file = create_or_modify_plugin_file(plugin_id, filename, contents, description, tasks)
    return list(set(plugins))
def create_parameters_object(params, modify):
    ret = []
    type_map = {'int': 'float', 'float':'float', 'string':'string'}
    for p in params:
        if (type(params[p]).__name__ in type_map.keys()):
            paramType = type_map[type(params[p]).__name__]
            paramType='string' # TODO: remove if backend can handle types correctly
            defaultValue = str(params[p])
        else:
            defaultValue = str(params[p])
            paramType = 'string'

        if p in modify.keys():
            defaultValue = str(modify[p])
        name = p
        param_obj = {
            'name': name,
            'defaultValue': str(defaultValue),
            'parameterType': paramType
        }
        ret += [param_obj]
    return ret
def get_graph_for_upload(yaml_text):
    i = 0
    for line in yaml_text:
        if line.startswith("graph:"):
            break
        i += 1
    return ''.join(yaml_text[i+1:])
def get_parameters_for_upload(yaml_text):
    i = 0
    for line in yaml_text:
        if line.startswith("parameters:"):
            start = i
        if line.startswith("tasks:"):
            break
        i += 1
    return yaml_text[start:i+1]
def register_entrypoint(group, name, description, queues, plugins, yaml_file, modify_params=None):
    modify_params = {} if modify_params is None else modify_params
    yaml = read_yaml(yaml_file)
    #task_graph = yaml['graph']
    parameters = yaml['parameters']
    
    with open(yaml_file, 'r') as f:
        lines = f.readlines()
    task_graph = get_graph_for_upload(lines).replace('\r','')
    
    entrypoint = create_or_get_entrypoints(1, name, description, task_graph, create_parameters_object(parameters, modify_params), queues, plugins)
    return entrypoint
def add_missing_plugin_files(location, upload):
    p = Path(location)
    for child in p.iterdir():
        if (child.name.endswith('.py')):
            if (str(child) not in upload.keys()):
                upload[str(child)] = []
    return upload

`run_experiment` uses the helper functions above to do the following tasks:
    - create a queue specified by `QUEUE_NAME` if needed
    - upload the plugins used by the specified `entrypoint` 
    - upload any other plugin files in the directory `PLUGIN_FILES`
    - register the entrypoint in Dioptra
    - create the experiment (if needed) and associate the entrypoint with the experiment
    - start a job for the specified `entrypoint` on the queue `QUEUE_NAME`
Note that any parameters passed in to `parameters` will overwrite the defaults in the specified YML file.

In [None]:
def run_experiment(entrypoint, entrypoint_name, entrypoint_desc, job_time_limit, parameters={}):
    upload = get_plugins_to_register(entrypoint, {})
    upload = add_missing_plugin_files(PLUGIN_FILES, upload)
    queue = create_or_get_queue(1, QUEUE_NAME, QUEUE_DESC)
    queues = [queue['id']]
    plugins = register_plugins(1,upload)
    entrypoint = register_entrypoint(1, entrypoint_name, entrypoint_desc, queues, plugins, entrypoint, parameters)
    experiment = create_or_get_experiment(1, EXPERIMENT_NAME, EXPERIMENT_DESC, [entrypoint['id']])
    return client.experiments.create_jobs_by_experiment_id(experiment['id'], entrypoint_desc, queue['id'], entrypoint['id'], {}, job_time_limit)

`wait_for_job` stalls til the previous job was finished, which is useful for jobs which depend on the output of other jobs.

In [None]:
def wait_for_job(job, job_name):
    n = 0
    while job['status'] != 'finished':  
        job = client.jobs.get_by_id(job['id'])
        time.sleep(1)
        clear_output(wait=True)
        display("Waiting for job." + "." * (n % 3) )
        n += 1
    clear_output(wait=True)
    display(f"Job finished. Starting {job_name} job.")
    

Next, we need to train our model. This particular entrypoint uses a LeNet-5 model.
Depending on the specs of your computer, it can take 5-20 minutes or longer to complete.
If you are fortunate enough to have access to a dedicated GPU, then the training time will be much shorter.

In [None]:
entrypoint = 'src/train.yml'
entrypoint_name = 'train'
entrypoint_desc = 'training a classifier on MNIST'
job_time_limit = '1h'

training_job = run_experiment(entrypoint, 
                              entrypoint_name, 
                              entrypoint_desc,
                              job_time_limit,
                              {"epochs_p":1})


Now that we have trained a model, next we will apply the fast-gradient method (FGM) evasion attack on it to generate adversarial images.

This specific workflow is an example of jobs that contain dependencies, as the metric evaluation jobs cannot start until the adversarial image generation jobs have completed, and the adversarial image generation job cannot start until the training job has completed.

Note that the training_job id is needed to tell the FGM attack which model to generate examples against.

In [None]:
entrypoint = 'src/fgm.yml'
entrypoint_name = 'fgm'
entrypoint_desc = 'generating examples on mnist_classifier using the fgm attack'
job_time_limit = '1h'

wait_for_job(training_job, entrypoint_name)
fgm_job = run_experiment(entrypoint,
                         entrypoint_name,
                         entrypoint_desc,
                         job_time_limit,
                         {"training_job_id": training_job['id']})


Finally, we can test out the results of our adversarial attack on the model we trained earlier. This will wait for the FGM job to finish, and then evaluate the model's performance on the adversarial examples. Note that we need to know both the `fgm_job` id as well as the `training_job` id, so that this entrypoint knows which run's adversarial examples to test against which model. 

The previous runs are all stored in Dioptra as well, so you can always go back later and retrieve examples, models, and even the code used to create them.

In [None]:
entrypoint = 'src/infer.yml'
entrypoint_name = 'infer'
entrypoint_desc = 'evaluating performance of mnist_classifier on generated fgm examples'
job_time_limit = '1h'

wait_for_job(fgm_job, entrypoint_name)
infer_job = run_experiment(entrypoint, 
                           entrypoint_name,
                           entrypoint_desc,
                           job_time_limit,
                           {"fgm_job_id": fgm_job['id'], "training_job_id": training_job['id']})


In [None]:
from mlflow.tracking import MlflowClient
mlflow_client = MlflowClient()
mlflow_runid = client.jobs.get_mlflow_run_id(infer_job['id'])['mlflowRunId'].replace('-','') # why
print(mlflow_runid)
mlflow_run = mlflow_client.get_run(mlflow_runid)
pprint.pprint(mlflow_run.data.metrics)