Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add huggingface models into CI tests #107247

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aae282d
[ONNX] Add huggingface models into CI tests
titaiwangms Aug 15, 2023
dc16fb9
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
1e3a8f7
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
9fccefd
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
aec630d
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
8c25e11
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
6cdc7a0
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
d45c180
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
52cc050
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
7f64709
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
6bbef52
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
b29316c
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
079a6e9
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
38b18dd
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
866bf27
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
ec896d0
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
b9b122d
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
f365371
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 21, 2023
dcea1d4
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 21, 2023
8927ebc
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 22, 2023
487fa54
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@ pip_install onnx-weekly==1.15.0.dev20230717
# TODO: change this when onnx-script is on testPypi
pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps

# Higgingface models requirements
pip_install einops==0.6.1 # mpt-7b

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
IMPORT_SCRIPT_FILENAME="/tmp/onnx_import_script.py"
as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2");' > "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("bigscience/bloom-560m"); transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True); transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True);' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("openai/whisper-tiny"); transformers.WhisperConfig.from_pretrained("openai/whisper-tiny");transformers.WhisperProcessor.from_pretrained("openai/whisper-tiny");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("google/flan-t5-small"); transformers.AutoTokenizer.from_pretrained("google/flan-t5-small");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("databricks/dolly-v2-3b"); transformers.AutoTokenizer.from_pretrained("databricks/dolly-v2-3b");' >> "${IMPORT_SCRIPT_FILENAME}"
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved


# Need a PyTorch version for transformers to work
pip_install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
Expand Down
172 changes: 167 additions & 5 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import onnx
import pytorch_test_common
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses import fake_tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -406,14 +407,175 @@ def forward(self, x):
fake_model, real_x, export_options=export_options
)

def test_fake_tensor_mode_huggingface_bigscience__bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]
def test_fake_tensor_mode_huggingface_gpt2(self):
config = transformers.GPT2Config()
batch, seq = 4, 256

model_name = "bigscience/bloom-560m"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = transformers.GPT2Model(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_bigscience_bloom(self):
config = transformers.BloomConfig()
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.BloomModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_open_llama(self):
config = transformers.OpenLlamaConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.OpenLlamaModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
@pytorch_test_common.xfail(
"SymFloat in OnnxFUnction attribute is not supported yet."
)
def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
model_name = "databricks/dolly-v2-3b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = transformers.AutoModel.from_pretrained(model_name)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
def test_fake_tensor_mode_huggingface_google_flan_t5_small(self):
model_name = "google/flan-t5-small"
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

decoder_input_ids = tokenizer(
"Studies show that", return_tensors="pt"
).input_ids # Batch size 1
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
decoder_input_ids = model._shift_right(decoder_input_ids)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
inputs["decoder_input_ids"] = decoder_input_ids

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
def test_fake_tensor_mode_huggingface_openai_whisper_tiny(self):
from datasets import load_dataset # type: ignore[import]

model_name = "openai/whisper-tiny"
with torch.onnx.enable_fake_mode() as fake_context:
config = transformers.WhisperConfig.from_pretrained(model_name)
processor = transformers.WhisperProcessor.from_pretrained(model_name)
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
input_features = processor(
[ds[0]["audio"]["array"]], return_tensors="pt"
).input_features
decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id

model = transformers.AutoModel.from_pretrained(model_name)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_features,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.xfail(
"AssertionError: Mutating module attribute seq_len_cached during export."
"self.seq_len_cached = seq_len"
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
config = transformers.FalconConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.FalconModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.xfail(
"torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment."
"Please use functorch.experimental.control_flow.cond to explicitly capture the control flow."
"if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:"
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt_7b(self):
model_name = "mosaicml/mpt-7b"
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
Expand Down
59 changes: 45 additions & 14 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import onnx_test_common
import onnxruntime # type: ignore[import]
import parameterized
import parameterized # type: ignore[import]
import pytorch_test_common
import torch
import torch.onnx
Expand All @@ -27,7 +27,7 @@
from torch.testing._internal import common_utils

try:
import torchvision
import torchvision # type: ignore[import]

HAS_TORCHVISION = True
except ImportError:
Expand Down Expand Up @@ -555,21 +555,52 @@ def func(x):
func, (torch.randn(3, 4),)
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")
def test_gpt2_tiny_from_config(self):
# Model
config = transformers.GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
model = transformers.GPT2Model(config).eval()

# Encoded inputs
batch, seq = 2, 128
input_ids = torch.randint(0, 8096, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

# Another encoded inputs to test dynamic shapes
another_batch, another_seq = 3, 256
another_input_ids = torch.randint(0, 8096, (another_batch, another_seq))
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
another_attention_mask = torch.ones(
another_batch, another_seq, dtype=torch.bool
)
another_position_ids = torch.arange(0, another_seq, dtype=torch.long)
another_position_ids = another_position_ids.unsqueeze(0).view(-1, another_seq)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
[],
input_kwargs=inputs,
additional_test_inputs=[((), another_inputs)],
(input_ids,),
input_kwargs={
"attention_mask": attention_mask,
"position_ids": position_ids,
},
additional_test_inputs=[
(
(another_input_ids,),
{
"attention_mask": another_attention_mask,
"position_ids": another_position_ids,
},
)
],
)

def test_prims_device_put(self):
Expand Down