In [1]:
import torch
import torch.nn.functional as F

def softmax_one(x, dim):
    x = x - x.max(dim=dim, keepdim=True).values
    exp_x = torch.exp(x)
    return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))

softmax = F.softmax


In [2]:
scores = torch.FloatTensor([ 0.4202, -2.1783,  1.0923,  1.0282,  1.6933, -0.2471,  0.2865, -0.2099,
         1.6876,  0.8158,  0.8784, -0.8089,  0.9769, -0.6090, -1.2419,  1.1895,
         0.0505,  2.3182, -0.1766, -0.5399,  0.1346, -0.0824,  1.1482,  1.0165,
         0.9373, -0.0946, -0.3365,  1.2053,  0.4422, -0.5857, -2.3086, -1.1120])

print(scores)
print(F.softmax(scores, dim=-1))
print(softmax_one(scores, dim=-1))


tensor([ 0.4202, -2.1783,  1.0923,  1.0282,  1.6933, -0.2471,  0.2865, -0.2099,
         1.6876,  0.8158,  0.8784, -0.8089,  0.9769, -0.6090, -1.2419,  1.1895,
         0.0505,  2.3182, -0.1766, -0.5399,  0.1346, -0.0824,  1.1482,  1.0165,
         0.9373, -0.0946, -0.3365,  1.2053,  0.4422, -0.5857, -2.3086, -1.1120])
tensor([0.0239, 0.0018, 0.0468, 0.0439, 0.0853, 0.0123, 0.0209, 0.0127, 0.0848,
        0.0355, 0.0378, 0.0070, 0.0417, 0.0085, 0.0045, 0.0515, 0.0165, 0.1594,
        0.0132, 0.0091, 0.0180, 0.0144, 0.0495, 0.0434, 0.0401, 0.0143, 0.0112,
        0.0524, 0.0244, 0.0087, 0.0016, 0.0052])
tensor([0.0206, 0.0015, 0.0403, 0.0378, 0.0736, 0.0106, 0.0180, 0.0110, 0.0732,
        0.0306, 0.0326, 0.0060, 0.0359, 0.0074, 0.0039, 0.0445, 0.0142, 0.1375,
        0.0113, 0.0079, 0.0155, 0.0125, 0.0427, 0.0374, 0.0346, 0.0123, 0.0097,
        0.0452, 0.0211, 0.0075, 0.0013, 0.0045])


In [6]:
import math

scores = torch.FloatTensor([-math.inf, -math.inf,  1.0923,  1.0282,  1.6933, -0.2471,  0.2865, -0.2099,
         1.6876,  0.8158,  0.8784, -0.8089,  0.9769, -0.6090, -1.2419,  1.1895,
         0.0505,  2.3182, -0.1766, -0.5399,  0.1346, -0.0824,  1.1482,  1.0165,
         0.9373, -0.0946, -0.3365,  1.2053,  0.4422, -0.5857, -2.3086, -1.1120])

print(scores)
print(F.softmax(scores, dim=-1))
print(softmax_one(scores, dim=-1))


tensor([   -inf,    -inf,  1.0923,  1.0282,  1.6933, -0.2471,  0.2865, -0.2099,
         1.6876,  0.8158,  0.8784, -0.8089,  0.9769, -0.6090, -1.2419,  1.1895,
         0.0505,  2.3182, -0.1766, -0.5399,  0.1346, -0.0824,  1.1482,  1.0165,
         0.9373, -0.0946, -0.3365,  1.2053,  0.4422, -0.5857, -2.3086, -1.1120])
tensor([0.0000, 0.0000, 0.0480, 0.0450, 0.0876, 0.0126, 0.0214, 0.0131, 0.0871,
        0.0364, 0.0388, 0.0072, 0.0428, 0.0088, 0.0047, 0.0529, 0.0169, 0.1636,
        0.0135, 0.0094, 0.0184, 0.0148, 0.0508, 0.0445, 0.0411, 0.0146, 0.0115,
        0.0537, 0.0251, 0.0090, 0.0016, 0.0053])
tensor([0.0000, 0.0000, 0.0413, 0.0387, 0.0753, 0.0108, 0.0184, 0.0112, 0.0748,
        0.0313, 0.0333, 0.0062, 0.0368, 0.0075, 0.0040, 0.0455, 0.0146, 0.1406,
        0.0116, 0.0081, 0.0158, 0.0127, 0.0436, 0.0382, 0.0353, 0.0126, 0.0099,
        0.0462, 0.0215, 0.0077, 0.0014, 0.0046])


In [19]:
dim = 8
scores = torch.randn((dim,dim))
print(scores)
mask = torch.full((dim,dim), float("-inf"))
mask = torch.triu(mask, diagonal=1)
scores = scores + mask
print(scores)
print(F.softmax(scores, dim=-1))
print(softmax_one(scores, dim=-1))

tensor([[ 0.8841,  0.7231, -0.6898, -1.8028,  0.8936, -1.9256, -0.9440, -0.1231],
        [ 0.3500, -1.6226, -0.4200,  0.8387, -0.1396,  0.4772,  0.0494,  0.7336],
        [ 0.5531,  1.0931, -0.0467,  0.7880,  1.8305, -0.0795, -1.2711,  1.1082],
        [-1.7921, -0.0812,  1.7433, -0.4884,  1.3150, -0.0415, -0.9348, -2.4931],
        [-0.7898, -0.1579, -1.0661,  1.0456,  0.2785, -0.5035, -1.1072, -0.2374],
        [-0.3552, -0.5652, -0.1526, -0.4452, -0.2718, -0.9043, -0.4288, -0.3006],
        [-1.2123,  1.8237,  0.6699,  0.1181,  0.2071, -0.7324,  1.1192,  0.6670],
        [ 0.3052, -0.7259, -0.4036,  0.4451, -0.8335, -1.4056,  1.2433, -1.2111]])
tensor([[ 0.8841,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.3500, -1.6226,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.5531,  1.0931, -0.0467,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.7921, -0.0812,  1.7433, -0.4884,    -inf,    -inf,    -inf,    -inf],
        [-0.789

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

out = F.softmax(scores, dim=-1).tolist()
out2 = softmax_one(scores, dim=-1).tolist()
rows = len(out)
fig, axes = plt.subplots(rows,1,figsize=(6,rows*2))



for i in range(rows):  
    
    data = pd.DataFrame({'softmax': out[i],
                   'softmax_one': out2[i]})
    sns.lineplot(data=data, ax = axes[i])
