In [5]:
import torch

import torch.nn.functional as F

In [6]:
l = torch.tensor([[0.2640, 0.2640, 0.4719],
        [0.3422, 0.3422, 0.3156]])

def custom_softmax(x, dim=0):
    i_dim = (dim + 1) % len(l.shape)
    i_range = x.shape[i_dim]

    t = torch.zeros(x.shape)
    for idx in range(i_range):
        exp_values = torch.exp(x.select(i_dim, idx))
        exp_sum = exp_values.sum()
        
        soft_vec = exp_values/exp_sum
        # print(i_dim, torch.tensor([idx]), soft_vec.unsqueeze(i_dim))
        t.index_add_(i_dim, torch.tensor([idx]), soft_vec.unsqueeze(i_dim))
    return t

def simplified_softmax(x, dim=0):
    exp_x = torch.exp(x)

    # Sum along the specified dimension, keeping dimensions for proper broadcasting
    sum_exp = exp_x.sum(dim=dim, keepdim=True)
    
    # Divide to get softmax probabilities
    return exp_x / sum_exp

def c_softmax(x, dim=-1):
    """
    Custom implementation of softmax using basic PyTorch operations.
    
    Args:
        x (torch.Tensor): Input tensor
        dim (int): Dimension along which softmax will be computed
        
    Returns:
        torch.Tensor: Softmax output
    """
    # Numerical stability: subtract the max value before exponentiating
    x_max, _ = torch.max(x, dim=dim, keepdim=True)
    x_exp = torch.exp(x - x_max)
    
    # Normalize by the sum
    x_exp_sum = torch.sum(x_exp, dim=dim, keepdim=True)
    return x_exp / x_exp_sum

In [7]:
dim = 1
custom_softmax(l, dim=dim), F.softmax(l, dim=dim), c_softmax(l, dim=dim), simplified_softmax(l, dim=dim)

(tensor([[0.3095, 0.3095, 0.3810],
         [0.3363, 0.3363, 0.3274]]),
 tensor([[0.3095, 0.3095, 0.3810],
         [0.3363, 0.3363, 0.3274]]),
 tensor([[0.3095, 0.3095, 0.3810],
         [0.3363, 0.3363, 0.3274]]),
 tensor([[0.3095, 0.3095, 0.3810],
         [0.3363, 0.3363, 0.3274]]))

tensor([[0.3095, 0.3095, 0.3810],
        [0.3363, 0.3363, 0.3274]])

In [49]:
n = torch.tensor([[1, 2], [3, 4]])

exp_x = torch.exp(l)

exp_x/exp_x.sum(dim=1, keepdim=True)

tensor([[0.3095, 0.3095, 0.3810],
        [0.3363, 0.3363, 0.3274]])

In [51]:
exp_x.sum(dim=0, keepdim=True).shape

torch.Size([1, 3])