# Model Training

In this notebook, we'll train a LightGBM model using Amazon SageMaker, so
we have an example trained model to explain.

You can bring also bring your own trained models to explain. See the
customizing section for more details.

**Note**: When running this notebook on SageMaker Studio, you should make
sure the 'SageMaker JumpStart Data Science 1.0' image/kernel is used. You
can run all cells or step through them one at a time.

<p align="center">
  <img src="https://github.com/awslabs/sagemaker-explaining-credit-decisions/raw/master/docs/architecture_diagrams/stage_2.png" width="1000px">
</p>

We then import a variety of packages that will be used throughout
the notebook. One of the most important packages used throughout this
solution is the Amazon SageMaker Python SDK (i.e. `import sagemaker`). We
also import modules from our own custom package that can be found at
`./package`.

In [None]:
from pathlib import Path
from sagemaker.sklearn import SKLearn
import shutil
import sys

sys.path.insert(0, '../package')
from package import config, utils
from package.sagemaker import containers

## Container
We now build our custom Docker image that will be used for model training
and deployment. It extends the official Amazon SageMaker framework image
for Scikit-learn, by adding additional packages such as
[LightGBM](https://lightgbm.readthedocs.io/en/latest/) and
[SHAP](https://github.com/slundberg/shap). After building the image, we
upload it to our solution's Amazon ECR repository.

In [None]:
!rm -f ../containers/model.zip

In [None]:
shutil.make_archive('../containers/model', 'zip', '../containers/model/')

In [None]:
!aws s3 cp ../containers/model.zip s3://$config.S3_BUCKET/containers/model.zip --quiet

In [None]:
containers.build(project_name=config.CODE_BUILD_PROJECT)

## Model Training
Amazon SageMaker provides two methods for training and deploying models.
You can start by quickly testing and debuging models on the Amazon
SageMaker Notebook instance using local mode. After this, you can scale
up training with SageMaker mode on dedicated instances and deploy the
model on dedicated instance too. Since this is a pre-developed solution
we'll be using SageMaker mode.

Up next, we configure our SKLearn estimator. We will use it to coordinate
model training and deployment. We reference our custom container (see
`image_name`) and our custom code (see `entry_point` and `dependencies`).
At this stage, we also reference the instance type (and instance count)
that will be used during training, and the hyperparmeters we wish to use.
And lastly we set the `output_path` for trained model artifacts and
`code_location` for a snapshot of the training script that was used.

**Note**: when customizing the solution, you can enable enhanced logging
by setting the `container_log_level=logging.DEBUG` on the `SKLearn`
estimator object (after `import logging`).

In [None]:
hyperparameters = {
    "tree-n-estimators": 42,
    "tree-max-depth": 2,
    "tree-min-child-samples": 1,
    "tree-boosting-type": "dart"
}

current_folder = utils.get_current_folder(globals())
estimator = SKLearn(
    image_uri=config.ECR_IMAGE,
    entry_point='entry_point.py',
    source_dir=str(Path(current_folder, '../containers/model/src').resolve()),
    dependencies=[str(Path(current_folder, '../package/package').resolve())],
    hyperparameters=hyperparameters,
    role=config.IAM_ROLE,
    instance_count=1,
    instance_type='ml.c5.xlarge',
    output_path='s3://' + str(Path(config.S3_BUCKET, config.OUTPUTS_S3_PREFIX)),
    code_location='s3://' + str(Path(config.S3_BUCKET, config.OUTPUTS_S3_PREFIX)),
    base_job_name=config.SOLUTION_PREFIX,
    tags=[{'Key': config.TAG_KEY, 'Value': config.SOLUTION_PREFIX}]
)

With our estimator now initialized, we can start the Amazon SageMaker
training job. Since our entry point script expects a number of data
channels to be defined, we can provide them when calling `fit`. When
referencing `s3://` folders, the contents of these folders will be
automatically downloaded from Amazon S3 before the entry point script is
run. When using local mode, it's possible to avoid this data transfer and
reference local folder using the `file://` prefix instead: e.g.
`{'schemas': 'file://' + str(schema_folder)}`

You can expect this step to take approximately 5 minutes.

In [None]:
estimator.fit({
    'schemas': 's3://' + str(Path(config.S3_BUCKET, config.SCHEMAS_S3_PREFIX)),
    'data_train': 's3://' + str(Path(config.S3_BUCKET, config.DATASETS_S3_PREFIX, 'data_train')),
    'label_train': 's3://' + str(Path(config.S3_BUCKET, config.DATASETS_S3_PREFIX, 'label_train')),
    'data_test': 's3://' + str(Path(config.S3_BUCKET, config.DATASETS_S3_PREFIX, 'data_test')),
    'label_test': 's3://' + str(Path(config.S3_BUCKET, config.DATASETS_S3_PREFIX, 'label_test'))
})

Our Amazon SageMaker training job has now completed, and we should have a
number of trained model artifacts that can be deployed and used for
explanations.

## Customization

We have provided an example of model training above, so that we have a
trained model to explain, but our solution is customizable if you have
your own models. You can choose to retrain your models on Amazon
SageMaker or train your models is some other way of your choosing.

When re-training models on Amazon SageMaker you should modify the
training script found at `./package/sagemaker/estimator_fns.py`. You
should modify the `train_fn` function as required and change any of the
other training functions (found in `./package/machine_learning` for
example). You may need to modify the dependencies too depending on you
model and these can be adjusted in `./containers/model/requirements.txt`.

When bringing your own trained model, you will need to upload all of the
model assets to Amazon S3 (in the solution bucket): e.g. trained
preprocessors, model weights and feature schemas (i.e. data schema after
feature engineering). Amazon SageMaker expects all of these model assets
to be packages up as a `model.tar.gz`.

## Next Stage

Up next we'll deploy the model explainer to a HTTP endpoint using Amazon
SageMaker and visualize the explanations.

[Click here to continue.](./3_endpoint.ipynb)