Skip to content

Commit

Permalink
[ONNX] Add huggingface models into CI tests (#107247)
Browse files Browse the repository at this point in the history
1. Add a list of HF models to CI tests. The PR intends to build them from Config, but some of them are not supported with Config. NOTE: Loaded from pre-trained model could potentially hit [uint8/bool conflict](huggingface/transformers#21013) when a newer version of transformers is used.
    - Dolly has torch.fx.Node in OnnxFunction attribute, which is currently not supported.
    - Falcon and MPT has unsupported user coding to Dynamo.
2. Only update GPT2 exporting with real tensor to Config, as FakeMode rises unequal input errors between PyTorch and ORT. The reason is that [non-persistent buffer is not supported](#107211)
Pull Request resolved: #107247
Approved by: https://github.com/wschin, https://github.com/BowenBao
  • Loading branch information
titaiwangms authored and pytorchmergebot committed Aug 23, 2023
1 parent 610f64d commit 400c4de
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pip_install \
pytest-cov==4.0.0 \
pytest-subtests==0.10.0 \
tabulate==0.9.0 \
transformers==4.25.1
transformers==4.31.0

# Using 1.15dev branch for the following not yet released features and fixes.
# - Segfault fix for shape inference.
Expand Down
183 changes: 178 additions & 5 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import onnx
import pytorch_test_common
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses import fake_tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -406,14 +407,186 @@ def forward(self, x):
fake_model, real_x, export_options=export_options
)

def test_fake_tensor_mode_huggingface_bigscience__bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]
# NOTE: To all transformer models, config is preferred to pre-trained model for testing because:
# 1. Pre-trained model is too big for CI
# 2. Pre-trained model is has uint8/bool issue: https://github.com/huggingface/transformers/issues/21013
def test_fake_tensor_mode_huggingface_gpt2(self):
config = transformers.GPT2Config()
batch, seq = 4, 256

model_name = "bigscience/bloom-560m"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = transformers.GPT2Model(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_bigscience_bloom(self):
config = transformers.BloomConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.BloomModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_open_llama(self):
config = transformers.OpenLlamaConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.OpenLlamaModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_google_t5(self):
config = transformers.T5Config()
device = "cpu"
batch, seq = 4, 256
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.T5Model(config).to(device).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones((batch, seq), dtype=torch.bool)
decoder_input_ids = torch.randint(0, config.vocab_size, (batch, seq))
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_openai_whisper(self):
config = transformers.WhisperConfig()
feature_extractor = transformers.WhisperFeatureExtractor()
device = "cpu"
batch = 4
with torch.onnx.enable_fake_mode() as fake_context:
input_features = torch.randn(
(
batch,
feature_extractor.feature_size,
feature_extractor.nb_max_frames,
),
dtype=torch.float32,
)
decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id
model = transformers.AutoModel.from_config(config).to(device).eval()
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_features,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# TODO: From Config/Model
@pytorch_test_common.skip_in_ci(
"Not decorated with xfail because CI doesn't have enough memory to run and then fail."
"SymFloat in OnnxFUnction attribute is not supported yet."
)
def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
# TODO: Make this test work with config
# Dolly has no config on transformers
model_name = "databricks/dolly-v2-3b"
device = "cpu"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
model = transformers.AutoModel.from_pretrained(model_name).to(device).eval()

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.skip_in_ci(
"Not decorated with xfail because CI doesn't have enough memory to run and then fail."
"AssertionError: Mutating module attribute seq_len_cached during export."
"self.seq_len_cached = seq_len"
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
config = transformers.FalconConfig()
batch, seq = 4, 256

with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.FalconModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

@pytorch_test_common.skip_in_ci(
"Not decorated with xfail because CI doesn't have enough memory to run and then fail."
"torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. "
"Please use functorch.experimental.control_flow.cond to explicitly capture the control flow"
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt_7b(self):
# TODO: Make this test work with config
# mpt-7b has no config on transformers
model_name = "mosaicml/mpt-7b"
device = "cpu"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = (
transformers.AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
.to(device)
.eval()
)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
Expand Down
90 changes: 67 additions & 23 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import onnx_test_common
import onnxruntime # type: ignore[import]
import parameterized
import parameterized # type: ignore[import]
import pytorch_test_common
import torch
import torch.onnx
Expand All @@ -27,7 +27,7 @@
from torch.testing._internal import common_utils

try:
import torchvision
import torchvision # type: ignore[import]

HAS_TORCHVISION = True
except ImportError:
Expand Down Expand Up @@ -560,21 +560,53 @@ def func(x):
func, (torch.randn(3, 4),)
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
def test_gpt2_tiny_from_config(self):
# Model
config = transformers.GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
model = transformers.GPT2Model(config).eval()

def input_generator(batch: int, seq: int):
input_ids = torch.randint(0, 8096, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)
return input_ids, attention_mask, position_ids

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")
# Encoded inputs
input_ids, attention_mask, position_ids = input_generator(2, 128)

# Another encoded inputs to test dynamic shapes
(
another_input_ids,
another_attention_mask,
another_position_ids,
) = input_generator(3, 256)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
[],
input_kwargs=inputs,
additional_test_inputs=[((), another_inputs)],
(input_ids,),
input_kwargs={
"attention_mask": attention_mask,
"position_ids": position_ids,
},
additional_test_inputs=[
(
(another_input_ids,),
{
"attention_mask": another_attention_mask,
"position_ids": another_position_ids,
},
)
],
)

def test_prims_device_put(self):
Expand Down Expand Up @@ -750,14 +782,20 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

@pytorch_test_common.xfail(
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
"https://github.com/huggingface/transformers/issues/21013"
)
@pytorch_test_common.skip_dynamic_fx_test(
"FakeTensor exporting is not supported by dynamic axes."
)
def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return transformers.AutoModel.from_pretrained(model_name).to(device).eval()

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -940,11 +978,17 @@ def create_kwargs():
export_within_fake_mode=self.export_within_fake_mode,
)

@pytorch_test_common.xfail(
"[ONNXRuntimeError] : 1 : FAIL : Type Error: Data in initializer 'h_0_attn_bias' "
"has element type tensor(uint8) but usage of initializer in graph expects tensor(bool)"
"https://github.com/huggingface/transformers/issues/21013"
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return transformers.AutoModel.from_pretrained(model_name).to(device).eval()

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -1012,22 +1056,22 @@ def create_kwargs():
"HF Bloom model does not need `model.load_state_dict` to work."
)
def test_fake_tensor_mode_huggingface_bigscience_bloom_560m(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "bigscience/bloom-560m"
config = transformers.BloomConfig()
batch, seq = 4, 256

def create_args():
return tuple()

def create_kwargs(model_name=model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer("Hello world!", return_tensors="pt")
def create_kwargs():
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
return {"input_ids": input_ids, "attention_mask": attention_mask}

def create_model():
return AutoModel.from_pretrained(model_name)
return transformers.BloomModel(config).eval()

self._test_fake_tensor_mode_exporter(
model_name.replace("/", "_"),
"huggingface_bigscience_bloom_560m",
create_model,
create_args,
create_kwargs,
Expand Down

0 comments on commit 400c4de

Please sign in to comment.