## mindspore.ops.split(tensor, split_size_or_sections, axis=0) -〉 tuple(Tensor)
根据指定的轴将输入Tensor切分成块。
- 输入：
    * tensor：mindspore的tensor。
    * split_size_or_sections：int, tuple(int), list(int)。
    * axis：int。
- 返回：tuple(Tensor)。

1、参数比较：
| mindspore   | torch       | jax         |
| :----:      | :----:      | :----:      |
| tensor      | tensor      | ary           |
| split_size_or_sections | split_size_or_sections  | indices_or_sections|
| axis        |  dim        | axis        |

2、返回值比较

实数计算：

In [2]:
import numpy as np
import mindspore as ms
import torch
import jax.numpy as jnp

input = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
print('input shape:', input.shape)

y1 = ms.ops.split(ms.tensor(input), (2,7))
y2 = torch.split(torch.tensor(input), (2,7))
y3 = jnp.split(input, (2,7))
print ('mindspore output:\n',y1)
print('\n')
print ('torch output:\n',y2)
print('\n')
print ('jax output:\n',y3)

input shape: (9,)
mindspore output:
 (Tensor(shape=[2], dtype=Int64, value= [1, 2]), Tensor(shape=[7], dtype=Int64, value= [3, 4, 5, 6, 7, 8, 9]))


torch output:
 (tensor([1, 2]), tensor([3, 4, 5, 6, 7, 8, 9]))


jax output:
 [Array([1, 2], dtype=int32), Array([3, 4, 5, 6, 7], dtype=int32), Array([8, 9], dtype=int32)]


ms与torch一致，split_size_or_sections是切片的shape。
jax的indices_or_sections，是切片的索引值。

In [4]:
import numpy as np
import mindspore as ms
import torch
import jax.numpy as jnp

input = np.array([[1, 2, 3, 4],
                  [5, 6, 7, 8]])
print('input shape:', input.shape)

y1 = ms.ops.split(ms.tensor(input), 2, 1)
y2 = torch.split(torch.tensor(input), 2, 1)
y3 = jnp.split(input, 2, 1)
print ('mindspore output:\n',y1)
print('\n')
print ('torch output:\n',y2)
print('\n')
print ('jax output:\n',y3)

input shape: (2, 4)
mindspore output:
 (Tensor(shape=[2, 2], dtype=Int64, value=
[[1, 2],
 [5, 6]]), Tensor(shape=[2, 2], dtype=Int64, value=
[[3, 4],
 [7, 8]]))


torch output:
 (tensor([[1, 2],
        [5, 6]]), tensor([[3, 4],
        [7, 8]]))


jax output:
 [Array([[1, 2],
       [5, 6]], dtype=int32), Array([[3, 4],
       [7, 8]], dtype=int32)]


* ms与jax不返回类型。
* ms返回值的value在最后，可读性差，建议关键信息提前。

3、报错信息比较

In [6]:
y1 = ms.ops.split(input, 2)

TypeError: expect `tensor` is a Tensor, but got <class 'numpy.ndarray'>

In [7]:
y2 = torch.split(input, 2)

AttributeError: 'numpy.ndarray' object has no attribute 'split'

In [8]:
input = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y3 = jnp.squeeze(input, 2)

TypeError: squeeze requires ndarray or scalar arguments, got <class 'list'> at position 0.

当输入类型不正确时，报错信息torch简洁明确。建议ms优化。