In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt

# 稀疏张量

稀疏张量主要用于高效存储多个0的情况<br>
- values: A 1D tensor with shape [N] 包含几个不为0的数.
- indices: A 2D tensor with shape [N, rank], 不为0数的下标.
- dense_shape: A 1D tensor with shape [shape], 矩阵的形状.


In [2]:
st1=tf.SparseTensor(
    indices=[[0,3],[2,4]],
    values=[10,20],
    dense_shape=[3,10]
)

所形成的稀疏张量如下：
![稀疏张量](https://tensorflow.google.cn/guide/images/sparse_tensor.png)

In [13]:
print(st1)

def pprint_tensor(st):
    s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
    for (index,value) in zip(st.indices,st.values):
        s += f"\n  %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
    return s+"}>"
pprint_tensor(st1)

SparseTensor(indices=tf.Tensor(
[[0 3]
 [2 4]], shape=(2, 2), dtype=int64), values=tf.Tensor([10 20], shape=(2,), dtype=int32), dense_shape=tf.Tensor([ 3 10], shape=(2,), dtype=int64))


'<SparseTensor shape=[3, 10] \n values={\n  [0, 3]: 10\n  [2, 4]: 20}>'

也可以同时用稀疏矩阵的直接矩阵表达形式来进行构造


In [8]:
st2=tf.sparse.from_dense([[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]])
print(pprint_tensor(st2))

<SparseTensor shape=[3, 4] 
 values={
  [0, 0]: 1
  [0, 3]: 8
  [2, 2]: 3}>


In [19]:
st_a = tf.SparseTensor(indices=[[0, 2], [3, 4]],
                       values=[31, 2], 
                       dense_shape=[4, 10])

st_b = tf.SparseTensor(indices=[[0, 2], [7, 0]],
                       values=[56, 38],
                       dense_shape=[4, 10])

# 稀疏矩阵加法
st_sum = tf.sparse.add(st_a, st_b)

print(pprint_tensor(st_sum))
# 稀疏矩阵乘法
st_c = tf.SparseTensor(indices=([0, 1], [1, 0], [1, 1]),
                       values=[13, 15, 17],
                       dense_shape=(2,2))
mb = tf.constant([[4], [6]])
product = tf.sparse.sparse_dense_matmul(st_c, mb)
print(st_c,mb,product)
# 稀疏矩阵连接
sparse_pattern_A = tf.SparseTensor(indices = [[2,4], [3,3], [3,4], [4,3], [4,4], [5,4]],
                         values = [1,1,1,1,1,1],
                         dense_shape = [8,5])
sparse_pattern_B = tf.SparseTensor(indices = [[0,2], [1,1], [1,3], [2,0], [2,4], [2,5], [3,5], 
                                              [4,5], [5,0], [5,4], [5,5], [6,1], [6,3], [7,2]],
                         values = [1,1,1,1,1,1,1,1,1,1,1,1,1,1],
                         dense_shape = [8,6])
sparse_pattern_C = tf.SparseTensor(indices = [[3,0], [4,0]],
                         values = [1,1],
                         dense_shape = [8,6])

sparse_patterns_list = [sparse_pattern_A, sparse_pattern_B, sparse_pattern_C]
sparse_pattern = tf.sparse.concat(axis=1, sp_inputs=sparse_patterns_list)
print(tf.sparse.to_dense(sparse_pattern))
# 稀疏矩阵元素加法
st2_plus_5 = tf.SparseTensor(
    st2.indices,
    st2.values + 5,
    st2.dense_shape)
print(tf.sparse.to_dense(st2_plus_5))


<SparseTensor shape=[4, 10] 
 values={
  [0, 2]: 87
  [3, 4]: 2
  [7, 0]: 38}>
SparseTensor(indices=tf.Tensor(
[[0 1]
 [1 0]
 [1 1]], shape=(3, 2), dtype=int64), values=tf.Tensor([13 15 17], shape=(3,), dtype=int32), dense_shape=tf.Tensor([2 2], shape=(2,), dtype=int64)) tf.Tensor(
[[4]
 [6]], shape=(2, 1), dtype=int32) tf.Tensor(
[[ 78]
 [162]], shape=(2, 1), dtype=int32)
tf.Tensor(
[[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]
 [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]
 [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]], shape=(8, 17), dtype=int32)
tf.Tensor(
[[ 6  0  0 13]
 [ 0  0  0  0]
 [ 0  0  8  0]], shape=(3, 4), dtype=int32)


# 和keras进行联用

仅需要在输入层设置`sparse=True`即可<br>
下面以神经网络为例子


In [23]:
x=tf.keras.Input(shape=(4,),sparse=True)
y=tf.keras.layers.Dense(4)(x)
model=tf.keras.Model(x,y)
sparse_data=tf.SparseTensor(
    indices= [(0,0),(0,1),(0,2),
               (4,3),(5,0),(5,1)],
    values=[1,1,1,1,1,1],
    dense_shape=(6,4)
)
model(sparse_data)
model.predict(sparse_data)

array([[ 0.47235915,  0.9101693 ,  0.37997937,  0.62414026],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.5417007 , -0.6604419 ,  0.23671693,  0.24461526],
       [ 0.4109371 ,  1.1005002 ,  0.26989496, -0.15205604]],
      dtype=float32)