In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma, norm
from ipywidgets import interact, FloatSlider, Dropdown
import torch
import torch.nn as nn

# Initialize PyTorch's KL Divergence loss function
kl_loss_fn = nn.KLDivLoss(reduction='batchmean')

# Static Gamma distribution with shape parameter k=2 and scale=1
def static_distribution(x):
    return gamma.pdf(x, a=2, scale=1)

# Define different movable distributions (Gamma, Bimodal)
def movable_distribution(x, scale, shift, dist_type, mixing_coeff):
    if dist_type == 'Gamma':
        return gamma.pdf(x - shift, a=2, scale=scale)
    elif dist_type == 'Bimodal':
        # Bimodal distribution with the mixing coefficient applied
        return mixing_coeff * norm.pdf(x - shift, loc=-2, scale=scale) + (1 - mixing_coeff) * norm.pdf(x - shift, loc=3, scale=scale)
    else:
        raise ValueError("Unsupported distribution type")

# Convert NumPy arrays to PyTorch tensors
def numpy_to_torch(arr):
    return torch.tensor(arr, dtype=torch.float32)

# Function to calculate symmetrical KL divergence
def symmetrical_kl_divergence(static_dist, movable_dist):
    # Convert both distributions to torch tensors
    static_dist_tensor = numpy_to_torch(static_dist)
    movable_dist_tensor = numpy_to_torch(movable_dist)
    
    # Take log of both distributions
    log_static_dist_tensor = torch.log(static_dist_tensor)
    log_movable_dist_tensor = torch.log(movable_dist_tensor)
    
    # KL(static || movable)
    kl_div_static_to_movable = kl_loss_fn(log_static_dist_tensor, movable_dist_tensor)
    
    # KL(movable || static)
    kl_div_movable_to_static = kl_loss_fn(log_movable_dist_tensor, static_dist_tensor)
    
    # Symmetrical KL Divergence: average of both directions
    combined_kl_div = (kl_div_static_to_movable + kl_div_movable_to_static) / 2
    
    return combined_kl_div

# Interactive plotting function
def plot_distributions(scale, shift, dist_type, mixing_coeff):
    # X values range
    x = np.linspace(0, 40, 4000)  # X-axis range for the distributions
    
    # Static and movable distributions
    static_dist = static_distribution(x)
    movable_dist = movable_distribution(x, scale, shift, dist_type, mixing_coeff)
    
    # Avoid numerical issues by ensuring there are no zero values in the distributions
    static_dist = np.clip(static_dist, 1e-10, None)
    movable_dist = np.clip(movable_dist, 1e-10, None)
    
    # Compute the symmetrical KL Divergence
    combined_kl_div = symmetrical_kl_divergence(static_dist, movable_dist)
    
    # Plot the distributions
    plt.figure(figsize=(8, 6))
    plt.plot(x, static_dist, label='Static Gamma Distribution (k=2, scale=1)', color='blue')
    plt.plot(x, movable_dist, label=f'{dist_type} Distribution (scale={scale:.2f}, shift={shift:.2f}, mix={mixing_coeff:.2f})', color='red')
    
    # Show combined KL Divergence from PyTorch
    plt.title(f'Symmetrical KL Divergence (PyTorch): {combined_kl_div.item():.4f}')
    plt.legend()
    plt.xlabel('x')
    plt.ylabel('Probability Density')
    plt.grid(True)
    plt.xlim((0, 40))
    plt.show()

# Dropdown to select the type of movable distribution
distribution_dropdown = Dropdown(
    options=['Gamma', 'Bimodal'],
    value='Bimodal',  # Default to Bimodal so the slider makes sense
    description='Distribution:',
)

# Sliders for the scale, shift, and mixing coefficient of the movable distribution
interact(plot_distributions,
         scale=FloatSlider(min=0.5, max=5.0, step=0.1, value=1),
         shift=FloatSlider(min=-5.0, max=30.0, step=0.1, value=0),
         dist_type=distribution_dropdown,
         mixing_coeff=FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5, description='Mix Coeff'));


interactive(children=(FloatSlider(value=1.0, description='scale', max=5.0, min=0.5), FloatSlider(value=0.0, deâ€¦