# softmax and safe_softmax

In [49]:
import torch 
x = torch.arange(4)
print(f"x is {x}")

#普通softmax
x_softmax = (x).exp()/(x).exp().sum()
print(f"x_softmax is {x_softmax}\n")

"""
safe softmax 同时除以全局最大的exp(max(x)),也就是在指数部分减去 x_max
"""
x_max = x.max()
x_safe_softmax = (x-x_max).exp()/(x-x_max).exp().sum()
print(f"x_safe_softmax is {x_safe_softmax}\n")


x is tensor([0, 1, 2, 3])
x_softmax is tensor([0.0321, 0.0871, 0.2369, 0.6439])

x_safe_softmax is tensor([0.0321, 0.0871, 0.2369, 0.6439])



# online softmax

In [50]:
"""
online softmax, 先算好N个数的softmax,再追加一个 x_{i+1}
"""
x_pre = x[:-1]
x_max_pre = x_pre.max()
x_sum_pre = (x_pre-x_max_pre).exp().sum()

x_max_cur = torch.max(x_max_pre,x[-1])
x_sum_cur = x_sum_pre * torch.exp(x_max_pre - x_max_cur) + torch.exp(x[-1] - x_max_cur)  #之前减去的不是全局最大值 

x_online_softmax = torch.exp(x-x_max_cur) / x_sum_cur
print(f"x_online_softmax is {x_online_softmax}")
assert torch.allclose(x_safe_softmax, x_online_softmax)

x_online_softmax is tensor([0.0321, 0.0871, 0.2369, 0.6439])


# 1d block online softmax

先写了 block_num = 2，然后用for循环进行了扩展

In [51]:
x_blocks = torch.split(x,split_size_or_sections=2,dim = 0)
# x_max_block0 = x_block[0].max()
# x_sum_block0 = torch.exp(x_block[0] - x_max_block0).sum()
# x_max_block1 = x_block[1].max()
# x_sum_block1 = torch.exp(x_block[1] - x_max_block1).sum() 
# x_max_global = torch.max(x_max_block0,x_max_block1)
# x_sum_global = x_sum_block0 * torch.exp(x_max_block0 - x_max_global) + x_sum_block1 * torch.exp(x_max_block1 - x_max_global)

x_max_old = torch.tensor(0.0)
x_sum_old = torch.tensor(0.0)
for x_block in x_blocks:
    x_max_block = x_block.max()
    x_max_new = torch.max(x_max_old,x_max_block)
    x_sum_new = x_sum_old * torch.exp(x_max_old - x_max_new) + torch.exp(x_block - x_max_new).sum()
    x_max_old = x_max_new
    x_sum_old = x_sum_new

print(f"x is {x},x_max_new is {x_max_new},x_sum_new is {x_sum_new}")
x_block_online_softmax  = torch.exp(x - x_max_old)/x_sum_old
print(f"x_block_online_softmax is {x_block_online_softmax}")
assert torch.allclose(x_block_online_softmax,x_softmax)

x is tensor([0, 1, 2, 3]),x_max_new is 3.0,x_sum_new is 1.5530017614364624
x_block_online_softmax is tensor([0.0321, 0.0871, 0.2369, 0.6439])


# 2d batch block online softmax

$S=QK^T$

Q (N,d)

S (N,N)

并行求解 k行的Q的全局 softmax，行与行之间是没有影响的，因为那都是不同的 query了，不应该互相产生影响

其实就是在 数据增加到k行的情况下继续刚刚的操作

In [52]:
x = torch.arange(16,dtype = torch.float32).reshape(4,4)

x_batch_blocks = torch.split(x,split_size_or_sections=2,dim = 1)

#为什么要做keepdim呢？
x_max_batch_block0,_ = x_batch_blocks[0].max(dim = 1,keepdim = True) #这里max张量了，还有个下标 
x_sum_batch_block0 = torch.exp(x_batch_blocks[0] - x_max_batch_block0).sum(dim =1,keepdim = True) #dim =1，就是列之间操作

x_max_batch_block1,_ = x_batch_blocks[1].max(dim = 1,keepdim = True) #这里max张量了，还有个下标 
x_sum_batch_block1 = torch.exp(x_batch_blocks[1] - x_max_batch_block1).sum(dim =1,keepdim = True) #dim =1，就是列之间操作

x_max_batch_block_update = torch.maximum(x_max_batch_block0,x_max_batch_block1)
#这里是对的 
x_sum_batch_block_update = x_sum_batch_block0 * torch.exp(x_max_batch_block0 - x_max_batch_block_update) + torch.exp(x_batch_blocks[1] - x_max_batch_block_update).sum(dim = 1,keepdim = True)

x_batch_online_softmax = torch.exp(x - x_max_batch_block_update) / x_sum_batch_block_update
x_direct_softmax = F.softmax(x, dim=1)
# print(f"x_batch_online_softmax is {x_batch_online_softmax}\n x_direct_softmax is {x_batch_online_softmax}")
assert torch.allclose(x_batch_online_softmax,x_direct_softmax)

改成用 for循环来实现，实现块数量扩展

In [53]:
x = torch.arange(16,dtype = torch.float32).reshape(4,4)

x_batch_blocks = torch.split(x,split_size_or_sections=2,dim = 1)

#为什么要做keepdim呢？
x_max_old = torch.zeros(x.shape[0],1)
x_sum_old = torch.zeros(x.shape[0],1)
for x_batch_block in x_batch_blocks:
    x_max_batch_block,_ = x_batch_block.max(dim = 1,keepdim = True) #这里max张量了，还有个下标 
    x_max_new = torch.maximum(x_max_batch_block,x_max_old)
    x_sum_new = x_sum_old * torch.exp(x_max_old - x_max_new) + torch.exp(x_batch_block - x_max_new).sum(dim = 1,keepdim = True)
    x_max_old = x_max_new
    x_sum_old = x_sum_new
x_batch_online_softmax = torch.exp(x - x_max_old) / x_sum_old


import torch.nn.functional as F
x_direct_softmax = F.softmax(x, dim=1)

print(f"x_batch_online_softmax is {x_batch_online_softmax}\n x_direct_softmax is {x_direct_softmax}")
assert torch.allclose(x_batch_online_softmax,x_direct_softmax)

x_batch_online_softmax is tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])
 x_direct_softmax is tensor([[0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])
