# EM Algorithm

In [4]:
EM Algorithm

import matplotlib.pyplot as plt
import numpy as np

# Step 1: Load in data
# Values are space-separated and there are 100 lines in the file:
X = np.loadtxt('/datasets/t1cw-data/binarydigits.txt', dtype=int)
N, D = X.shape # Number of samples and number of variables
S = np.sum(X, axis=0) # Sum of ones per variable

# Step 2: Define the log-Bernoulli PDF
def log_bernoulli_pdf(X, P):
    # X: NxD data, P: KxD Bernoulli parameter matrix
    # Returns NxK log pdf matrix
    # Avoid log(0) by clipping P away from 0/1
    eps = 1e-8
    P = np.clip(P, eps, 1-eps)
    log_p = np.log(P)
    log_1mp = np.log(1-P)
    return (X @ log_p.T) + ((1-X) @ log_1mp.T) # NxK

# Step 3: EM Algorithm with weak symmetric priors
def em_map(X, K, max_iter=1000, tol=1e-4, alpha=1.01, beta=1.01, plot_trace=True, verbose=False): # Default values for alpha and beta are 1.01
    """
    MAP Estimation for a mixture of K multivariate Bernoullis using (alpha, beta) weak symmetric priors.
    alpha: prior for mixture proportions (Dirichlet),
    beta: prior for Bernoulli parameters (Beta for each dimension).
    """
    rng = np.random.default_rng() # Initialize random number generator
    alpha = alpha
    beta = beta
    # Initialize
    pi = np.ones(K) / K # Uniform prior for mixture proportions
    P = rng.uniform(0.25, 0.75, size=(K, D)) # Uniform prior for Bernoulli parameters
    log_trace = [] # Initialize log-trace to store log-posterior values

    for it in range(max_iter): # Iterate over max_iter
        # E-step: responsibilities as before
        log_prob = log_bernoulli_pdf(X, P) + np.log(pi) # Calculate log-posterior
        log_prob_max = np.max(log_prob, axis=1, keepdims=True) # Calculate log-posterior maximum
        prob = np.exp(log_prob - log_prob_max)
        gamma = prob / np.sum(prob, axis=1, keepdims=True) # Calculate responsibilities

        # M-step: Update parameters with priors
        Nk = np.sum(gamma, axis=0) # Effective counts
        N = np.sum(Nk) # Total number of samples
        N = N + 1e-12 # Avoid division by zero
        Nk = Nk + 1e-12 # Avoid division by zero

        # Dirichlet prior on pi: add (alpha-1) "pseudo-counts" to avoid log(0)
        pi = (Nk + (alpha - 1)) / (N + K * (alpha - 1)) # Update mixture proportions

        # Beta prior on Bernoulli parameters: each parameter beta[success, failure] = [beta, beta] to avoid log(0)
        P = (gamma.T @ X + (beta - 1)) / (Nk[:, None] + 2 * (beta - 1)) # Update Bernoulli parameters
        P = np.clip(P, 1e-6, 1-1e-6) # Clip Bernoulli parameters to avoid log(0)

        # Log-posterior tracking
        log_prior_pi = (alpha - 1) * np.sum(np.log(pi+1e-16)) # Add (alpha-1) "pseudo-counts" to avoid log(0)
        log_prior_P = (beta - 1) * np.sum(np.log(P+1e-16) + np.log(1-P+1e-16)) # Add (beta-1) "pseudo-counts" to avoid log(0)
        loglike = np.sum(np.log(np.sum(np.exp(log_bernoulli_pdf(X, P) + np.log(pi)), axis=1))) # Calculate log-likelihood
        log_trace.append(loglike + log_prior_pi + log_prior_P) # Append log-posterior to log-trace
        if it > 1 and np.abs(log_trace[-1] - log_trace[-2]) < tol: # Check if log-posterior has converged
            break # Break if log-posterior has converged
        if verbose:
            print(f"Iter {it}: log-posterior = {log_trace[-1]:.2f}")

    # Step 4: Plot log-trace
    if plot_trace:
        plt.plot(log_trace, marker='o') # Plot log-trace
        plt.title(f"EM Log-Posterior Trace, K={K}")
        plt.xlabel("Iteration")
        plt.ylabel("Log-Posterior")
        plt.show()

    return loglike, pi, P, log_trace # Return log-likelihood, final mixture proportions, Bernoulli parameters, and log-trace

def estimated_model_bits(K, D, bits_per_float=32):
    """
    Estimate the bits needed to encode the K-component mixture of D-dimensional Bernoullis.
    bits_per_float can be 32 (single) or 64 (double).
    """
    # Mixture weights: K-1 independent parameters (last is 1-sum)
    header_bits = (K - 1) * bits_per_float
    # Bernoulli parameters: K * D parameters
    param_bits = K * D * bits_per_float
    model_bits = header_bits + param_bits
    return model_bits

# Step 5: Run EM Algorithm for different values of K
Ks = [2, 3, 4, 7, 10]
for K in Ks:
     loglike, pi, P, log_trace = em_map(X, K, max_iter=100, plot_trace=True) # Plot trace is True by default
     print(f"\nK={K}, Final mixing proportions: {pi}") # Print final mixing proportions
     print(f"Final Bernoulli parameters (first 3 components):\n{P[:3]}") # Print final Bernoulli parameters


# Log-likelihoods nats to bits

## Gzip

In [6]:
# Bit-packing txt data for proper gziping
# Ensure data is a numpy array of type uint8 [0, 1]
data = np.loadtxt('/datasets/t1cw-data/binarydigits.txt', dtype=int)
makebin = data.astype(np.uint8).flatten()  # Flatten so all binary digits are in one long vector
# Pack bits: every 8 bits become one byte
packed = np.packbits(makebin)
packed.tofile('binarydigits_packed.raw')

# Raw binary conversion for second Gzip option
# Load .txt data (assume each row is whitespace-separated digits)
txt = np.loadtxt('/datasets/t1cw-data/binarydigits.txt', dtype=int)
# Convert to uint8 (0/1)
bin = txt.astype(np.uint8)
# Save as raw binary (no separators, just N*D bytes)
bin.tofile('binarydigits_bin.raw')

## Compare model costs

In [8]:
import os
import gzip # Import gzip module for compression
import numpy as np # Import numpy for array operations
import matplotlib.pyplot as plt # Import matplotlib for plotting

# Calulate model costs and compare with naive encoding and Gzip compression (2
# types, bit-packed and standard raw binary).

# --- Gzip compression of raw data files ---
# Gzip binarydigits_bin.raw
with open('/datasets/t1cw-data/binarydigits_bin.raw', 'rb') as f_in:
    with gzip.open('/datasets/t1cw-data/binarydigits_bin.gz', 'wb') as f_out:
        f_out.writelines(f_in)

# Gzip binarydigits_packed.raw
with open('/datasets/t1cw-data/binarydigits_packed.raw', 'rb') as f_in:
    with gzip.open('/datasets/t1cw-data/binarydigits_packed.gz', 'wb') as f_out:
        f_out.writelines(f_in)
# --- End of Gzip compression steps ---


# For the EM model:
total_costs = []
model_bits_list = []
data_bits_list = []

for K in Ks:
    loglike, pi, P, log_trace = em_map(X, K, max_iter=100, plot_trace=False)
    # Data bits: negative log-likelihood in bits
    data_bits = -loglike / np.log(2)
    # Model bits: ≈ no of params × bits per parameter = header for mixture weights + Bernoulli params (sum to one)
    model_bits = estimated_model_bits(K, D, bits_per_float=32)
    # Total bits: model + data
    total_cost = model_bits + data_bits
    total_costs.append(total_cost)
    model_bits_list.append(model_bits)
    data_bits_list.append(data_bits)
    print(f"\nK={K}")
    print(f"Length of model-based coding in bits (total): {total_cost:.2f}")
    print(f"Model-based coding bits per digit: {total_cost / (N*D):.4f}")


# Calculate baseline (i.e. naive encoding)
naive_bits = N * D # each digit uses 1 bit
print("\nLength of naive encoding in bits:", naive_bits)

# For Gzip:
# If bit-packed, then 8 digits per byte is the actual theoretical minimum, hence
# 1 bit per digit, therefore file size (in bytes) × 8 = bits
packedfilesize = os.path.getsize('/content/binarydigits_packed.gz')
gzip_packed_bits = packedfilesize * 8
print(f"Length of bit-packed Gzip encoding in bits: {gzip_packed_bits}")


# If not bit-packed, file has 1 digit per byte, which means 8 bits per digit is
# the lower bound for gzip
binfilesize = os.path.getsize('/content/binarydigits_bin.gz')
gzip_raw_bin_bits = binfilesize * 8
print(f"Length of raw binary Gzip encoding in bits: {gzip_raw_bin_bits}")

## Plots

In [None]:
# Plot to figure
plt.figure(figsize=(8,6))
plt.plot(Ks, total_costs, 'o-', label="Total Cost (Model+Data)")
plt.plot(Ks, model_bits_list, 's--', label="Model Header (bits)")
plt.plot(Ks, data_bits_list, 'x--', label="Data Bits (neg. loglike)")
plt.axhline(naive_bits, color='grey', linestyle=':', label="Naive Encoding (1 bit/digit)")
plt.axhline(gzip_packed_bits, color='red', linestyle='--', label="Bit-packed Gzip Encoding")
plt.axhline(gzip_raw_bin_bits, color='magenta', linestyle='-', label="Raw Binary Gzip Encoding") # Adjusted label
plt.xlabel("K")
plt.ylabel("Bits")
plt.title("Total Code Length vs. K (MDL Principle)")
plt.legend()
plt.tight_layout()
plt.show()

# Plot "bits per digit" for completeness
plt.figure(figsize=(8,6))
plt.plot(Ks, np.array(total_costs)/(N*D), linestyle='dashdot', label="EM")
plt.axhline(1, color='grey', linestyle=':', label="Naive Baseline (1 bit/digit)")
plt.axhline(gzip_packed_bits/(N*D), color='red', linestyle='--', label="Bit-packed Gzip")
plt.axhline(gzip_raw_bin_bits/(N*D), color='magenta', linestyle='-', label="Raw Binary Gzip")
plt.xlabel("K")
plt.ylabel("Bits per Digit")
plt.title("Bits per Digit vs. K")
plt.legend()
plt.tight_layout()
plt.show()

# Analysis Aids

In [None]:
# Collage figures for comparison
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# List of file names for your PNG plots
png_files = ['x.png', 'x.png', 'x.png', 'x.png', 'x.png']

fig, axes = plt.subplots(1, len(png_files), figsize=(5 * len(png_files), 5))
for i, png in enumerate(png_files):
    img = mpimg.imread(png)
    axes[i].imshow(img)
    axes[i].axis('off')

plt.tight_layout()
plt.title(f"EM Log-Posterior Traces, K = 10")
plt.show()


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=cb182644-878e-48cb-992b-68a78a5afe3d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>