Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4 from JZ-LIANG/AutoParallel/support-…
Browse files Browse the repository at this point in the history
…blomm

add bf16 o2
  • Loading branch information
zhaoyinglia authored Mar 8, 2023
2 parents d276420 + ef4cc88 commit 48fd4a9
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def forward(ctx, *args, **kwargs):
check_variable_and_dtype(
Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum',
)

Expand Down Expand Up @@ -645,7 +645,7 @@ def backward(ctx, *args, **kwargs):
check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)

Expand Down Expand Up @@ -687,12 +687,15 @@ def backward(ctx, *args, **kwargs):
},
)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)

Expand Down
Loading

0 comments on commit 48fd4a9

Please sign in to comment.