## mindspore.ops.diag(input) -〉 Tensor
- 输入：x必须为mindspore的tensor。
- 返回：mindspore的tensor。

当输入为一维数组时：

In [1]:
import numpy as np
import mindspore as ms
import torch
import jax.numpy as jnp

input = np.array([1, 2, 3, 4])

y1 = ms.ops.diag(ms.tensor(input))
y2 = torch.diag(torch.tensor(input))
y3 = jnp.diag(input)
print ('mindspore output ->\n',y1)
print ('torch     output ->\n',y2)
print ('jax       output ->\n',y3)

mindspore output ->
 [[1 0 0 0]
 [0 2 0 0]
 [0 0 3 0]
 [0 0 0 4]]
torch     output ->
 tensor([[1, 0, 0, 0],
        [0, 2, 0, 0],
        [0, 0, 3, 0],
        [0, 0, 0, 4]])
jax       output ->
 [[1 0 0 0]
 [0 2 0 0]
 [0 0 3 0]
 [0 0 0 4]]


ms和jax不会输出类型。

当输入为二维数组时：

In [3]:
input = np.array([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y1 = ms.ops.diag(ms.tensor(input))
y2 = torch.diag(torch.tensor(input))
y3 = jnp.diag(input)
print ('mindspore output ->\n',y1)
print ('torch     output ->\n',y2)
print ('jax       output ->\n',y3)

mindspore output ->
 [[[[1 0 0]
   [0 0 0]
   [0 0 0]]

  [[0 2 0]
   [0 0 0]
   [0 0 0]]

  [[0 0 3]
   [0 0 0]
   [0 0 0]]]


 [[[0 0 0]
   [4 0 0]
   [0 0 0]]

  [[0 0 0]
   [0 5 0]
   [0 0 0]]

  [[0 0 0]
   [0 0 6]
   [0 0 0]]]


 [[[0 0 0]
   [0 0 0]
   [7 0 0]]

  [[0 0 0]
   [0 0 0]
   [0 8 0]]

  [[0 0 0]
   [0 0 0]
   [0 0 9]]]]
torch     output ->
 tensor([1, 5, 9])
jax       output ->
 [1 5 9]


当输入为多维数组时，与其他框架输出不一致。

In [10]:
input = np.array([1, 2, 3, 4])
out2 = torch.tensor([0])

torch.diag(torch.tensor(input), out = out2)
print ('torch     output ->\n',out2)

torch     output ->
 tensor([[1, 0, 0, 0],
        [0, 2, 0, 0],
        [0, 0, 3, 0],
        [0, 0, 0, 4]])


  torch.diag(torch.tensor(input), out = out2)


torch还提供了出参的方式，ms未支持。

In [11]:
y2 = torch.diag(torch.tensor(input), 1)
y3 = jnp.diag(input, 1)
print ('torch     output ->\n',y2)
print ('jax       output ->\n',y3)

torch     output ->
 tensor([[0, 1, 0, 0, 0],
        [0, 0, 2, 0, 0],
        [0, 0, 0, 3, 0],
        [0, 0, 0, 0, 4],
        [0, 0, 0, 0, 0]])
jax       output ->
 [[0 1 0 0 0]
 [0 0 2 0 0]
 [0 0 0 3 0]
 [0 0 0 0 4]
 [0 0 0 0 0]]


torch和jax还可以通过对角线偏移参数来控制对角线位置。ms未支持。

In [12]:
y1 = ms.ops.diag(input)

TypeError: Failed calling Diag with "Diag()(input=<class 'numpy.ndarray'>)".
The valid calling should be: 
"Diag()(input=<Tensor>)".

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/pipeline/pynative/pynative_utils.cc:1294 PrintTypeCastError


In [13]:
y2 = torch.diag(input)

TypeError: diag(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

In [14]:
y3 = jnp.diag(torch.tensor(input))

TypeError: Cannot interpret 'torch.int64' as a data type

当输入类型不正确时，报错信息torch最简洁明确。建议ms优化。