# Develop support for a new model with NeuronX Distributed Inference

In this notebook you will learn how to develop support for a new model with NeuronX Distributed Inference (NxD). NxD is a Python package developed by Annapurna Labs that enables you to shard, compile, train, and host PyTorch models on Trainium and Inferentia instances. We develop two key packages demonstrating how to use this, [NxD Inference](https://github.com/aws-neuron/neuronx-distributed-inference/tree/main) and [NxD Training](https://github.com/aws-neuron/neuronx-distributed-training). This notebook focuses on inference. You will learn how to develop support for a new model in NxD Inference through the context of Llama 3.2, 1B.

#### Overview
1. Check dependencies for AWS Neuron SDK
2. Accept the Meta usage terms and download the model from Hugging Face.
3. Learn how to invoke the model step-by-step
   - Load the model from a local path.
   - Shard and compile it for Trainium.
   - Download and tokenize the dataset
   - Invoke the model with prompts
4. Learn how to modify the underlying APIs to work with your own models

#### Prerequisites
This notebook was developed on a trn1.2xlarge instance, using the latest Amazon Linux DLAMI. Both the Amazon Linux and Ubuntu Neuron DLAMI's have preinstalled Python virtual environments with all the basic software packages included. The virtual environment used to develop this notebook is located at this path in both Amazon Linux and Ubuntu DLAMIs:  `/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference`. 

### Step 1. Import NxD Inference packages

If you are running this notebook in the virtual environment for NxD Inference, then the package should already be installed. Let's verify that with the following import.

In [None]:
import neuronx_distributed_inference

### Step 2. Accept the Meta usage terms and download the model

If you would like to use the model directly from Meta, you'll need to navigate over to the Hugging Face hub for Llama 3.2 1B [here](https://huggingface.co/meta-llama/Llama-3.2-1B). Log in to the Hub, accept the usage term, and request access to the model. Once access has been granted, copy your Hugging Face token and paste it into the download command below.

If you do not have your token readily available, you can proceed with the alternative model shown below.

In [None]:
# helpful packages to speed up the download
!pip install hf_transfer "huggingface_hub[cli]"

We'll download the `NousResearch/Llama3.2-1B` model here.

In [None]:
!hf download NousResearch/Llama-3.2-1B --local-dir /home/ubuntu/environment/models/llama/

### Step 3. Establish model configs
Next, you'll point to the local model files and establish config objects. Each of these configs are helpful in successfully invoking the model.

In [None]:
# the original checkpoint
model_path = '/home/ubuntu/environment/models/llama/'

In [None]:
# where your NxD trace will go
traced_model_path = '/home/ubuntu/environment/models/traced_llama'

In [None]:
import torch
from transformers import AutoTokenizer, GenerationConfig

from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig
from neuronx_distributed_inference.models.llama.modeling_llama import LlamaInferenceConfig, NeuronLlamaForCausalLM
from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config
from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params

# torch.manual_seed(0)

In [None]:
# update the generation config to address a trailing comma
!cp generation_config.json $model_path/

In [None]:
# Initialize configs 
generation_config = GenerationConfig.from_pretrained(model_path)

# Some sample overrides for generation
generation_config_kwargs = {
    "do_sample": True,
    "top_k": 1,
    "pad_token_id": generation_config.eos_token_id,
}
generation_config.update(**generation_config_kwargs)

In [None]:
neuron_config = NeuronConfig(
    tp_degree=2,
    batch_size=2,
    max_context_length=32,
    seq_len=64,
    on_device_sampling_config=OnDeviceSamplingConfig(top_k=1),
    enable_bucketing=True,
    flash_decoding_enabled=False
)

# Build the Llama Inference config
config = LlamaInferenceConfig(
    neuron_config,
    load_config=load_pretrained_config(model_path),
)

### Step 4. Shard and compile the model
The NeuronX compiler will optimize your model for Trainium hardware, ultimately generating the assembly code that executes your operations. We will invoke that compiler now. Generally, it's suggested to compile for some of the larger input and output shapes for your model, while using bucketing to optimize performance. Both of those are handled for you automatically with NxD.

With NxD, this step also shards your checkpoint for the TP degree that you defined above. Compilation can take some time, for a 1B model this should run for a few minutes.

In [None]:
model = NeuronLlamaForCausalLM(model_path, config)
model.compile(traced_model_path)

Once compilation is complete your new model is saved and ready to load! 

### Step 5. Download the tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(traced_model_path)

### Step 6. Load the traced model

In [None]:
model = NeuronLlamaForCausalLM(traced_model_path)
model.load(traced_model_path)
tokenizer = AutoTokenizer.from_pretrained(traced_model_path)

### Step 7. Define the prompts and prepare them for sampling

In [None]:
prompts = ["I believe the meaning of life is", "The color of the sky is"]

# Example: parameter sweeps for sampling
sampling_params = prepare_sampling_params(batch_size=neuron_config.batch_size,
                                         top_k=[10, 5],
                                         top_p=[0.5, 0.9],
                                         temperature=[0.9, 0.5])

inputs = tokenizer(prompts, padding=True, return_tensors="pt")

### Step 8. Create a Generation Adapter and run inference

In [None]:
generation_model = HuggingFaceGenerationAdapter(model)
outputs = generation_model.generate(
    inputs.input_ids,
    generation_config=generation_config,
    attention_mask=inputs.attention_mask,
    max_length=model.config.neuron_config.max_length,
    sampling_params=sampling_params,
)
output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print("Generated outputs:")
for i, output_token in enumerate(output_tokens):
    print(f"Output {i}: {output_token}")


---
# Develop support for a new model with NxDI
Now that you've run inference with this model, let's take a closer look at how this works. The cells you just ran are based on a script available in our repository [here](https://github.com/aws-neuron/neuronx-distributed-inference/tree/main). You can step through this repository to understand how the objects are developed, inherited, and made available for inference. The full developer guide on the topic is available [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/onboarding-models.html#nxdi-onboarding-models). Let's look at some of the key points!

#### 1/ NeuronConfig class
You can inherit our base `NeuronConfig` class and extend it with your own model parameters. In the notebook you just ran, this is how we defined the following parameters:
- Tensor Parallel (TP) Degree
- Batch size
- Max context length (input shape)
- Sequence length (output shape)
- On device sampling
- Enabling bucketing
- Flash decoding


This object and these parameters will be sent to the compiler when you call `model.compile`. It's a helpful way to ensure that the compiler registers your design choices so that it can start optimizations. It also enables the model sharing with NxDI for your preferred TP degree, which lets you very quickly test a variety of TP degrees (TP=8, 32, 64, etc.).

#### 2/ InferenceConfig class
Next, you can inherit our base `InferenceConfig` class and extend it with the rest of your modeling parameters. In the notebook you ran above, we took two important steps with this config.
1. Passed into it the base `NeuronConfig`.
2. Passed the rest of the model config from the HuggingFace pretrained config.

Your inference class is where you define modeling parameters like the following:
- hidden size
- num attention heads
- num hidden layers
- num key value heads
- vocab size

You'll use this `config` object to save and compile your model. Let's learn how!

#### 3/ NeuronModel
This is how you fundamentally integrate your modeling code into the Neuron SDK. If you'd like to simply reuse our `NeuronAttentionBase`, you can inherit this directly through the library and simply pass your parameters through the `InferenceConfig` you defined above. This is how the example code in our notebook works. This is also the fastest way of getting your model online with NxD I.

In the example code you ran, you also used our code for `NeuronLlamaMLP`. This is a layer in the network which inherits from `nn.Module` directly, and it's where you can define the structure of your computations. The `NeuronLlamaMLP` uses a predefined `ColumnParallelLinear` object for both the gate and up projections, while using a predefined `RowParallelLinear` object for the down projection. It also defines a forward pass on that layer.

The rest of the model is defined similarly: either you inherit from our base objects and just passing in your `InferenceConfig`, or you define a new layer inheriting from `nn.Module` and write those layers as either `RowParallelLinear`, `ColumnParallelLinear`, or something else. The benefit of writing your layers into the `Row` and `Column` parallel layers as presented here is that we can handle the distribution of your model for you. 

For a more complete guide check out our documentation on the subject [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/api_guide.html#api-guide).

### Notebook Wrap-Up

For more advanced topics:
- **Profiling**: See [Neuron Profiling Tools](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/tools/neuron-profile/index.html).
- **Distributed Serving**: Explore vLLM or other serving frameworks.
- **Performance Benchmarking**: Use `llmperf` or custom scripts.

Thank you for using AWS Trainium, and happy LLM experimentation!
