# CleanLab Mnist Example

This notebook contains an end-to-end demostration of Dioptra that can be run on any modern laptop.
Please see the [example README](README.md) for instructions on how to prepare your environment for running this example.

### Setup

Below we import the necessary Python modules and ensure the proper environment variables are set so that all the code blocks will work as expected,

In [None]:
EXPERIMENT_NAME = "cleanlab_exp"
EXPERIMENT_DESC = "experimenting with cleanlab cleaning techniques with a classifier trained on MNIST"
QUEUE_NAME = 'tensorflow_cpu'
QUEUE_DESC = 'Tensorflow CPU Queue'
MODEL_NAME = "mnist_classifier"

# Default address for accessing the RESTful API service
RESTAPI_ADDRESS = "http://localhost:5000"

# Default address for accessing the MLFlow Tracking server
MLFLOW_TRACKING_URI = "http://localhost:35000"

In [None]:
# Import packages from the Python standard library
import importlib.util
import json
import os
import sys
import pprint
import time
import warnings
from IPython.display import display, clear_output
import logging
import structlog
from pathlib import Path

# Filter out warning messages
warnings.filterwarnings("ignore")
structlog.configure(
    wrapper_class=structlog.make_filtering_bound_logger(logging.ERROR),
)

from dioptra.client import connect_json_dioptra_client, connect_response_dioptra_client, select_files_in_directory, select_one_or_more_files

# Set DIOPTRA_API variable if not defined, used to connect to RESTful API service
if os.getenv("DIOPTRA_API") is None:
    os.environ["DIOPTRA_API"] = RESTAPI_ADDRESS

# Set MLFLOW_TRACKING_URI variable, used to connect to MLFlow Tracking service
if os.getenv("MLFLOW_TRACKING_URI") is None:
    os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI

In [None]:
def wait_for_job(job, job_name, quiet=False):
    n = 0
    while job['status'] not in ['finished', 'failed']:
        job = client.jobs.get_by_id(job['id'])
        time.sleep(1)
        if not quiet:
            clear_output(wait=True)
            display("Waiting for job." + "." * (n % 3) )
        n += 1
    if not quiet:
        if job['status'] == 'finished':
            clear_output(wait=True)
            display(f"Job finished. Starting {job_name} job.")
        else:
            raise Exception("Previous job failed. Please see tensorflow-cpu logs for details.")
    

### Dataset

We obtained a copy of the MNIST dataset when we ran `download_data.py` script. If you have not done so already, see [How to Obtain Common Datasets](https://pages.nist.gov/dioptra/getting-started/acquiring-datasets.html).
The training and testing images for the MNIST dataset are stored within the `/dioptra/data/Mnist` directory as PNG files that are organized into the following folder structure,

    Mnist
    ├── testing
    │   ├── 0
    │   ├── 1
    │   ├── 2
    │   ├── 3
    │   ├── 4
    │   ├── 5
    │   ├── 6
    │   ├── 7
    │   ├── 8
    │   └── 9
    └── training
        ├── 0
        ├── 1
        ├── 2
        ├── 3
        ├── 4
        ├── 5
        ├── 6
        ├── 7
        ├── 8
        └── 9

The subfolders under `training/` and `testing/` are the classification labels for the images in the dataset.
This folder structure is a standardized way to encode the label information and many libraries can make use of it, including the Tensorflow library that we are using for this particular demo.

### Login to Dioptra and setup RESTAPI client

To connect with the endpoint, we will use a client class defined in the `examples/scripts/client.py` file that is able to connect with the Dioptra RESTful API using the HTTP protocol.
We connect using the client below.
The client uses the environment variable `DIOPTRA_API`, which we configured at the top of the notebook, to figure out how to connect to the Dioptra RESTful API.

In [None]:
#client = connect_response_dioptra_client()
client = connect_json_dioptra_client()

It is necessary to login to the RESTAPI to be able to perform any functions. Here we create a user if it is not created already, and login with it.

In [None]:
try:
    client.users.create(
        username='user',
        email='email',
        password='pass'
    )
except:
    pass # ignore if user exists already

client.auth.login(
    username='user',
    password='pass'
)


### Upload all the entrypoints

In [None]:
# import from local filesystem
logging.basicConfig(level=logging.DEBUG) # Sets the root logger level

response = client.workflows.import_resources(group_id=1,
                                             source=select_files_in_directory("../extra/", recursive=True),
                                             config_path="dioptra.toml",
                                             resolve_name_conflicts_strategy="overwrite",
                                            )
resources = response["resources"]



train_ep = resources["entrypoints"]["Train"]
metrics_ep = resources["entrypoints"]["Metrics"]
clean_ep = resources["entrypoints"]["Clean"]
poison_ep = resources["entrypoints"]["Poison"]
filter_ep = resources["entrypoints"]["Filter"]
train_mlflow_ep = resources["entrypoints"]["Train from MLFlow"]

entrypoints = [train_ep, metrics_ep, clean_ep, poison_ep, filter_ep, train_mlflow_ep ]

In [None]:
try:
    experiment = client.experiments.create(group_id=1, name=EXPERIMENT_NAME, description=EXPERIMENT_DESC)
except:
    experiment = client.experiments.get(search=f"name:'{EXPERIMENT_NAME}'")["data"][0]

try:
    queue = client.queues.create(group_id=1, name=QUEUE_NAME, description=QUEUE_DESC)
except:
    queue = client.queues.get(search=f"name:'{QUEUE_NAME}'")["data"][0]

experiment_id = experiment['id']
queue_id = queue['id']

client.experiments.entrypoints.create(experiment_id=experiment_id, entrypoint_ids=entrypoints)

for entrypoint in entrypoints:
    client.entrypoints.queues.create(entrypoint_id=entrypoint, queue_ids=[queue_id])



### Helper functions to submit jobs

In [None]:
def run_job(experiment_id, queue_id, ep, title, prev_job_id=False, latest_model=False, args=None, prev_job=None, job_time_limit='1h'):
    args = {} if args is None else args
    if prev_job is not None:
        wait_for_job(prev_job, title, quiet=False)
    if prev_job_id and 'id' in prev_job.keys():
        args['job_id'] = str(prev_job['id'])
    if latest_model:
        args['model_name'] = MODEL_NAME 
        args['model_version'] = str(-1)
    job = client.experiments.jobs.create(
        experiment_id=experiment_id,
        description=f"{title} job for {experiment_id}",
        queue_id=queue_id,
        entrypoint_id=ep,
        values=args,
        timeout=job_time_limit
    )
    return job


In [None]:
def poison(experiment_id, queue_id, poison_ep, percent, job_time_limit='1h'):
    arg_dict = {
         "percent_poison": percent,
    } 
    poison_job = run_job(experiment_id, queue_id, poison_ep, "poison", latest_model=True, args=arg_dict, job_time_limit=job_time_limit)
    return poison_job

In [None]:
def clean(experiment_id, queue_id, clean_ep, prev_job, job_time_limit='1h'):
    clean_job = run_job(experiment_id, queue_id, clean_ep, "clean", prev_job_id=True, latest_model=True, args=None, prev_job=prev_job, job_time_limit=job_time_limit)
    return clean_job

In [None]:
def filter_data(experiment_id, queue_id, filter_ep, prev_job, job_time_limit='1h'):
    filter_job = run_job(experiment_id, queue_id, filter_ep, "filter_data", prev_job_id=True, latest_model=False, args=None, prev_job=prev_job, job_time_limit=job_time_limit)
    return filter_job

In [None]:
def train_mlflow(experiment_id, queue_id, train_mlflow_ep, data_dir, tar_name, model_name, prev_job, job_time_limit='1h'):
    arg_dict = {
        "data_dir": data_dir,
        "tar_name": tar_name,
        "register_model_name": model_name,
        "epochs": "3",
    } 
    train_job = run_job(experiment_id, queue_id, train_mlflow_ep, "train_mlflow", prev_job_id=True, latest_model=False, args=arg_dict, prev_job=prev_job, job_time_limit=job_time_limit)
    return train_job

### Training a Regular MNIST LeNet Model

This trains a MNIST LeNet model, we also use this later for our poisoning model

In [None]:
job_time_limit = '1h'

training_job = client.experiments.jobs.create(
    experiment_id=experiment_id, 
    description=f"training job for {experiment_id}", 
    queue_id=queue_id,
    entrypoint_id=train_ep, 
    values={"epochs":"3"}, 
    timeout=job_time_limit
)

### Training based on Cleanlab outputs after poisoning

The next few jobs poisons a dataset and trains a model on the poisoned dataset.

It then runs cleanlab's analysis on the poisoned dataset, filters out all potentially problematic data, and trains a new model on the result.

#### Poisoning Attack

In [None]:
poisoning = poison(experiment_id, queue_id, poison_ep, ".99", job_time_limit='1h')

#### Training a Poisoned Model

In [None]:
poison_model = train_mlflow(experiment_id, queue_id, train_mlflow_ep, "poison_testing", "poison_testing.tar.gz", "poisoned_mnist_classifier", poisoning)

#### CleanLab Report

In [None]:
cleaning = clean(experiment_id, queue_id, clean_ep, poisoning, job_time_limit='1h')

#### Filter Dataset

In [None]:
filtering = filter_data(experiment_id, queue_id, filter_ep, cleaning, job_time_limit='1h')

#### Training on filtered data

In [None]:
filtered_model = train_mlflow(experiment_id, queue_id, train_mlflow_ep, "cleaned_data", "cleaned_data.tar.gz", "cleaned_mnist_classifier", filtering)