Skip to content

Commit

Permalink
[ONNX] Add huggingface models into CI tests
Browse files Browse the repository at this point in the history
[ONNX] Add transformers models into no runtime test of fx exporter

ghstack-source-id: fe1a4c2f9d0c554a49047ce4c543444eeb39565c
Pull Request resolved: #107247
  • Loading branch information
titaiwangms committed Aug 16, 2023
1 parent 3577ae3 commit ca682f2
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 19 deletions.
9 changes: 9 additions & 0 deletions .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@ pip_install onnx-weekly==1.15.0.dev20230717
# TODO: change this when onnx-script is on testPypi
pip_install onnxscript-preview==0.1.0.dev20230809 --no-deps

# Higgingface models requirements
pip_install einops==0.6.1 # mpt-7b

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
IMPORT_SCRIPT_FILENAME="/tmp/onnx_import_script.py"
as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2");' > "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("bigscience/bloom-560m"); transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True); transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True);' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("openai/whisper-tiny"); transformers.WhisperConfig.from_pretrained("openai/whisper-tiny");transformers.WhisperProcessor.from_pretrained("openai/whisper-tiny");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("google/flan-t5-small"); transformers.AutoTokenizer.from_pretrained("google/flan-t5-small");' >> "${IMPORT_SCRIPT_FILENAME}"
as_jenkins echo 'transformers.AutoModel.from_pretrained("databricks/dolly-v2-3b"); transformers.AutoTokenizer.from_pretrained("databricks/dolly-v2-3b");' >> "${IMPORT_SCRIPT_FILENAME}"


# Need a PyTorch version for transformers to work
pip_install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
Expand Down
172 changes: 167 additions & 5 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import onnx
import pytorch_test_common
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses import fake_tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -406,14 +407,175 @@ def forward(self, x):
fake_model, real_x, export_options=export_options
)

def test_fake_tensor_mode_huggingface_bigscience__bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]
def test_fake_tensor_mode_huggingface_gpt2(self):
config = transformers.GPT2Config()
batch, seq = 4, 256

model_name = "bigscience/bloom-560m"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = transformers.GPT2Model(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_bigscience_bloom(self):
config = transformers.BloomConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.BloomModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_open_llama(self):
config = transformers.OpenLlamaConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.OpenLlamaModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
@pytorch_test_common.xfail(
"SymFloat in OnnxFUnction attribute is not supported yet."
)
def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
model_name = "databricks/dolly-v2-3b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = transformers.AutoModel.from_pretrained(model_name)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
def test_fake_tensor_mode_huggingface_google_flan_t5_small(self):
model_name = "google/flan-t5-small"
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

decoder_input_ids = tokenizer(
"Studies show that", return_tensors="pt"
).input_ids # Batch size 1
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
decoder_input_ids = model._shift_right(decoder_input_ids)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
inputs["decoder_input_ids"] = decoder_input_ids

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
def test_fake_tensor_mode_huggingface_openai_whisper_tiny(self):
from datasets import load_dataset # type: ignore[import]

model_name = "openai/whisper-tiny"
with torch.onnx.enable_fake_mode() as fake_context:
config = transformers.WhisperConfig.from_pretrained(model_name)
processor = transformers.WhisperProcessor.from_pretrained(model_name)
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
input_features = processor(
[ds[0]["audio"]["array"]], return_tensors="pt"
).input_features
decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id

model = transformers.AutoModel.from_pretrained(model_name)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_features,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.xfail(
"AssertionError: Mutating module attribute seq_len_cached during export."
"self.seq_len_cached = seq_len"
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
config = transformers.FalconConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.FalconModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.xfail(
"torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment."
"Please use functorch.experimental.control_flow.cond to explicitly capture the control flow."
"if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:"
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt_7b(self):
model_name = "mosaicml/mpt-7b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
Expand Down
59 changes: 45 additions & 14 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import onnx_test_common
import onnxruntime # type: ignore[import]
import parameterized
import parameterized # type: ignore[import]
import pytorch_test_common
import torch
import torch.onnx
Expand All @@ -27,7 +27,7 @@
from torch.testing._internal import common_utils

try:
import torchvision
import torchvision # type: ignore[import]

HAS_TORCHVISION = True
except ImportError:
Expand Down Expand Up @@ -555,21 +555,52 @@ def func(x):
func, (torch.randn(3, 4),)
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")
def test_gpt2_tiny_from_config(self):
# Model
config = transformers.GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
model = transformers.GPT2Model(config).eval()

# Encoded inputs
batch, seq = 2, 128
input_ids = torch.randint(0, 8096, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

# Another encoded inputs to test dynamic shapes
another_batch, another_seq = 3, 256
another_input_ids = torch.randint(0, 8096, (another_batch, another_seq))
another_attention_mask = torch.ones(
another_batch, another_seq, dtype=torch.bool
)
another_position_ids = torch.arange(0, another_seq, dtype=torch.long)
another_position_ids = another_position_ids.unsqueeze(0).view(-1, another_seq)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
[],
input_kwargs=inputs,
additional_test_inputs=[((), another_inputs)],
(input_ids,),
input_kwargs={
"attention_mask": attention_mask,
"position_ids": position_ids,
},
additional_test_inputs=[
(
(another_input_ids,),
{
"attention_mask": another_attention_mask,
"position_ids": another_position_ids,
},
)
],
)

def test_prims_device_put(self):
Expand Down

0 comments on commit ca682f2

Please sign in to comment.