In [1]:
import torch
import torch.nn.functional as f

In [13]:
a = torch.tensor([1, 2, 3], dtype=torch.float16)
b = torch.tensor([1, 2, 3], dtype=torch.float16)
c = torch.tensor([True, True, True])

a == b

tensor([True, True, True])

In [15]:
def is_valid_softmax(x: torch.Tensor):
    return torch.sum(x, dim=0).item() == 1


def are_tensors_equal(a, b):
    if all(a == b):
        return True
    return False


a = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float16)
a1, a2 = torch.split(a, 3) # this creates [1, 2, 3] and [4, 5, 6]

s1 = f.softmax(a1, dim=0)
s2 = f.softmax(a2, dim=0)
s = f.softmax(a, dim=0)

s_merged = torch.concat((s1, s2))
s_merged, s

(tensor([0.0900, 0.2448, 0.6650, 0.0900, 0.2448, 0.6650], dtype=torch.float16),
 tensor([0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338], dtype=torch.float16))

In [22]:
print(is_valid_softmax(s_merged))

False


In [23]:
print(is_valid_softmax(s))

True


In [17]:
a = 61000.0

In [19]:
b = torch.tensor([a, a], dtype=torch.bfloat16)
b

tensor([60928., 60928.], dtype=torch.bfloat16)

In [67]:
import torch
import numpy as np

def compare_softmax_denominators(x: torch.Tensor):
    """ 
    Compares the standard softmax denominator (dₙ) with the surrogate sequence (dₙ')
    to demonstrate their relationship.
    
    Args:
        x: Input tensor of shape (N,)
    Returns:
        d: Standard denominators at each step
        d_surrogate: Surrogate denominators at each step
    """
    N = len(x)
    
    # Initialize arrays to store denominators and maxes
    d = torch.zeros(N)          # Standard denominator sequence
    d_surrogate = torch.zeros(N)  # Surrogate denominator sequence
    m = torch.zeros(N)          # Running maximum values
    
    # Global maximum (used for standard sequence)
    m_global = torch.max(x)
    
    # Compute both sequences
    for i in range(N):
        # Update running maximum
        m[i] = torch.max(x[:i+1])
        
        # Standard denominator sequence using global max
        d[i] = torch.sum(torch.exp(x[:i+1] - m_global))
        
        # Surrogate denominator sequence using running max
        d_surrogate[i] = torch.sum(torch.exp(x[:i+1] - m[i]))
        
    return d, d_surrogate, m, m_global

def analyze_sequences(x: torch.Tensor):
    """
    Analyzes and prints detailed comparison of standard vs surrogate sequences.
    """
    d, d_surrogate, m, m_global = compare_softmax_denominators(x)
    
    print("\nInput sequence:", x.numpy())
    print("\nRunning maximum (mᵢ):", m.numpy())
    print("Global maximum (mₙ):", m_global.item())
    
    print("\nStep-by-step comparison:")
    print("idx | x[i] |    dᵢ   |   dᵢ'   | Scale Factor")
    print("-" * 50)
    
    for i in range(len(x)):
        # Scale factor between d and d'
        scale = torch.exp(m_global - m[i])
        print(f"{i:3d} | {x[i]:5.2f} | {d[i]:7.4f} | {d_surrogate[i]:7.4f} | {scale:7.4f}")
    
    # Verify final values match after scaling
    final_scale = torch.exp(m_global - m[-1])
    scaled_final_d_surrogate = d_surrogate[-1] * final_scale
    
    print("\nVerification:")
    print(f"Final dₙ:          {d[-1]:.6f}")
    print(f"Final dₙ' × scale: {scaled_final_d_surrogate:.6f}")
    print(f"Match: {torch.allclose(d[-1], scaled_final_d_surrogate)}")

# Example 1: Simple increasing sequence
print("\nExample 1: Simple increasing sequence")
x1 = torch.tensor([1.0, 3.0, 5.0])
analyze_sequences(x1)

# Example 2: Sequence with varying patterns
print("\nExample 2: Varying sequence")
x2 = torch.tensor([3, 1, 4, 7, 2])
analyze_sequences(x2)

def demonstrate_softmax_equivalence(x: torch.Tensor):
    """
    Demonstrates that both sequences lead to the same softmax probabilities.
    """
    N = len(x)
    
    # Standard softmax
    m_global = torch.max(x)
    exp_x = torch.exp(x - m_global)
    d_global = torch.sum(exp_x)
    softmax_standard = exp_x / d_global
    
    # Softmax using surrogate sequence
    m_local = torch.max(x)  # In this case same as global since we're at the end
    exp_x_surrogate = torch.exp(x - m_local)
    d_surrogate = torch.sum(exp_x_surrogate)
    softmax_surrogate = exp_x_surrogate / d_surrogate
    
    print("\nSoftmax Equivalence:")
    print("idx |  Standard  | Surrogate")
    print("-" * 35)
    for i in range(N):
        print(f"{i:3d} | {softmax_standard[i]:9.6f} | {softmax_surrogate[i]:9.6f}")
    
    print(f"\nOutputs match: {torch.allclose(softmax_standard, softmax_surrogate)}")

# Demonstrate softmax equivalence
print("\nExample 3: Softmax Equivalence")
x3 = torch.tensor([1.0, 3.0, 2.0, 5.0])
demonstrate_softmax_equivalence(x3)


Example 1: Simple increasing sequence

Input sequence: [1. 3. 5.]

Running maximum (mᵢ): [1. 3. 5.]
Global maximum (mₙ): 5.0

Step-by-step comparison:
idx | x[i] |    dᵢ   |   dᵢ'   | Scale Factor
--------------------------------------------------
  0 |  1.00 |  0.0183 |  1.0000 | 54.5981
  1 |  3.00 |  0.1537 |  1.1353 |  7.3891
  2 |  5.00 |  1.1537 |  1.1537 |  1.0000

Verification:
Final dₙ:          1.153651
Final dₙ' × scale: 1.153651
Match: True

Example 2: Varying sequence

Input sequence: [3 1 4 7 2]

Running maximum (mᵢ): [3. 3. 4. 7. 7.]
Global maximum (mₙ): 7

Step-by-step comparison:
idx | x[i] |    dᵢ   |   dᵢ'   | Scale Factor
--------------------------------------------------
  0 |  3.00 |  0.0183 |  1.0000 | 54.5981
  1 |  1.00 |  0.0208 |  1.1353 | 54.5981
  2 |  4.00 |  0.0706 |  1.4177 | 20.0855
  3 |  7.00 |  1.0706 |  1.0706 |  1.0000
  4 |  2.00 |  1.0773 |  1.0773 |  1.0000

Verification:
Final dₙ:          1.077319
Final dₙ' × scale: 1.077319
Match: True

Exam

0.9991452300000001