In [22]:
%run ../settings.py
POST_SETTINGS = {
        "image": "img/cross-entropy.jpg",
        "title": "Measuring Differences in Knowledge",
        "description": "Understanding Cross-Entropy and KL-Divergence",
        "category": "article",
        "tags": ["Information Theory", "Articles", "Tutorials"]
    }

# Cross Entropy intuition from a single bit

Suppose you made it to the final round of this seasons MasterChef,\
and you're serving your good friend Alice and your mortal enemy Bob an interesting challange...<br> 

They write down all the foods they're allergic to on a little pieces of paper and put them in a bag, along with the other audience members <br>
We now have a bag of words containing the following sets:

$A = \{ peanuts, shrimp, potato \} $

$B = \{ peanuts, kale, chicken \} $

$({A\cup B})^\complement = \{ salmon, apple, toast \} $

The show host tells you to draw 3 random words $\vec{z} = \set{ z_1, z_2,z_3}$ from the bag to make tonights banquet. \
Since Bob is planning to resurrect Hitler using misaligned AGI, we're interested in finding out which of the 3 ingredients are more likely to kill Bob.

</br>

The simple way we can do this is to measure the probability $p$ that $\vec{z}$ came from the set $B$ using a random variable $X \sim Ber(p)$, \
where $P(X = x) $ is the probability that $\vec{z}$ came from the set $B$.
Since there are 8 unique foods in the bag, there's $8\choose{3}$ = $\frac{8!}{3!(8-3)!} = 56$ possible  ways of choosing $\vec{z}$,  \
but only one way to choose exactly the 3 foods in $B$, making $p(x) = \frac{1}{56}$ 


In [23]:
# import matplotlib.pyplot as plt
# from matplotlib_venn import venn2

# A = {0, 1, 2}
# B = {0, 2, 4}

# intersection = B & A      # {1, 2}
# only_x = B - A            # {3}
# only_y = A - B            # {0}

# labels = ["peanuts", "shrimp", "kale", "potato", "chicken", "salmon", "apple", "toast"]


# elements = sorted(B | A)  # sorted for consistent mapping; gives [0, 1, 2, 3]
# label_mapping = {num: label for num, label in zip(elements, labels)}

# fig, ax = plt.subplots(figsize=(3.5, 3.5))
# fig.patch.set_facecolor('none')  # Set the figure background to transparent

# venn = venn2([B, A], ('B', 'A'), ax=ax)
# if venn.get_label_by_id('10'):
#     venn.get_label_by_id('10').set_text("\n".join(label_mapping[n] for n in only_x))
# if venn.get_label_by_id('01'):
#     venn.get_label_by_id('01').set_text("\n".join(label_mapping[n] for n in only_y))
# if venn.get_label_by_id('11'):
#     venn.get_label_by_id('11').set_text("\n".join(label_mapping[n] for n in intersection))
# venn.get_patch_by_id('10').set_color('#99d8c9')   
# venn.get_patch_by_id('01').set_color('#fc9272')   
# venn.get_patch_by_id('11').set_color('#fdae6b')   

# for patch in ['10', '01', '11']:
#     if venn.get_patch_by_id(patch) is not None:
#         venn.get_patch_by_id(patch).set_edgecolor('black')
#         venn.get_patch_by_id(patch).set_linewidth(1.5)


In [24]:
%%HTML
<div id="cross-entropy-demo" style="max-width: 800px; margin: 0 auto; font-family: Arial, sans-serif;">
    <h2 style="text-align: center;">Cross-Entropy: Interactive Visualization</h2>
    
    <div style="margin: 20px 0; text-align: center;">
        <p>Cross-entropy measures how well a predicted probability distribution matches the true distribution. Lower values indicate better predictions.</p>
    </div>
    
    <div style="display: flex; justify-content: space-between; align-items: start; margin: 30px 0;">
        <!-- Left side: Distribution Controls -->
        <div style="width: 45%;">
            <h3>Distributions</h3>
            <div style="margin-bottom: 20px;">
                <h4>True Distribution (Ground Truth)</h4>
                <div id="true-dist-controls" style="display: flex; flex-direction: column; gap: 10px;">
                    <!-- Sliders will be added here -->
                </div>
            </div>
            
            <div>
                <h4>Predicted Distribution</h4>
                <div id="pred-dist-controls" style="display: flex; flex-direction: column; gap: 10px;">
                    <!-- Sliders will be added here -->
                </div>
            </div>
            
            <div style="margin-top: 20px;">
                <button id="random-button" style="padding: 8px 15px; background: #4CAF50; color: white; border: none; border-radius: 5px; cursor: pointer; margin-right: 10px;">Random Distributions</button>
                <button id="reset-button" style="padding: 8px 15px; background: #f44336; color: white; border: none; border-radius: 5px; cursor: pointer;">Reset</button>
            </div>
        </div>
        
        <!-- Right side: Visualization -->
        <div style="width: 50%;">
            <h3>Cross-Entropy Visualization</h3>
            <div id="visualization" style="position: relative; height: 300px; border: 1px solid #ddd; margin-bottom: 20px; border-radius: 5px;">
                <canvas id="chart-canvas" width="400" height="300"></canvas>
            </div>
            
            <div style="background: #f5f5f5; padding: 15px; border-radius: 5px;">
                <h4>Results</h4>
                <div id="results" style="display: flex; flex-direction: column; gap: 10px;">
                    <div>
                        <strong>Cross-Entropy:</strong> <span id="cross-entropy-value">0</span>
                    </div>
                    <div>
                        <strong>KL Divergence:</strong> <span id="kl-divergence-value">0</span>
                    </div>
                    <div>
                        <strong>Entropy of True Distribution:</strong> <span id="entropy-value">0</span>
                    </div>
                </div>
            </div>
        </div>
    </div>
    
    <div style="margin: 30px 0; padding: 15px; background: #e8f5e9; border-radius: 5px;">
        <h3>Explanation</h3>
        <p>Cross-entropy (H) between true distribution P and predicted distribution Q is calculated as:</p>
        <p style="font-family: monospace; font-weight: bold; text-align: center;">H(P,Q) = -∑ P(x) * log(Q(x))</p>
        <div id="explanation-details" style="margin-top: 10px;">
            <p>When predictions match reality perfectly, cross-entropy equals the entropy of the true distribution.</p>
            <p>When predictions differ from reality, cross-entropy increases by the KL divergence amount.</p>
            <p>Try adjusting the sliders to see how changing distributions affects the cross-entropy!</p>
        </div>
    </div>
</div>

<style>
.slider-container {
    display: flex;
    align-items: center;
    gap: 10px;
}

.slider-container input {
    flex-grow: 1;
}

.slider-value {
    width: 40px;
    text-align: right;
}

.bar {
    position: absolute;
    bottom: 0;
    border-radius: 3px 3px 0 0;
    transition: height 0.3s ease, background-color 0.3s ease;
}

.probability-label {
    position: absolute;
    bottom: -25px;
    width: 100%;
    text-align: center;
    font-size: 12px;
}

.bar-true {
    background-color: rgba(54, 162, 235, 0.6);
    border: 1px solid rgba(54, 162, 235, 1);
}

.bar-pred {
    background-color: rgba(255, 99, 132, 0.6);
    border: 1px solid rgba(255, 99, 132, 1);
}

.highlight {
    animation: pulse 1s infinite alternate;
}

@keyframes pulse {
    0% { opacity: 0.7; }
    100% { opacity: 1; }
}

.bar-label {
    position: absolute;
    top: -20px;
    width: 100%;
    text-align: center;
    font-size: 12px;
    color: #333;
}

.category-label {
    position: absolute;
    bottom: -40px;
    width: 100%;
    text-align: center;
    font-size: 14px;
    font-weight: bold;
}

.measurement-point {
    position: absolute;
    width: 10px;
    height: 10px;
    background-color: yellow;
    border: 1px solid #333;
    border-radius: 50%;
    transform: translate(-50%, -50%);
    z-index: 10;
}

.explanation-highlight {
    background-color: #ffeb3b;
    padding: 2px 4px;
    border-radius: 3px;
}
</style>


In [25]:

%%javascript
// Initialize variables
const NUM_CATEGORIES = 4;
const CATEGORIES = ['A', 'B', 'C', 'D'];
let trueDistribution = Array(NUM_CATEGORIES).fill(1/NUM_CATEGORIES);
let predDistribution = Array(NUM_CATEGORIES).fill(1/NUM_CATEGORIES);

// Set up the controls
function setupControls() {
    const trueDistControls = document.getElementById('true-dist-controls');
    const predDistControls = document.getElementById('pred-dist-controls');
    
    // Clear existing controls
    trueDistControls.innerHTML = '';
    predDistControls.innerHTML = '';
    
    // Create sliders for true distribution
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        const container = document.createElement('div');
        container.className = 'slider-container';
        
        const label = document.createElement('label');
        label.textContent = `Category ${CATEGORIES[i]}:`;
        label.style.width = '100px';
        
        const slider = document.createElement('input');
        slider.type = 'range';
        slider.min = '0';
        slider.max = '100';
        slider.value = Math.round(trueDistribution[i] * 100);
        slider.className = 'true-slider';
        slider.dataset.index = i;
        
        const valueDisplay = document.createElement('span');
        valueDisplay.className = 'slider-value';
        valueDisplay.textContent = slider.value + '%';
        
        container.appendChild(label);
        container.appendChild(slider);
        container.appendChild(valueDisplay);
        
        trueDistControls.appendChild(container);
        
        // Add event listener
        slider.addEventListener('input', function() {
            valueDisplay.textContent = this.value + '%';
            trueDistribution[i] = parseInt(this.value) / 100;
            normalizeDistribution(trueDistribution, 'true');
            updateVisualization();
        });
    }
    
    // Create sliders for predicted distribution
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        const container = document.createElement('div');
        container.className = 'slider-container';
        
        const label = document.createElement('label');
        label.textContent = `Category ${CATEGORIES[i]}:`;
        label.style.width = '100px';
        
        const slider = document.createElement('input');
        slider.type = 'range';
        slider.min = '0';
        slider.max = '100';
        slider.value = Math.round(predDistribution[i] * 100);
        slider.className = 'pred-slider';
        slider.dataset.index = i;
        
        const valueDisplay = document.createElement('span');
        valueDisplay.className = 'slider-value';
        valueDisplay.textContent = slider.value + '%';
        
        container.appendChild(label);
        container.appendChild(slider);
        container.appendChild(valueDisplay);
        
        predDistControls.appendChild(container);
        
        // Add event listener
        slider.addEventListener('input', function() {
            valueDisplay.textContent = this.value + '%';
            predDistribution[i] = parseInt(this.value) / 100;
            normalizeDistribution(predDistribution, 'pred');
            updateVisualization();
        });
    }
}

// Normalize distribution to ensure it sums to 1
function normalizeDistribution(distribution, type) {
    const sum = distribution.reduce((a, b) => a + b, 0);
    
    if (sum > 0) {
        for (let i = 0; i < distribution.length; i++) {
            distribution[i] = distribution[i] / sum;
        }
        
        // Update slider values
        const sliders = document.querySelectorAll(`.${type}-slider`);
        const valueDisplays = document.querySelectorAll(`#${type}-dist-controls .slider-value`);
        
        sliders.forEach((slider, index) => {
            slider.value = Math.round(distribution[index] * 100);
            valueDisplays[index].textContent = slider.value + '%';
        });
    } else {
        // If all values are 0, reset to uniform distribution
        for (let i = 0; i < distribution.length; i++) {
            distribution[i] = 1 / distribution.length;
        }
    }
}

// Calculate cross-entropy: -∑ P(x) * log(Q(x))
function calculateCrossEntropy(trueProbs, predProbs) {
    let crossEntropy = 0;
    for (let i = 0; i < trueProbs.length; i++) {
        // Avoid log(0) by adding a small epsilon
        const epsilon = 1e-15;
        const predProb = Math.max(predProbs[i], epsilon);
        if (trueProbs[i] > 0) {
            crossEntropy -= trueProbs[i] * Math.log2(predProb);
        }
    }
    return crossEntropy;
}

// Calculate KL Divergence: ∑ P(x) * log(P(x)/Q(x))
function calculateKLDivergence(trueProbs, predProbs) {
    let kl = 0;
    for (let i = 0; i < trueProbs.length; i++) {
        // Avoid division by zero and log(0)
        const epsilon = 1e-15;
        const predProb = Math.max(predProbs[i], epsilon);
        const trueProb = Math.max(trueProbs[i], epsilon);
        if (trueProb > epsilon) {
            kl += trueProb * Math.log2(trueProb / predProb);
        }
    }
    return kl;
}

// Calculate entropy: -∑ P(x) * log(P(x))
function calculateEntropy(probs) {
    let entropy = 0;
    for (let i = 0; i < probs.length; i++) {
        const epsilon = 1e-15;
        const prob = Math.max(probs[i], epsilon);
        if (prob > epsilon) {
            entropy -= prob * Math.log2(prob);
        }
    }
    return entropy;
}

// Update the visualization
function updateVisualization() {
    const canvas = document.getElementById('chart-canvas');
    const ctx = canvas.getContext('2d');
    const width = canvas.width;
    const height = canvas.height;
    
    // Clear canvas
    ctx.clearRect(0, 0, width, height);
    
    // Draw grid lines
    ctx.strokeStyle = '#e0e0e0';
    ctx.beginPath();
    for (let i = 0; i < 10; i++) {
        const y = height - (i * height / 10);
        ctx.moveTo(0, y);
        ctx.lineTo(width, y);
    }
    ctx.stroke();
    
    // Draw y-axis labels
    ctx.fillStyle = '#666';
    ctx.font = '12px Arial';
    ctx.textAlign = 'right';
    for (let i = 0; i <= 10; i++) {
        const y = height - (i * height / 10);
        ctx.fillText((i / 10).toFixed(1), 25, y + 4);
    }
    
    // Draw bars
    const barWidth = 30;
    const padding = 40;
    const totalWidth = NUM_CATEGORIES * (2 * barWidth + 10) + 2 * padding;
    const startX = (width - totalWidth) / 2 + padding;
    
    // Draw category labels
    ctx.textAlign = 'center';
    ctx.fillStyle = '#333';
    ctx.font = '14px Arial';
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        const x = startX + i * (2 * barWidth + 10) + barWidth;
        ctx.fillText(CATEGORIES[i], x, height - 10);
    }
    
    // Draw true distribution bars
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        const barHeight = trueDistribution[i] * (height - 80);
        const x = startX + i * (2 * barWidth + 10);
        const y = height - 40 - barHeight;
        
        ctx.fillStyle = 'rgba(54, 162, 235, 0.7)';
        ctx.fillRect(x, y, barWidth, barHeight);
        
        ctx.strokeStyle = 'rgba(54, 162, 235, 1)';
        ctx.strokeRect(x, y, barWidth, barHeight);
        
        // Add label
        ctx.fillStyle = '#333';
        ctx.font = '12px Arial';
        ctx.fillText(`${(trueDistribution[i] * 100).toFixed(0)}%`, x + barWidth/2, y - 5);
    }
    
    // Draw predicted distribution bars
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        const barHeight = predDistribution[i] * (height - 80);
        const x = startX + i * (2 * barWidth + 10) + barWidth;
        const y = height - 40 - barHeight;
        
        ctx.fillStyle = 'rgba(255, 99, 132, 0.7)';
        ctx.fillRect(x, y, barWidth, barHeight);
        
        ctx.strokeStyle = 'rgba(255, 99, 132, 1)';
        ctx.strokeRect(x, y, barWidth, barHeight);
        
        // Add label
        ctx.fillStyle = '#333';
        ctx.font = '12px Arial';
        ctx.fillText(`${(predDistribution[i] * 100).toFixed(0)}%`, x + barWidth/2, y - 5);
    }
    
    // Draw legend
    ctx.fillStyle = 'rgba(54, 162, 235, 0.7)';
    ctx.fillRect(width - 180, 20, 15, 15);
    ctx.strokeStyle = 'rgba(54, 162, 235, 1)';
    ctx.strokeRect(width - 180, 20, 15, 15);
    
    ctx.fillStyle = 'rgba(255, 99, 132, 0.7)';
    ctx.fillRect(width - 180, 45, 15, 15);
    ctx.strokeStyle = 'rgba(255, 99, 132, 1)';
    ctx.strokeRect(width - 180, 45, 15, 15);
    
    ctx.fillStyle = '#333';
    ctx.textAlign = 'left';
    ctx.fillText('True Distribution', width - 160, 32);
    ctx.fillText('Predicted Distribution', width - 160, 57);
    
    // Calculate and display results
    const crossEntropy = calculateCrossEntropy(trueDistribution, predDistribution);
    const klDivergence = calculateKLDivergence(trueDistribution, predDistribution);
    const entropy = calculateEntropy(trueDistribution);
    
    document.getElementById('cross-entropy-value').textContent = crossEntropy.toFixed(4);
    document.getElementById('kl-divergence-value').textContent = klDivergence.toFixed(4);
    document.getElementById('entropy-value').textContent = entropy.toFixed(4);
    
    // Highlight values based on how well predictions match reality
    const crossEntropyEl = document.getElementById('cross-entropy-value');
    if (crossEntropy < entropy + 0.1) {
        crossEntropyEl.style.color = 'green';
    } else if (crossEntropy < entropy + 0.5) {
        crossEntropyEl.style.color = 'orange';
    } else {
        crossEntropyEl.style.color = 'red';
    }
    
    // Update explanation
    updateExplanation(crossEntropy, klDivergence, entropy);
}

// Generate random distributions
function randomizeDistributions() {
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        trueDistribution[i] = Math.random();
        predDistribution[i] = Math.random();
    }
    
    normalizeDistribution(trueDistribution, 'true');
    normalizeDistribution(predDistribution, 'pred');
    updateVisualization();
}

// Reset to uniform distributions
function resetDistributions() {
    for (let i = 0; i < NUM_CATEGORIES; i++) {
        trueDistribution[i] = 1 / NUM_CATEGORIES;
        predDistribution[i] = 1 / NUM_CATEGORIES;
    }
    
    normalizeDistribution(trueDistribution, 'true');
    normalizeDistribution(predDistribution, 'pred');
    updateVisualization();
}

// Add detailed explanation based on current values
function updateExplanation(crossEntropy, klDivergence, entropy) {
    const explanationDiv = document.getElementById('explanation-details');
    
    // Relationship between cross-entropy, entropy, and KL divergence
    const relationship = `
        <p>Cross-entropy (${crossEntropy.toFixed(2)}) = Entropy (${entropy.toFixed(2)}) + KL Divergence (${klDivergence.toFixed(2)})</p>
    `;
    
    // Add details about current state
    let details = '';
    if (Math.abs(crossEntropy - entropy) < 0.01) {
        details = `
            <p><span class="explanation-highlight">Perfect prediction!</span> Your predicted distribution matches the true distribution exactly.</p>
            <p>When predictions match reality perfectly, cross-entropy equals the entropy of the true distribution.</p>
        `;
    } else if (klDivergence < 0.2) {
        details = `
            <p><span class="explanation-highlight">Good prediction!</span> Your predicted distribution is close to the true distribution.</p>
            <p>The extra "penalty" (KL divergence) is small: ${klDivergence.toFixed(4)}</p>
        `;
    } else {
        // Find the worst predicted category
        let maxDiff = 0;
        let worstCategory = 0;
        for (let i = 0; i < NUM_CATEGORIES; i++) {
            const diff = Math.abs(trueDistribution[i] - predDistribution[i]);
            if (diff > maxDiff) {
                maxDiff = diff;
                worstCategory = i;
            }
        }
        
        details = `
            <p><span class="explanation-highlight">Poor prediction!</span> Your predicted distribution is quite different from the true distribution.</p>
            <p>The largest mismatch is in category ${CATEGORIES[worstCategory]}, where the true probability is ${(trueDistribution[worstCategory] * 100).toFixed(0)}% but you predicted ${(predDistribution[worstCategory] * 100).toFixed(0)}%.</p>
            <p>This results in a high KL divergence penalty: ${klDivergence.toFixed(4)}</p>
        `;
    }
    
    // Real-world usage examples
    const examples = `
        <p><strong>Real-world application:</strong> Cross-entropy is commonly used as a loss function in classification tasks like image recognition, language modeling, and recommendation systems.</p>
    `;
    
    explanationDiv.innerHTML = relationship + details + examples;
}

// Add event listeners
document.getElementById('random-button').addEventListener('click', randomizeDistributions);
document.getElementById('reset-button').addEventListener('click', resetDistributions);

// Initial setup
setupControls();
updateVisualization();

<IPython.core.display.Javascript object>

However, there's significant overlap in the foods both Bob and Alice are allergic to, and Alice is currently training a machine learning model that cures cancer, \
so we need a better way of analyzing which of the ingridients will avert certain doom.


Let's add another random variable $Y \sim Ber(q)$, where $Q(Y = y) $ is a measure of whether $\vec{z}$ is in $A$ \
similarly, $q(y)$ = $\frac{1}{56}$


# Cross Entropy
cross-entropy tells us how surprised we would be if z came from X, if we assumed it came from Y, <br>
in other words: if we need one bit of information to encode z being in Y, then whats the probability that it came from X? 

>$H(p,q) = {p(z) \log q(z)}$ 

