Skip to content

Commit

Permalink
[CI] Switch inference accuracy and performance tests to bfloat16
Browse files Browse the repository at this point in the history
ghstack-source-id: 3a321704605a1e14e909a35beae2d0fd464d06a6
Pull Request resolved: #103535
  • Loading branch information
desertfire committed Jun 15, 2023
1 parent f61b248 commit 7cb1f26
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 40 deletions.
3 changes: 2 additions & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ test_perf_for_dashboard() {
fi
if [[ "$DASHBOARD_TAG" == *inference-true* ]]; then
modes+=(inference)
dtype=bfloat16
fi
# TODO: All the accuracy tests can be skipped once the CI accuracy checking is stable enough
local targets=(accuracy performance)
Expand Down Expand Up @@ -432,7 +433,7 @@ test_dynamo_benchmark() {
if [[ "${TEST_CONFIG}" == *cpu_accuracy* ]]; then
test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --float32 "$@"
else
test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --amp "$@"
test_single_dynamo_benchmark "inference" "$suite" "$shard_id" --inference --bfloat16 "$@"
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
fi
fi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AllenaiLongformerBase,pass,4
BartForCausalLM,pass,0
BertForMaskedLM,pass,0
BertForQuestionAnswering,pass,0
BlenderbotForCausalLM,pass_due_to_skip,0
BlenderbotSmallForCausalLM,pass,0
BlenderbotSmallForConditionalGeneration,pass,0
CamemBert,pass,0
Expand All @@ -17,7 +18,6 @@ DistillGPT2,pass,0
ElectraForCausalLM,pass,0
ElectraForQuestionAnswering,pass,0
GPT2ForSequenceClassification,pass,2
GoogleFnet,pass,0
LayoutLMForMaskedLM,pass,0
LayoutLMForSequenceClassification,pass,2
M2M100ForConditionalGeneration,pass,0
Expand All @@ -34,7 +34,7 @@ PLBartForConditionalGeneration,pass,0
PegasusForCausalLM,pass,0
PegasusForConditionalGeneration,pass,0
RobertaForCausalLM,pass,0
RobertaForQuestionAnswering,pass,0
RobertaForQuestionAnswering,fail_accuracy,0
Speech2Text2ForCausalLM,pass,0
T5ForConditionalGeneration,pass,0
T5Small,pass,0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ BartForCausalLM,pass,0
BartForConditionalGeneration,pass,0
BertForMaskedLM,pass,0
BertForQuestionAnswering,pass,0
BlenderbotForCausalLM,pass_due_to_skip,0
BlenderbotSmallForCausalLM,pass,0
BlenderbotSmallForConditionalGeneration,pass,0
CamemBert,pass,0
DebertaForMaskedLM,pass,0
DebertaForQuestionAnswering,pass,0
DebertaV2ForMaskedLM,pass_due_to_skip,0
DebertaV2ForQuestionAnswering,eager_2nd_run_OOM,0
DebertaV2ForQuestionAnswering,pass,0
DistilBertForMaskedLM,pass,0
DistilBertForQuestionAnswering,pass,0
DistillGPT2,pass,0
ElectraForCausalLM,pass,0
ElectraForQuestionAnswering,pass,0
GPT2ForSequenceClassification,pass,2
GoogleFnet,pass,0
LayoutLMForMaskedLM,pass,0
LayoutLMForSequenceClassification,pass,2
M2M100ForConditionalGeneration,pass,0
Expand All @@ -36,7 +36,7 @@ PLBartForConditionalGeneration,pass,0
PegasusForCausalLM,pass,0
PegasusForConditionalGeneration,pass,0
RobertaForCausalLM,pass,0
RobertaForQuestionAnswering,pass,0
RobertaForQuestionAnswering,fail_accuracy,0
Speech2Text2ForCausalLM,pass,0
T5ForConditionalGeneration,pass,0
T5Small,pass,0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ SelecSls42b,pass,0
adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
cait_m36_384,pass,0
coat_lite_mini,pass,0
convit_base,pass,0
convmixer_768_32,pass,0
convnext_base,pass,0
crossvit_9_240,pass,0
Expand Down Expand Up @@ -57,4 +57,3 @@ tnt_s_patch16_224,pass,0
twins_pcpvt_base,pass,0
vit_base_patch16_224,pass,0
volo_d1_224,pass,0
xcit_large_24_p8_224,pass,0
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ SelecSls42b,pass,0
adv_inception_v3,pass,0
beit_base_patch16_224,pass,0
botnet26t_256,pass,0
cait_m36_384,pass,0
coat_lite_mini,pass,0
convit_base,pass,0
convmixer_768_32,pass,0
convnext_base,pass,0
crossvit_9_240,pass,0
Expand Down Expand Up @@ -57,4 +57,3 @@ tnt_s_patch16_224,pass,0
twins_pcpvt_base,pass,0
vit_base_patch16_224,pass,0
volo_d1_224,pass,0
xcit_large_24_p8_224,pass,0
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name,accuracy,graph_breaks
BERT_pytorch,pass,0
Background_Matting,pass,0
LearningToPaint,pass,0
Super_SloMo,pass,0
alexnet,pass,0
attention_is_all_you_need_pytorch,pass,0
basic_gnn_edgecnn,pass,0
Expand All @@ -12,8 +11,6 @@ basic_gnn_sage,pass,0
cm3leon_generate,pass,6
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,4
doctr_reco_predictor,fail_accuracy,4
drq,pass,0
fastNLP_Bert,pass,5
functorch_dp_cifar10,pass,0
Expand All @@ -28,7 +25,7 @@ hf_Reformer,pass,27
hf_T5_generate,pass,118
hf_T5_large,pass_due_to_skip,0
lennard_jones,pass,0
llama,pass,0
llama,fail_accuracy,0
maml_omniglot,pass,0
mnasnet1_0,pass,0
mobilenet_v2,pass,0
Expand All @@ -54,6 +51,5 @@ timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
tts_angular,pass,2
vgg16,pass,0
yolov3,pass,2
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name,accuracy,graph_breaks
BERT_pytorch,pass,0
Background_Matting,pass,0
LearningToPaint,pass,0
Super_SloMo,pass,0
alexnet,pass,0
attention_is_all_you_need_pytorch,pass,0
basic_gnn_edgecnn,pass,0
Expand All @@ -12,8 +11,6 @@ basic_gnn_sage,pass,0
cm3leon_generate,pass,67
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,4
doctr_reco_predictor,fail_accuracy,4
drq,pass,0
fastNLP_Bert,pass,5
functorch_dp_cifar10,pass,0
Expand All @@ -28,7 +25,7 @@ hf_Reformer,pass,27
hf_T5_generate,pass,118
hf_T5_large,pass_due_to_skip,0
lennard_jones,pass,0
llama,pass,0
llama,fail_accuracy,0
maml_omniglot,pass,0
mnasnet1_0,pass,0
mobilenet_v2,pass,0
Expand Down Expand Up @@ -57,7 +54,6 @@ timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
tts_angular,pass,2
vgg16,pass,0
vision_maskrcnn,fail_accuracy,58
yolov3,pass,2
53 changes: 33 additions & 20 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,10 @@ def non_deterministic_models(self):
def fp32_only_models(self):
return set()

@property
def force_amp_for_fp16_bf16_models(self):
return set()

@property
def skip_not_suitable_for_training_models(self):
return set()
Expand Down Expand Up @@ -1273,19 +1277,39 @@ def iter_models(self, args):
def deepcopy_model(self, model):
return copy.deepcopy(model)

def cast_based_on_args(self, model, example_inputs):
if self.args.float32 or self.args.only in self.fp32_only_models:
if not self.args.float32:
log.warning("Model %s supports float32 only", self.args.only)
model, example_inputs = cast_to_fp32(model, example_inputs)
elif self.args.float16:
if self.args.only in self.force_amp_for_fp16_bf16_models:
log.warning(
"Model %s does not support float16, running with amp instead",
self.args.only,
)
self.args.amp = True
self.setup_amp()
else:
model, example_inputs = cast_to_fp16(model, example_inputs)
elif self.args.bfloat16:
if self.args.only in self.force_amp_for_fp16_bf16_models:
log.warning(
"Model %s does not support bfloat16, running with amp instead",
self.args.only,
)
self.args.amp = True
self.setup_amp()
else:
model, example_inputs = cast_to_bf16(model, example_inputs)

def validate_model(self, model, example_inputs):
"""
Runs the eager model with example inputs to ensure that eager passes.
"""
model = self.deepcopy_model(model)
example_inputs = clone_inputs(example_inputs)
if self.args.float32:
model, example_inputs = cast_to_fp32(model, example_inputs)
elif self.args.float16:
model, example_inputs = cast_to_fp16(model, example_inputs)
elif self.args.bfloat16:
model, example_inputs = cast_to_bf16(model, example_inputs)

self.cast_based_on_args(model, example_inputs)
try:
self.model_iter_fn(model, example_inputs)
except Exception as e:
Expand All @@ -1294,12 +1318,7 @@ def validate_model(self, model, example_inputs):
def maybe_cast(self, model, example_inputs):
model = self.deepcopy_model(model)
example_inputs = clone_inputs(example_inputs)
if self.args.float32:
model, example_inputs = cast_to_fp32(model, example_inputs)
elif self.args.float16:
model, example_inputs = cast_to_fp16(model, example_inputs)
elif self.args.bfloat16:
model, example_inputs = cast_to_bf16(model, example_inputs)
self.cast_based_on_args(model, example_inputs)
return model, example_inputs

def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
Expand Down Expand Up @@ -2667,13 +2686,6 @@ def run(runner, args, original_dir=None):
current_batch_size = batch_size
set_model_name(name)

if args.float32:
model, example_inputs = cast_to_fp32(model, example_inputs)
elif args.float16:
model, example_inputs = cast_to_fp16(model, example_inputs)
elif args.bfloat16:
model, example_inputs = cast_to_bf16(model, example_inputs)

# Look for stuff that looks like batch size, and mark it dynamic.
# Better integration would integrate directly with benchmark suite
# but cannot conveniently do this
Expand Down Expand Up @@ -2705,6 +2717,7 @@ def detect_and_mark_batch(t):
args.per_process_memory_fraction
)

runner.cast_based_on_args(model, example_inputs)
runner.run_one_model(
name,
model,
Expand Down
9 changes: 9 additions & 0 deletions benchmarks/dynamo/timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def pip_install(package):
"sebotnet33ts_256",
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
"convit_base",
"xcit_large_24_p8_224",
}


def refresh_model_names():
import glob
Expand Down Expand Up @@ -169,6 +174,10 @@ def __init__(self):
super().__init__()
self.suite_name = "timm_models"

@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS

@download_retry_decorator
def _download_model(self, model_name):
model = create_model(
Expand Down
11 changes: 11 additions & 0 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ def setup_torchbench_cwd():
"pytorch_unet": 2,
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
"doctr_det_predictor",
"doctr_reco_predictor",
"Super_SloMo",
"tts_angular",
}


class TorchBenchmarkRunner(BenchmarkRunner):
def __init__(self):
Expand Down Expand Up @@ -249,6 +256,10 @@ def skip_not_suitable_for_training_models(self):
def failing_fx2trt_models(self):
return TRT_NOT_YET_WORKING

@property
def force_amp_for_fp16_bf16_models(self):
return FORCE_AMP_FOR_FP16_BF16_MODELS

@property
def skip_accuracy_checks_large_models_dashboard(self):
if self.args.dashboard or self.args.accuracy:
Expand Down

0 comments on commit 7cb1f26

Please sign in to comment.