In [1]:
import numpy as np

In [2]:
def softmax1(logits, axis=1):
    '''unstable softmax'''
    
    exps = np.exp(logits)
    return exps / np.sum(exps, axis=axis, keepdims=True)

def softmax2(logits, axis=1):
    '''stable softmax'''
    
    exps = np.exp(logits - np.max(logits, axis=axis, keepdims=True))
    return exps / np.sum(exps, axis=axis, keepdims=True)

In [3]:
logits = np.linspace(-1, 500, 500).reshape([50, 10])
assert np.allclose(softmax1(logits), softmax2(logits)), "Softmaxes differ"

In [4]:
logits = np.linspace(-1, 5e5, 500).reshape([50, 10])
_ = softmax1(logits)
print(_[:3])

[[ 0. nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]
 [nan nan nan nan nan nan nan nan nan nan]]




In [5]:
_ = softmax2(logits)
print(_[:3])

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]


In [6]:
# softmax2 is stable version
softmax = softmax2

# last check
logits = np.linspace(-1, 5e5, 500).reshape([50, 10])
_ = softmax(logits)

Let't find a stable version of cross entropy

In [7]:
def softmax_crossentropy_with_logits1(logits, reference_answers):
    """Compute crossentropy from logits[batch,n_classes] and ids of correct answers"""
   
    # this is unstable code. I experienced numerical over and under flows
    logits_for_answers = logits[np.arange(len(logits)),reference_answers]
    xentropy = - logits_for_answers + np.log(np.sum(np.exp(logits),axis=-1))
    
    return xentropy

def softmax_crossentropy_with_logits2(logits, reference_answers):
    """Compute crossentropy from logits[batch,n_classes] and ids of correct answers"""
    
    proba = softmax(logits, axis=1)
    xentropy = -np.log(proba[range(len(logits)), reference_answers])
       
    return xentropy

def softmax_crossentropy_with_logits3(logits, reference_answers):
    """Compute crossentropy from logits[batch,n_classes] and ids of correct answers"""
   
    logits = logits - logits.max(axis=1, keepdims=True)
    logits_for_answers = logits[np.arange(len(logits)),reference_answers]
    xentropy = - logits_for_answers + np.log(np.sum(np.exp(logits),axis=-1))
    
    return xentropy

In [8]:
# check implementations are correct

logits = np.linspace(-1,1,500).reshape([50,10])
answers = np.arange(50)%10
loss1 = softmax_crossentropy_with_logits1(logits, answers) 
loss2 = softmax_crossentropy_with_logits2(logits, answers)
loss3 = softmax_crossentropy_with_logits3(logits, answers)

assert np.allclose(loss1, loss2), "loss1 != loss2"
assert np.allclose(loss2, loss3), "loss2 != loss3"

In [9]:
# Test stability

logits = np.linspace(-1,1e5,500).reshape([50,10])
answers = np.arange(50)%10

In [10]:
print(softmax_crossentropy_with_logits1(logits, answers)[:3])

[inf inf inf]




In [11]:
print(softmax_crossentropy_with_logits2(logits, answers)[:3])

[inf inf inf]




In [12]:
print(softmax_crossentropy_with_logits3(logits, answers)[:3])

[1803.6252505  1603.22244489 1402.81963928]


In [13]:
# softmax_crossentropy_with_logits3 is stable version
softmax_crossentropy_with_logits = softmax_crossentropy_with_logits3

# last check
print(softmax_crossentropy_with_logits(logits, answers)[:3])

[1803.6252505  1603.22244489 1402.81963928]


Let't find a stable version of cross entropy gradient

In [14]:
def grad_softmax_crossentropy_with_logits1(logits,reference_answers):
    """Compute crossentropy gradient from logits[batch,n_classes] and ids of correct answers"""
    ones_for_answers = np.zeros_like(logits)
    ones_for_answers[np.arange(len(logits)),reference_answers] = 1
    
    softmax = np.exp(logits) / np.exp(logits).sum(axis=-1,keepdims=True)
    
    return (- ones_for_answers + softmax) / logits.shape[0]

def grad_softmax_crossentropy_with_logits2(logits, reference_answers):
    """Compute crossentropy gradient from logits[batch,n_classes] and ids of correct answers"""
    ones_for_answers = np.zeros_like(logits)
    ones_for_answers[np.arange(len(logits)), reference_answers] = 1
        
    return (softmax(logits) - ones_for_answers) / logits.shape[0]

In [15]:
# check implementations are correct

logits = np.linspace(-1,1,500).reshape([50,10])
answers = np.arange(50)%10
grad1 = grad_softmax_crossentropy_with_logits1(logits, answers)
grad2 = grad_softmax_crossentropy_with_logits2(logits, answers)

assert np.allclose(grad1, grad2), "grad1 != grad2"

In [16]:
# Test stability

logits = np.linspace(-1,1e5,500).reshape([50,10])
answers = np.arange(50)%10

In [17]:
print(grad_softmax_crossentropy_with_logits1(logits, answers)[:3])

[[-0.02  0.    0.    0.     nan   nan   nan   nan   nan   nan]
 [  nan   nan   nan   nan   nan   nan   nan   nan   nan   nan]
 [  nan   nan   nan   nan   nan   nan   nan   nan   nan   nan]]




In [18]:
print(grad_softmax_crossentropy_with_logits2(logits, answers)[:3])

[[-2.00000000e-002  0.00000000e+000  0.00000000e+000  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]
 [ 0.00000000e+000 -2.00000000e-002  0.00000000e+000  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]
 [ 0.00000000e+000  0.00000000e+000 -2.00000000e-002  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]]


In [19]:
# softmax_crossentropy_with_logits3 is stable version
grad_softmax_crossentropy_with_logits = grad_softmax_crossentropy_with_logits2

# last check
print(grad_softmax_crossentropy_with_logits(logits, answers)[:3])

[[-2.00000000e-002  0.00000000e+000  0.00000000e+000  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]
 [ 0.00000000e+000 -2.00000000e-002  0.00000000e+000  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]
 [ 0.00000000e+000  0.00000000e+000 -2.00000000e-002  0.00000000e+000
   0.00000000e+000  0.00000000e+000  1.58318655e-263  1.71145201e-176
   1.85010919e-089  2.00000000e-002]]
