diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 968231fc42e..8164d1ebb02 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -49,7 +49,8 @@ Tensor& copy_out( // Use direct copy fast path if broadcast is not needed and tensors are // non-empty if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) && - src.numel() > 0) { + src.numel() > 0 && out.nbytes() >= src.nbytes() && + tensors_have_same_dtype(src, out)) { std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); } else { ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { @@ -91,8 +92,9 @@ Tensor& copy_( // Use direct copy fast path if broadcast is not needed and tensors are // non-empty if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) && - src.numel() > 0) { - std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes()); + src.numel() > 0 && in.nbytes() >= src.nbytes() && + tensors_have_same_dtype(src, in)) { + std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); } else { ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { utils::apply_bitensor_elementwise_fn<