[MPS] Revamp copy_to_mps_ implementation #87475
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Tensor's view in linear storage is represented by the following parameters:
.shape
,.stride()
and.storage_offset()
.Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single
copy(from:sourceOffset:to:destinationOffset:size:)
call.Modify
copy_to_mps_
function to do the following steps:src
tensor to dst data type if neededsrc
tensor todst
tensor shapesrc
tensor if it is not stride contiguous (i.e. can not be represented bysrc.view(src.numel())
)dst
is not stride-contiguous or if its strides are different then potentially clonedsrc
stridessrc
to (potentiall temp)dst
Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor