Skip to content

Commit

Permalink
[ONNX] Add huggingface models into CI tests
Browse files Browse the repository at this point in the history
[ONNX] Add transformers models into no runtime test of fx exporter

ghstack-source-id: 01663e84842d7dd46fd7a72154ea1b2d5bb655d7
Pull Request resolved: #107247
  • Loading branch information
titaiwangms committed Aug 15, 2023
1 parent e9cb717 commit f1e11a0
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 32 deletions.
1 change: 0 additions & 1 deletion test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def run_ort(
raise AssertionError(
f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
)

ort_input = {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
return session.run(None, ort_input)

Expand Down
137 changes: 137 additions & 0 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,143 @@ def test_fake_tensor_mode_huggingface_bigscience__bloom_560m(self):
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_open_llama_3b_v2(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "openlm-research/open_llama_3b_v2"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# SymFloat in attribute
def test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "databricks/dolly-v2-3b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# AssertionError: Mutating module attribute seq_len_cached during export.
@pytorch_test_common.xfail(
"AssertionError: Mutating module attribute seq_len_cached during export."
"self.seq_len_cached = seq_len"
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon_7b(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "tiiuae/falcon-7b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModel.from_pretrained(model_name)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

# torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment.
# Please use functorch.experimental.control_flow.cond to explicitly capture the control flow
@pytorch_test_common.xfail(
"torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment."
"Please use functorch.experimental.control_flow.cond to explicitly capture the control flow."
"if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:"
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt_7b(self):
from transformers import ( # type: ignore[import]
AutoModelForCausalLM,
AutoTokenizer,
)

model_name = "mosaicml/mpt-7b"
with torch.onnx.enable_fake_mode() as fake_context:
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer("Hello world!", return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(model_name)

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_google_flan_t5_small(self):
from transformers import AutoModel, AutoTokenizer # type: ignore[import]

model_name = "google/flan-t5-small"
with torch.onnx.enable_fake_mode() as fake_context:
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

decoder_input_ids = tokenizer(
"Studies show that", return_tensors="pt"
).input_ids # Batch size 1
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
decoder_input_ids = model._shift_right(decoder_input_ids)
inputs = tokenizer("Hello world!", return_tensors="pt")
inputs["decoder_input_ids"] = decoder_input_ids

export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model, **inputs, export_options=export_options
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)

def test_fake_tensor_mode_huggingface_openai_whisper_tiny(self):
from datasets import load_dataset # type: ignore[import]
from transformers import ( # type: ignore[import]
AutoModel,
WhisperConfig,
WhisperProcessor,
)

model_name = "openai/whisper-tiny"
with torch.onnx.enable_fake_mode() as fake_context:
config = WhisperConfig.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
input_features = processor(
[ds[0]["audio"]["array"]], return_tensors="pt"
).input_features
decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id

model = AutoModel.from_pretrained(model_name)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
export_output = torch.onnx.dynamo_export(
model,
input_features,
decoder_input_ids=decoder_input_ids,
export_options=export_options,
)
onnx.checker.check_model(export_output.model_proto)
onnx.shape_inference.infer_shapes(export_output.model_proto)


if __name__ == "__main__":
common_utils.run_tests()
156 changes: 125 additions & 31 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pytorch_test_common
import torch
import torch.onnx
import transformers # type: ignore[import]
from torch import nn

from torch._subclasses import fake_tensor
Expand Down Expand Up @@ -556,20 +555,57 @@ def func(x):
)

def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
another_inputs = tokenizer("Another Hello world!", return_tensors="pt")
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

device = "cpu"
# Model
config = GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
model = GPT2Model(config).to(device).eval()

# Encoded inputs
batch, seq = 2, 128
input_ids = torch.randint(0, 8096, (batch, seq)).to(device)
attention_mask = torch.ones(batch, seq, dtype=torch.bool).to(device)
position_ids = torch.arange(0, seq, dtype=torch.long).to(device)
position_ids = position_ids.unsqueeze(0).view(-1, seq)

# Another encoded inputs to test dynamic shapes
another_batch, another_seq = 3, 256
another_input_ids = torch.randint(0, 8096, (another_batch, another_seq)).to(
device
)
another_attention_mask = torch.ones(
another_batch, another_seq, dtype=torch.bool
).to(device)
another_position_ids = torch.arange(0, another_seq, dtype=torch.long).to(device)
another_position_ids = another_position_ids.unsqueeze(0).view(-1, another_seq)

self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model,
[],
input_kwargs=inputs,
additional_test_inputs=[((), another_inputs)],
(input_ids,),
input_kwargs={
"attention_mask": attention_mask,
"position_ids": position_ids,
},
additional_test_inputs=[
(
(another_input_ids,),
{
"attention_mask": another_attention_mask,
"position_ids": another_position_ids,
},
)
],
)

def test_prims_device_put(self):
Expand Down Expand Up @@ -636,7 +672,8 @@ def _test_fx_symbolic_tracer_large_scale_exporter(
with ctx, ftm:
# Toy model with parameters and buffers as FakeTensor's.
fake_model = create_model()
fake_model.load_state_dict(torch.load(tmp_file.name))
model_state_dict = torch.load(tmp_file.name)
fake_model.load_state_dict(model_state_dict)
# Toy inputs as FakeTensor's.
fake_args = create_args()
# Export ONNX model without initializers while ctx.paths records
Expand Down Expand Up @@ -745,21 +782,62 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

# @pytorch_test_common.skip_dynamic_fx_test(
# "FakeTensor exporting is not supported by dynamic axes."
# )
# def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2_other(self):
# model_name = "sshleifer/tiny-gpt2"

# def create_model() -> nn.Module:
# return transformers.AutoModel.from_pretrained(model_name)

# def create_args():
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
# kwargs = tokenizer("Hello world!", return_tensors="pt")
# input_ids = kwargs["input_ids"]
# attention_mask = kwargs["attention_mask"]
# return input_ids, None, attention_mask

# def create_pytorch_only_extra_kwargs():
# return {"return_dict": False}

# self._test_fx_symbolic_tracer_large_scale_exporter(
# "tiny_gpt2",
# create_model,
# create_args,
# create_pytorch_only_extra_kwargs,
# )

@pytorch_test_common.skip_dynamic_fx_test(
"FakeTensor exporting is not supported by dynamic axes."
)
def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

device = "cpu"
config = GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
batch, seq = 2, 128

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return GPT2Model(config).to(device).eval()

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
kwargs = tokenizer("Hello world!", return_tensors="pt")
input_ids = kwargs["input_ids"]
attention_mask = kwargs["attention_mask"]
return input_ids, None, attention_mask
input_ids = torch.randint(0, 8096, (batch, seq)).to(device)
attention_mask = torch.ones(batch, seq, dtype=torch.bool).to(device)
position_ids = torch.arange(0, seq, dtype=torch.long).to(device)
position_ids = position_ids.unsqueeze(0).view(-1, seq)
# TODO(titaiwang): initializers are not given to the model
return input_ids, None, attention_mask, None, position_ids

def create_pytorch_only_extra_kwargs():
return {"return_dict": False}
Expand All @@ -774,9 +852,7 @@ def create_pytorch_only_extra_kwargs():

def _parameterized_class_attrs_and_values_with_fake_options():
input_values = []
input_values.extend(
itertools.product((True, False), (True, False), (True, False), (True, False))
)
input_values.extend(itertools.product((False,), (False,), (False,), (True, False)))
return {
"attrs": [
"op_level_debug",
Expand Down Expand Up @@ -936,20 +1012,38 @@ def create_kwargs():
)

def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

device = "cpu"
config = GPT2Config(
num_hidden_layers=4,
vocab_size=8096,
hidden_size=16,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
hidden_dropout_prob=0.0,
attention_dropout_prob=0.0,
)
batch, seq = 2, 128

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)
return GPT2Model(config).to(device).eval()

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
kwargs = tokenizer("Hello world!", return_tensors="pt")
input_ids = kwargs["input_ids"]
attention_mask = kwargs["attention_mask"]
return input_ids, None, attention_mask
input_ids = torch.randint(0, 8096, (batch, seq)).to(device)
return (input_ids,)

def create_kwargs():
return {"return_dict": False}
attention_mask = torch.ones(batch, seq, dtype=torch.bool).to(device)
position_ids = torch.arange(0, seq, dtype=torch.long).to(device)
position_ids = position_ids.unsqueeze(0).view(-1, seq)
return {
"attention_mask": attention_mask,
"position_ids": position_ids,
"return_dict": False,
}

self._test_fake_tensor_mode_exporter(
"tiny_gpt2",
Expand Down

0 comments on commit f1e11a0

Please sign in to comment.