# softmax layers

## IO of softmax

In [51]:
import tensorflow as tf

from tensorflow.keras.layers import Activation

logit = tf.random.uniform(shape=(8,5), minval=-10, maxval=10)

softmax_value = Activation('softmax')(logit)
softmax_sum = tf.reduce_sum(softmax_value, axis=1)
print(softmax_value) 


tf.Tensor([[1.3706939e-01 8.1360275e-01 3.1551328e-05 4.9295358e-02 9.3676255e-07]], shape=(1, 5), dtype=float32)


## softmax in dense layers

In [59]:
import tensorflow as tf

from tensorflow.keras.layers import Dense

logit = tf.random.uniform(shape=(8,5), minval=-10, maxval=10)
dense = Dense(units=8, activation='softmax')

Y = dense(logit)
print(tf.reduce_sum(Y, axis=1))

tf.Tensor([7 3 7 2 2 5 0 0], shape=(8,), dtype=int64)


# multi-class classifiers

## multi-class classifiers

In [60]:
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense

class TestModel(Model):
  def __init__(self):
    super(TestModel, self).__init__()
    
    self.dense1 = Dense(units=8, activation='relu')
    self.dense2 = Dense(units=5, activation='relu')
    self.dense3 = Dense(units=3, activation='softmax')

  def call(self, x):
    print('X: {}\n{}\n'.format(x.shape, x.numpy()))

    x = self.dense1(x)
    print("A1 : {}\n".format(x.numpy()))

    x = self.dense2(x)
    print("A2 : {}\n".format(x.numpy()))

    x = self.dense3(x)
    print("Y : {}\n".format(x.numpy()))
    print('Sum of vectors : {}\n'.format(tf.reduce_sum(x, axis=1)))

model = TestModel()

X = tf.random.uniform(shape=(8,5), minval=10, maxval=-10)
Y = model(X)

X: (8, 5)
[[ 1.7140179   7.4170184   6.4797664  -9.5126     -6.3325157 ]
 [-2.6711369  -1.211338   -2.8418732  -4.0103455   2.3601031 ]
 [-1.2444477  -9.154881    5.3363156  -4.6406555  -9.119862  ]
 [ 5.1413536   8.956084   -6.7682114  -7.6008034   8.244417  ]
 [-0.03538609  9.387014   -9.266825    2.7079368   6.508317  ]
 [ 4.4124103  -5.0472927   3.8880754   6.1586905  -3.834033  ]
 [ 2.25698     8.252085    1.4029579   0.7229023   7.770705  ]
 [ 1.9938517   6.6955447   4.7116804   1.5245132  -1.0377884 ]]

A1 : [[ 0.          0.          0.          0.          4.071772    1.4953319
   0.          0.        ]
 [ 0.          1.6312283   3.2774274   0.68266416  0.          0.
   2.0403724   0.        ]
 [11.564244    0.          0.43916714  0.          2.6332877   0.
   2.500523    0.        ]
 [ 0.         15.996227    0.72973454  6.4963164   0.          0.
   0.          0.        ]
 [ 0.         13.53635     0.9149485  12.174039    0.          4.8604426
   0.4398177   0.        ]
