# Training and Deploying ML Models using JAX on SageMaker

<div style="border: 2px solid #ff9900; border-radius: 8px; padding: 15px; background-color: #fff3e0; margin-bottom: 10px;">
<strong>⚠️ Compatibility Notice:</strong> This notebook has been tested using <strong>SageMaker Distribution Image 3.7.0</strong> and the <strong>SageMaker Python SDK version 3.4.0</strong>.
</div>

Amazon SageMaker provides you the flexibility to train models using our pre-built machine learning containers or your own bespoke container. We'll refer to these strategies as Bring-Your-Own-Script **(BYOS)** and Bring-Your-Own-Container **(BYOC)** in this tutorial. 

### Bring Your Own JAX Script

In this notebook, we'll show how to extend our optimized TensorFlow containers to train machine learning models using the increasingly popular [JAX library](https://github.com/google/jax). We'll train a fashion MNIST classification model using vanilla JAX, another using `jax.experimental.stax`, and a final model using the [higher level Trax library](https://github.com/google/trax).

For all three patterns, we'll show how the JAX models can be serialized as standard TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model). This enables us to seamlessly deploy the models using the managed and optimized SageMaker TensorFlow inference containers.


### Bring Your Own JAX Container

We've included a dockerfile in this repo directory to show how you can build your own bespoke JAX container with support for GPUs on SageMaker. Unfortunately, the NVIDIA/CUDA Dockerhub containers have a [deletion policy](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md), so we're unable to assert that the container can be built through time. Nonetheless, you can trivially adapt a newer version of the container if your workload requires a custom container. For more information on running BYOC on SageMaker see the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-training-container.html).



In [None]:
!pip install -q -U "sagemaker==3.4.0"

In [None]:
import boto3
from importlib.metadata import version as pkg_version
from sagemaker.core.helper.session_helper import get_execution_role
from sagemaker.core import image_uris
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import SourceCode, Compute

role = get_execution_role()
region = boto3.Session().region_name
sm_client = boto3.client("sagemaker")
sm_runtime = boto3.client("sagemaker-runtime")

print(f"SageMaker SDK version: {pkg_version('sagemaker')}")
print(f"Role: {role}")
print(f"Region: {region}")

In [None]:
# Retrieve TF 2.10 training image (GPU) and inference image (CPU)
tf_training_image = image_uris.retrieve(
    framework="tensorflow", region=region, version="2.10",
    instance_type="ml.g5.xlarge", image_scope="training",
)
tf_inference_image = image_uris.retrieve(
    framework="tensorflow", region=region, version="2.10",
    instance_type="ml.m5.xlarge", image_scope="inference",
)
print(f"Training image: {tf_training_image}")
print(f"Inference image: {tf_inference_image}")

## Installing JAX in SageMaker TensorFlow Containers

When using BYOS with managed SageMaker containers, you can trivially install extra dependencies by providing a `requirements.txt` within the `source_dir` that contains your training scripts. With `ModelTrainer`, specify the `requirements` parameter in `SourceCode` to ensure dependencies are installed prior to executing the training script.

To be specific, any container that has the [sagemaker-training-toolkit](https://github.com/aws/sagemaker-training-toolkit) supports installing additional dependencies from `requirements.txt`


## Serializing models as SavedModel format
In the upcoming training jobs we'll be training a vanilla JAX model, a Stax model, and a Trax model on the [fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist).
The full details of the model can be seen in the `training_scripts/` directory, but it is worth calling out the methods for serialization.

The JAX/Stax models utilize the new jax2tf converter: https://github.com/google/jax/tree/master/jax/experimental/jax2tf

```python
def save_model_tf(prediction_function, params_to_save):
    tf_fun = jax2tf.convert(prediction_function, enable_xla=False)
    param_vars = tf.nest.map_structure(lambda param: tf.Variable(param), params_to_save)

    tf_graph = tf.function(
        lambda inputs: tf_fun(param_vars, inputs),
        autograph=False,
        jit_compile=False,
    )

```


The Trax model utilizes the new trax2keras functionality: https://github.com/google/trax/blob/master/trax/trax2keras.py

```python
def save_model_tf(model_to_save):
    """
    Serialize a TensorFlow graph from trained Trax Model
    :param model_to_save: Trax Model
    """
    keras_layer = trax.AsKeras(model_to_save, batch_size=1)
    inputs = tf.keras.Input(shape=(28, 28, 1))
    hidden = keras_layer(inputs)

    keras_model = tf.keras.Model(inputs=inputs, outputs=hidden)
    keras_model.save("/opt/ml/model/1", save_format="tf")
```

## Train using Vanilla JAX

Note: Our `source_dir` directory contains a `requirements.txt` that will install JAX with CUDA support. We use thh `ml.g5.xlarge` (A10G GPU).

In [None]:
vanilla_jax_trainer = ModelTrainer(
    training_image=tf_training_image,
    source_code=SourceCode(
        source_dir="training_scripts",
        entry_script="train_jax.py",
        requirements="requirements.txt",
    ),
    compute=Compute(instance_type="ml.g5.xlarge", instance_count=1),
    role=role,
    base_job_name="jax",
    hyperparameters={"num_epochs": "3"},
)
vanilla_jax_trainer.train(wait=True, logs=False)

## Train Using JAX Medium-level API Stax

In [None]:
stax_trainer = ModelTrainer(
    training_image=tf_training_image,
    source_code=SourceCode(
        source_dir="training_scripts",
        entry_script="train_stax.py",
        requirements="requirements.txt",
    ),
    compute=Compute(instance_type="ml.g5.xlarge", instance_count=1),
    role=role,
    base_job_name="stax",
    hyperparameters={"num_epochs": "3"},
)
stax_trainer.train(wait=True, logs=False)

## Train Using JAX High-level API Trax

In [None]:
trax_trainer = ModelTrainer(
    training_image=tf_training_image,
    source_code=SourceCode(
        source_dir="training_scripts",
        entry_script="train_trax.py",
        requirements="requirements.txt",
    ),
    compute=Compute(instance_type="ml.g5.xlarge", instance_count=1),
    role=role,
    base_job_name="trax",
    hyperparameters={"train_steps": "1000"},
)
trax_trainer.train(wait=True, logs=False)

## Deploy Models to managed TF Containers
Since we've serialized the models as TensorFlow SavedModel format, deploying these models as endpoints uses `ModelBuilder` with the TF inference image and the S3 model artifacts from training.

In [None]:
from sagemaker.serve import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode

In [None]:
vanilla_jax_mb = ModelBuilder(
    s3_model_data_url=vanilla_jax_trainer._latest_training_job.model_artifacts.s3_model_artifacts,
    image_uri=tf_inference_image,
    role_arn=role,
    instance_type="ml.m5.xlarge",
    mode=Mode.SAGEMAKER_ENDPOINT,
)
vanilla_jax_model = vanilla_jax_mb.build()
vanilla_jax_endpoint = vanilla_jax_mb.deploy(
    endpoint_name="jax-vanilla-v3",
    initial_instance_count=1,
    instance_type="ml.m5.xlarge",
)

In [None]:
stax_mb = ModelBuilder(
    s3_model_data_url=stax_trainer._latest_training_job.model_artifacts.s3_model_artifacts,
    image_uri=tf_inference_image,
    role_arn=role,
    instance_type="ml.m5.xlarge",
    mode=Mode.SAGEMAKER_ENDPOINT,
)
stax_model = stax_mb.build()
stax_endpoint = stax_mb.deploy(
    endpoint_name="jax-stax-v3",
    initial_instance_count=1,
    instance_type="ml.m5.xlarge",
)

In [None]:
trax_mb = ModelBuilder(
    s3_model_data_url=trax_trainer._latest_training_job.model_artifacts.s3_model_artifacts,
    image_uri=tf_inference_image,
    role_arn=role,
    instance_type="ml.m5.xlarge",
    mode=Mode.SAGEMAKER_ENDPOINT,
)
trax_model = trax_mb.build()
trax_endpoint = trax_mb.deploy(
    endpoint_name="jax-trax-v3",
    initial_instance_count=1,
    instance_type="ml.m5.xlarge",
)

## Test Inference Endpoints
This requires TF to be installed on your notebook's kernel as it is used to load testing data. Inference uses `boto3 invoke_endpoint`.

In [None]:
import json
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

In [None]:
def test_image(endpoint_name, test_images, test_labels, image_number):
    np_img = np.expand_dims(np.expand_dims(test_images[image_number], axis=-1), axis=0)
    payload = json.dumps({"instances": np_img.tolist()})

    response = sm_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/json",
        Accept="application/json",
        Body=payload,
    )
    result = json.loads(response["Body"].read().decode("utf-8"))
    pred_y = np.argmax(result["predictions"])

    print("True Label:", test_labels[image_number])
    print("Predicted Label:", pred_y)
    plt.imshow(test_images[image_number])

In [None]:
test_image("jax-vanilla-v3", x_test, y_test, 0)

In [None]:
test_image("jax-stax-v3", x_test, y_test, 0)

In [None]:
test_image("jax-trax-v3", x_test, y_test, 0)

## Optional: Delete the running endpoints

In [None]:
# Clean-Up
for ep_name in ["jax-vanilla-v3", "jax-stax-v3", "jax-trax-v3"]:
    try:
        sm_client.delete_endpoint(EndpointName=ep_name)
        print(f"Deleted endpoint: {ep_name}")
    except Exception as e:
        print(f"Failed to delete endpoint {ep_name}: {e}")
    try:
        sm_client.delete_endpoint_config(EndpointConfigName=ep_name)
    except Exception:
        pass

## Conclusion

This notebook demonstrated how to:

1. Create a JAX training script
2. Train models using `ModelTrainer` with the SageMaker TensorFlow training container
3. Deploy models using `ModelBuilder` with the SageMaker TensorFlow inference container
4. Test the deployed models via `boto3 invoke_endpoint`
5. Clean up resources

The JAX framework provides excellent performance for machine learning workloads, and SageMaker makes it easy to scale your JAX models in the cloud.

## Next Steps

- Experiment with different JAX model architectures
- Use real datasets instead of synthetic data
- Implement more sophisticated preprocessing
- Add model monitoring and logging
- Explore distributed training with multiple instances