Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add huggingface models into CI tests #107247

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aae282d
[ONNX] Add huggingface models into CI tests
titaiwangms Aug 15, 2023
dc16fb9
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
1e3a8f7
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
9fccefd
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 15, 2023
aec630d
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
8c25e11
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
6cdc7a0
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
d45c180
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
52cc050
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
7f64709
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 16, 2023
6bbef52
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
b29316c
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
079a6e9
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
38b18dd
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 17, 2023
866bf27
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
ec896d0
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
b9b122d
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 18, 2023
f365371
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 21, 2023
dcea1d4
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 21, 2023
8927ebc
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 22, 2023
487fa54
Update on "[ONNX] Add huggingface models into CI tests"
titaiwangms Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 1 addition & 16 deletions .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pip_install \
pytest-cov==4.0.0 \
pytest-subtests==0.10.0 \
tabulate==0.9.0 \
transformers==4.25.1
transformers==4.31.0

# Using 1.15dev branch for the following not yet released features and fixes.
# - Segfault fix for shape inference.
Expand All @@ -32,18 +32,3 @@ pip_install onnx-weekly==1.15.0.dev20230717

# TODO: change this when onnx-script is on testPypi
pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
IMPORT_SCRIPT_FILENAME="/tmp/onnx_import_script.py"
as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2");' > "${IMPORT_SCRIPT_FILENAME}"

# Need a PyTorch version for transformers to work
pip_install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
# Very weird quoting behavior here https://github.com/conda/conda/issues/10972,
# so echo the command to a file and run the file instead
conda_run python "${IMPORT_SCRIPT_FILENAME}"

# Cleaning up
conda_run pip uninstall -y torch
rm "${IMPORT_SCRIPT_FILENAME}" || true
176 changes: 171 additions & 5 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import onnx
import pytorch_test_common
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses import fake_tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -406,14 +407,179 @@ def forward(self, x):
fake_model, real_x, export_options=export_options
)

def test_fake_tensor_mode_huggingface_bigscience__bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]
def test_fake_tensor_mode_huggingface_gpt2(self):
config = transformers.GPT2Config()
batch, seq = 4, 256

model_name = "bigscience/bloom-560m"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = transformers.GPT2Model(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_bigscience_bloom(self):
config = transformers.BloomConfig()
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.BloomModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_open_llama(self):
config = transformers.OpenLlamaConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.OpenLlamaModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_google_t5(self):
config = transformers.T5Config()
device = "cpu"
batch, seq = 4, 256
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.T5Model(config).to(device).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones((batch, seq), dtype=torch.bool)
decoder_input_ids = torch.randint(0, config.vocab_size, (batch, seq))
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_openai_whisper(self):
config = transformers.WhisperConfig()
feature_extractor = transformers.WhisperFeatureExtractor()
device = "cpu"
batch = 4
with torch.onnx.enable_fake_mode() as fake_context:
input_features = torch.randn(
(
batch,
feature_extractor.feature_size,
feature_extractor.nb_max_frames,
),
dtype=torch.float32,
)
decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id
model = transformers.AutoModel.from_config(config).to(device).eval()
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_features,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
@pytorch_test_common.skip_in_ci(
"Skip this test in CI because of memory issue."
"SymFloat in OnnxFUnction attribute is not supported yet."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
model_name = "databricks/dolly-v2-3b"
device = "cpu"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
model = transformers.AutoModel.from_pretrained(model_name).to(device).eval()

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.skip_in_ci(
"AssertionError: Mutating module attribute seq_len_cached during export."
"self.seq_len_cached = seq_len"
"Skip this test in CI because of memory issue."
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
config = transformers.FalconConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.FalconModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.skip_in_ci(
"Skip this test in CI because of memory issue."
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
"torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. "
"Please use functorch.experimental.control_flow.cond to explicitly capture the control flow"
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt_7b(self):
model_name = "mosaicml/mpt-7b"
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
device = "cpu"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = (
transformers.AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
.to(device)
.eval()
)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
Expand Down
91 changes: 67 additions & 24 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import onnx_test_common
import onnxruntime # type: ignore[import]
import parameterized
import parameterized # type: ignore[import]
import pytorch_test_common
import torch
import torch.onnx
Expand All @@ -27,7 +27,7 @@
from torch.testing._internal import common_utils

try:
import torchvision
import torchvision # type: ignore[import]

HAS_TORCHVISION = True
except ImportError:
Expand Down Expand Up @@ -560,21 +560,52 @@ def func(x):
func, (torch.randn(3, 4),)
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")
def test_gpt2_tiny_from_config(self):
# Model
config = transformers.GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
model = transformers.GPT2Model(config).eval()

# Encoded inputs
batch, seq = 2, 128
input_ids = torch.randint(0, 8096, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

# Another encoded inputs to test dynamic shapes
another_batch, another_seq = 3, 256
another_input_ids = torch.randint(0, 8096, (another_batch, another_seq))
titaiwangms marked this conversation as resolved.
Show resolved Hide resolved
another_attention_mask = torch.ones(
another_batch, another_seq, dtype=torch.bool
)
another_position_ids = torch.arange(0, another_seq, dtype=torch.long)
another_position_ids = another_position_ids.unsqueeze(0).view(-1, another_seq)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
[],
input_kwargs=inputs,
additional_test_inputs=[((), another_inputs)],
(input_ids,),
input_kwargs={
"attention_mask": attention_mask,
"position_ids": position_ids,
},
additional_test_inputs=[
(
(another_input_ids,),
{
"attention_mask": another_attention_mask,
"position_ids": another_position_ids,
},
)
],
)

def test_prims_device_put(self):
Expand Down Expand Up @@ -750,14 +781,20 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

@pytorch_test_common.xfail(
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
"https://github.com/huggingface/transformers/issues/21013"
)
@pytorch_test_common.skip_dynamic_fx_test(
"FakeTensor exporting is not supported by dynamic axes."
)
def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return transformers.AutoModel.from_pretrained(model_name).to(device).eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still needs to be cached though right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Added back the tiny-gpt2 cache.


def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -940,11 +977,17 @@ def create_kwargs():
export_within_fake_mode=self.export_within_fake_mode,
)

@pytorch_test_common.xfail(
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
"https://github.com/huggingface/transformers/issues/21013"
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return transformers.AutoModel.from_pretrained(model_name).to(device).eval()

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -1012,22 +1055,22 @@ def create_kwargs():
"HF Bloom model does not need `model.load_state_dict` to work."
)
def test_fake_tensor_mode_huggingface_bigscience_bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "bigscience/bloom-560m"
config = transformers.BloomConfig()
batch, seq = 4, 256

def create_args():
return tuple()

def create_kwargs(model_name=model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer("Hello world!", return_tensors="pt")
def create_kwargs():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _test_fake_tensor_mode_exporter, the exported ONNX model and PyTorch model are run and their outputs are compared. However, I am wondering if that really checks dynamic_shapes. If not, probably with the next (or 2nd next), let's modify create_args, create_kwargs, and test_fake_tensor_mode_exporter to guard dynamic_shapes supports different batch and sequence sizes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is assert_dynamic_shapes inside the _test_fake_tensor_mode_exporter should address your concern.

input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
return {"input_ids": input_ids, "attention_mask": attention_mask}

def create_model():
return AutoModel.from_pretrained(model_name)
return transformers.BloomModel(config).eval()

self._test_fake_tensor_mode_exporter(
model_name.replace("/", "_"),
"huggingface_bigscience_bloom_560m",
create_model,
create_args,
create_kwargs,
Expand Down