In [5]:
import torch

def js_divergence(p, q, eps=1e-12):
    """
    Compute the Jensen-Shannon divergence between two 1D probability distributions p and q.
    
    Parameters:
        p (torch.Tensor): 1D tensor for the predicted probability distribution.
        q (torch.Tensor): 1D tensor for the target probability distribution.
        eps (float): Small constant to avoid numerical issues (log(0)).
        
    Returns:
        torch.Tensor: A scalar representing the JS divergence.
    """
    # Calculate the mixture distribution
    m = 0.5 * (p + q)
    
    # Compute KL divergences
    kl_pm = torch.sum(p * torch.log(p / torch.clamp(m, min=eps)))
    kl_qm = torch.sum(q * torch.log(q / torch.clamp(m, min=eps)))
    
    # Return the Jensen-Shannon divergence
    js = 0.5 * (kl_pm + kl_qm)
    return js

# -------------------------------
# Test Case 1: Low divergence
# p1 and q1 are very similar distributions.
p1 = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float32)
q1 = torch.tensor([0.21, 0.29, 0.5], dtype=torch.float32)

js_loss_low = js_divergence(p1, q1)
print("Test Case 1 - Low divergence:", js_loss_low.item())

# -------------------------------
# Test Case 2: High divergence
# p2 and q2 are very different distributions.
p2 = torch.tensor([0.9, 0.05, 0.05], dtype=torch.float32)
q2 = torch.tensor([0.000001, 0.90000999999, 0.9999999], dtype=torch.float32)

js_loss_high = js_divergence(p2, q2)
print("Test Case 2 - High divergence:", js_loss_high.item())


Test Case 1 - Low divergence: 0.00010334770195186138
Test Case 2 - High divergence: 0.8066102862358093
