In [1]:
import torch
from torch import Tensor
from jaxtyping import Float

a:Float[Tensor, "b t h"] = torch.randn(2, 3, 4)
b: Float[Tensor, "b t h"] = torch.randn(2, 3, 4)
attn: Float[Tensor, "b t"] = torch.randn(2, 3)
a_ref:  Float[Tensor, "b t h"] = torch.randn(2, 3, 4)
b_ref:  Float[Tensor, "b t h"] = torch.randn(2, 3, 4)
attn_ref: Float[Tensor, "b t"] = torch.randn(2, 3)

In [4]:
torch.log(torch.abs(a-b)/torch.abs(a_ref+b_ref)).mean()

tensor(-0.3211)

In [5]:
# should be closer
b2 = (a+b)/2
torch.log(torch.abs(a-b2)/torch.abs(a_ref+b_ref)).mean()

tensor(-1.0143)

: 

In [11]:
import torch
import numpy as np

def generate_small_distances(n=1000000, scale=1e-10):
    r =  torch.abs(torch.randn(n)) * scale
    r[:3] = 0 
    return r

def compare_stability(dist_ab, dist_ab_ref, eps=1e-10):
    # Approach 1: torch.log(dist_ab / dist_ab_ref)
    result1 = torch.log((dist_ab+eps) / (dist_ab_ref + eps))
    
    # Approach 2: dist_ab / dist_ab_ref
    result2 = dist_ab / (dist_ab_ref+eps)
    
    # Approach 3: torch.log(dist_ab) - torch.log(dist_ab_ref)
    result3 = torch.log(dist_ab+eps) - torch.log(dist_ab_ref+eps)
    
    return result1, result2, result3

# Generate very small distances
dist_ab = generate_small_distances()
dist_ab_ref = generate_small_distances()

# Compare stability
result1, result2, result3 = compare_stability(dist_ab, dist_ab_ref)

# Function to count non-finite values
def count_non_finite(tensor):
    return torch.sum(~torch.isfinite(tensor)).item()

print(f"Approach 1 non-finite values: {count_non_finite(result1)}")
print(f"Approach 2 non-finite values: {count_non_finite(result2)}")
print(f"Approach 3 non-finite values: {count_non_finite(result3)}")

# Calculate mean and std for each approach, ignoring non-finite values
def safe_stats(tensor):
    finite_values = tensor[torch.isfinite(tensor)]
    return finite_values.mean().item(), finite_values.std().item()

mean1, std1 = safe_stats(result1)
mean2, std2 = safe_stats(result2)
mean3, std3 = safe_stats(result3)

print(f"\nApproach 1: mean = {mean1:.6e}, std = {std1:.6e}")
print(f"Approach 2: mean = {mean2:.6e}, std = {std2:.6e}")
print(f"Approach 3: mean = {mean3:.6e}, std = {std3:.6e}")

Approach 1 non-finite values: 0
Approach 2 non-finite values: 0
Approach 3 non-finite values: 0

Approach 1: mean = -1.100063e-04, std = 4.488434e-01
Approach 2: mean = 4.900776e-01, std = 4.144593e-01
Approach 3: mean = -1.100062e-04, std = 4.488434e-01


In [12]:

def compare_stability2(dist_ab, dist_ab_ref, eps=1e-10):
    # Approach 1: torch.log(dist_ab / dist_ab_ref)
    result1 = torch.log(dist_ab.clamp_min(eps) / dist_ab_ref.clamp_min(eps))
    
    # Approach 2: dist_ab / dist_ab_ref
    result2 = dist_ab / (dist_ab_ref.clamp_min(eps))
    
    # Approach 3: torch.log(dist_ab) - torch.log(dist_ab_ref)
    result3 = torch.log(dist_ab.clamp_min(eps)) - torch.log(dist_ab_ref.clamp_min(eps))
    
    return result1, result2, result3


# Compare stability
result1, result2, result3 = compare_stability2(dist_ab, dist_ab_ref)

# Function to count non-finite values
def count_non_finite(tensor):
    return torch.sum(~torch.isfinite(tensor)).item()

print(f"Approach 1 non-finite values: {count_non_finite(result1)}")
print(f"Approach 2 non-finite values: {count_non_finite(result2)}")
print(f"Approach 3 non-finite values: {count_non_finite(result3)}")

# Calculate mean and std for each approach, ignoring non-finite values
def safe_stats(tensor):
    finite_values = tensor[torch.isfinite(tensor)]
    return finite_values.mean().item(), finite_values.std().item()

mean1, std1 = safe_stats(result1)
mean2, std2 = safe_stats(result2)
mean3, std3 = safe_stats(result3)

print(f"\nApproach 1: mean = {mean1:.6e}, std = {std1:.6e}")
print(f"Approach 2: mean = {mean2:.6e}, std = {std2:.6e}")
print(f"Approach 3: mean = {mean3:.6e}, std = {std3:.6e}")

Approach 1 non-finite values: 0
Approach 2 non-finite values: 0
Approach 3 non-finite values: 0

Approach 1: mean = -5.770955e-05, std = 3.300098e-01
Approach 2: mean = 7.220803e-01, std = 5.714583e-01
Approach 3: mean = -5.770929e-05, std = 3.300095e-01
