# Speculative Decoding Experiments

## Simulation

In this section, we simulate the distribution resulting from the use of speculative decoding and compare it with $p(x)$.

In [None]:
import random
import matplotlib.pyplot as plt

# Define simple discrete distributions (example)
p_x = {0: 0.4, 1: 0.3, 2: 0.2, 3: 0.1}  # Target distribution
q_x = {0: 0.3, 1: 0.4, 2: 0.2, 3: 0.1}  # Proposal distribution
NUM_SAMPLES = 100000  # Number of samples for empirical distribution

# Simulate speculative sampling
samples = []
for _ in range(NUM_SAMPLES):
    # Sample from q(x)
    x = random.choices(list(q_x.keys()), weights=list(q_x.values()), k=1)[0]
    
    # Acceptance check
    if q_x[x] <= p_x[x]:
        samples.append(x)
    else:
        # Rejection with probability 1 - p(x)/q(x)
        rejection_prob = 1 - (p_x[x] / q_x[x])
        if random.random() > rejection_prob:
            samples.append(x)
        else:
            # Resample from adjusted distribution p'(x)
            residual = max(0, p_x[x] - q_x[x])
            total_residual = sum(max(0, p_x[x_] - q_x[x_]) for x_ in p_x)
            p_prime = {x_: max(0, p_x[x_] - q_x[x_]) / total_residual for x_ in p_x} if total_residual > 0 else p_x
            x_new = random.choices(list(p_prime.keys()), weights=list(p_prime.values()), k=1)[0]
            samples.append(x_new)

# Compute empirical distribution
empirical_dist = {x: 0. for x in p_x}
for x in samples:
    empirical_dist[x] += 1
for x in empirical_dist:
    empirical_dist[x] /= NUM_SAMPLES

# Plotting
x_values = list(p_x.keys())
plt.bar([x - 0.2 for x in x_values], [p_x[x] for x in x_values], 0.4, label='p(x)', color='blue')
plt.bar([x + 0.2 for x in x_values], [empirical_dist[x] for x in x_values], 0.4, label='Empirical', color='orange')
plt.xlabel('Outcome')
plt.ylabel('Probability')
plt.title('Comparison of p(x) and Empirical Distribution from Speculative Sampling')
plt.legend()
plt.xticks(x_values)
plt.show()

## Calculation

In this section, we calculate the exact probability of each event occurring and confirm that it matches $p$.

In [None]:
# Define the distributions
p_x = {0: 0.4, 1: 0.3, 2: 0.2, 3: 0.1}  # Target distribution
q_x = {0: 0.3, 1: 0.4, 2: 0.2, 3: 0.1}  # Proposal distribution

# Calculate the total residual probability R
R = sum(max(0, p_x[x] - q_x[x]) for x in p_x)

# Initialize probability dictionary
P_x = {x: 0.0 for x in p_x}

# Calculate probabilities for each x
for x in p_x:
    # Acceptance probability
    accept_prob = q_x[x] * min(1, p_x[x] / q_x[x])
    P_x[x] += accept_prob
    
    # Rejection and resampling probability
    reject_prob = 1 - min(1, p_x[x] / q_x[x])
    if reject_prob > 0:
        residual = max(0, p_x[x] - q_x[x])
        if R > 0:
            p_prime_x = residual / R if residual > 0 else 0
            # Contribution from rejection and resampling to all x
            for x_prime in p_x:
                if max(0, p_x[x_prime] - q_x[x_prime]) > 0:
                    P_x[x_prime] += q_x[x] * reject_prob * (max(0, p_x[x_prime] - q_x[x_prime]) / R)

# Print results
print("Target distribution p(x):", p_x)
print("Calculated distribution P(x) from speculative sampling:", P_x)
for x in p_x:
    diff = abs(p_x[x] - P_x[x])
    print(f"Difference for x={x}: {diff:.6f}")