Skip to content

Commit

Permalink
modified the python/paddle/amp/auto_cast.py to fix bug (PaddlePaddle#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tianhaodongbd committed Aug 2, 2023
1 parent eee6376 commit 35e76be
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def amp_guard(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
enable = False
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
Expand All @@ -382,6 +383,7 @@ def amp_guard(
cuda_version,
)
)
enable = False

amp_dtype = dtype
amp_global_state().amp_dtype = amp_dtype
Expand Down Expand Up @@ -572,6 +574,46 @@ def amp_decorate(
else:
return models, optimizers

# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)

# check device_type:
if not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_custom_place()
):
if optimizers is None:
return models
else:
return models, optimizers
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
if optimizers is None:
return models
else:
return models, optimizers
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
if optimizers is None:
return models
else:
return models, optimizers
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16' and not _is_gpu_float16_supported()) or (
dtype == 'bfloat16' and not _is_gpu_bfloat16_supported()
):
if optimizers is None:
return models
else:
return models, optimizers

models_is_list = False
if isinstance(models, paddle.nn.Layer):
models_is_list = False
Expand Down

0 comments on commit 35e76be

Please sign in to comment.