Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for decision transformer (Closes #794) #795

Merged
merged 2 commits into from
Jul 1, 2024

Conversation

xenova
Copy link
Owner

@xenova xenova commented Jun 6, 2024

Example code:

import { AutoModel, Tensor } from '@xenova/transformers';

// Load model
const model_id = 'onnx-community/decision-transformer-gym-hopper-expert';
const model = await AutoModel.from_pretrained(model_id, { quantized: false });

// Helper function to generate random tensor
function rand(dims) {
    const data = Float32Array.from({ length: dims.reduce((a, b) => a * b) }, () => Math.random());
    return new Tensor('float32', data, dims);
}

// Define config
const batch_size = 2;
const episode_length = 16;
const state_dim = model.config.state_dim;
const act_dim = model.config.act_dim;

// Generate random input
const states = rand([batch_size, episode_length, state_dim]);
const actions = rand([batch_size, episode_length, act_dim]);
const rewards = rand([batch_size, episode_length, 1]);
const returns_to_go = rand([batch_size, episode_length, 1]);
const timesteps = new Tensor('int64', new BigInt64Array([BigInt(episode_length)]), [1, 1]);
const attention_mask = rand([batch_size, episode_length]);

// Call model
const input = { states, actions, rewards, returns_to_go, timesteps, attention_mask };
const output = await model(input);
console.log(output);
// {
//     state_preds: Tensor {
//         dims: [2, 16, 11],
//         type: 'float32',
//         data: Float32Array(352)[ ... ],
//         size: 352
//     },
//     action_preds: Tensor {
//         dims: [2, 16, 3],
//         type: 'float32',
//         data: Float32Array(96)[ ... ],
//         size: 96
//     },
//     return_preds: Tensor {
//         dims: [2, 16, 1],
//         type: 'float32',
//         data: Float32Array(32)[ ... ],
//         size: 32
//     },
//     last_hidden_state: Tensor {
//         dims: [2, 48, 128],
//         type: 'float32',
//         data: Float32Array(12288)[ ... ],
//         size: 12288
//     }
// }

Export models to ONNX:

Requirements:

pip install transformers onnx==1.13.1

Code:

import torch
from transformers import DecisionTransformerModel

# 1. Load model
model_id = "edbeeching/decision-transformer-gym-hopper-medium"
model = DecisionTransformerModel.from_pretrained(model_id)

# 2. Define inputs
# states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
#     The states for each step in the trajectory
# actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
#     The actions taken by the "expert" policy for the current state, these are masked for auto regressive
#     prediction
# rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
#     The rewards for each state, action
# returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
#     The returns for each state in the trajectory
# timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
#     The timestep for each step in the trajectory
# attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
#     Masking, used to mask the actions when performing autoregressive prediction
batch_size = 2
episode_length = 16
state_dim = model.config.state_dim
act_dim = model.config.act_dim

states=torch.randn((batch_size, episode_length, state_dim))
actions=torch.randn((batch_size, episode_length, act_dim))
rewards=torch.randn((batch_size, episode_length, 1))
returns_to_go=torch.randn((batch_size, episode_length, 1))
timesteps=torch.tensor(0, dtype=torch.long).reshape(1, 1)
attention_mask=torch.randn((batch_size, episode_length))

# 3. Define outputs
# last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
#     Sequence of hidden-states at the output of the last layer of the model.
# state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
#     Environment state predictions
# action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
#     Model action predictions
# return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
#     Predicted returns for each state
# hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
#     Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
#     shape `(batch_size, sequence_length, hidden_size)`.

#     Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
#     Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
#     sequence_length)`.

#     Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
#     heads.

dynamic_axes = {0 : 'batch_size', 1: 'episode_length'}

# 4. Export the model
torch.onnx.export(model, # model being run
                  (states, actions, rewards, returns_to_go, timesteps, attention_mask), # model input (or a tuple for multiple inputs)
                  "model.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=13,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['states', 'actions', 'rewards', 'returns_to_go', 'timesteps', 'attention_mask'],   # the model's input names
                  output_names = ['state_preds', 'action_preds', 'return_preds', 'last_hidden_state'], # the model's output names
                  dynamic_axes={
                      'states' : dynamic_axes,
                      'actions' : dynamic_axes,
                      'rewards' : dynamic_axes,
                      'returns_to_go' : dynamic_axes,
                      'timesteps' : dynamic_axes,
                      'attention_mask' : dynamic_axes,

                      'state_preds' : dynamic_axes,
                      'action_preds' : dynamic_axes,
                      'return_preds' : dynamic_axes,
                      'last_hidden_state' : dynamic_axes,
                  }
)

@xenova xenova mentioned this pull request Jun 6, 2024
2 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Models are in the `onnx-community` org on HF
@xenova xenova merged commit fc34517 into main Jul 1, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants