# CNS2025 Homework 10

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load data
data = np.load('hw10-data.npz')
s = data['s']
r = data['r']
print(f"Data loaded: s shape={s.shape}, r shape={r.shape}")

In [None]:
# Define mutual information function
def calculate_mi(x, y, bins=25):
    # Calculate 2D histogram
    hist_2d, _, _ = np.histogram2d(x, y, bins=bins)
    
    # Convert to probability
    p_xy = hist_2d / np.sum(hist_2d)
    p_x = np.sum(p_xy, axis=1)
    p_y = np.sum(p_xy, axis=0)
    
    # Calculate entropies
    def entropy(p):
        p = p[p > 0]
        return -np.sum(p * np.log2(p))
    
    H_x = entropy(p_x)
    H_y = entropy(p_y)
    H_xy = entropy(p_xy.flatten())
    
    # Mutual information
    return H_x + H_y - H_xy

## Exercise 1: MI vs Relative Frame Shift

In [None]:
# Calculate MI for different shifts
shifts = np.arange(-64, 65)
mi_values = []

for delta in shifts:
    r_shifted = np.roll(r, delta)
    mi = calculate_mi(s, r_shifted)
    mi_values.append(mi)

mi_values = np.array(mi_values)

# Find peak
peak_idx = np.argmax(mi_values)
peak_shift = shifts[peak_idx]
peak_mi = mi_values[peak_idx]

# Plot
plt.figure(figsize=(10, 5))
plt.plot(shifts, mi_values, 'b-', linewidth=2)
plt.plot(peak_shift, peak_mi, 'ro', markersize=10)
plt.xlabel('Relative Frame Shift δ', fontsize=12)
plt.ylabel('Mutual Information (bits)', fontsize=12)
plt.title('Mutual Information vs Relative Frame Shift', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
plt.show()

print(f"Peak position: δ = {peak_shift}")
print(f"Peak MI value: {peak_mi:.4f} bits")

## Exercise 2: Baseline Estimation

In [None]:
# Method 1: Shuffle baseline
rng = np.random.default_rng(42)
n_shuffles = 1000
mi_shuffled = []

for i in range(n_shuffles):
    rr = rng.permutation(r)
    mi = calculate_mi(s, rr)
    mi_shuffled.append(mi)

baseline_shuffle = np.mean(mi_shuffled)
std_shuffle = np.std(mi_shuffled)

print(f"Shuffle baseline: {baseline_shuffle:.6f} ± {std_shuffle:.6f} bits")

In [None]:
# Method 2: All shifts baseline
all_shifts_mi = []

for delta in range(1000):
    r_shifted = np.roll(r, delta)
    mi = calculate_mi(s, r_shifted)
    all_shifts_mi.append(mi)

baseline_allshifts = np.mean(all_shifts_mi)
std_allshifts = np.std(all_shifts_mi)

print(f"All shifts baseline: {baseline_allshifts:.6f} ± {std_allshifts:.6f} bits")

In [None]:
# Comparison plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Shuffle distribution
ax1.hist(mi_shuffled, bins=30, density=True, alpha=0.7, color='blue', edgecolor='black')
ax1.axvline(baseline_shuffle, color='red', linestyle='--', linewidth=2)
ax1.set_xlabel('Mutual Information (bits)')
ax1.set_ylabel('Density')
ax1.set_title(f'Shuffle Baseline: {baseline_shuffle:.6f}')
ax1.grid(True, alpha=0.3)

# All shifts distribution
ax2.hist(all_shifts_mi, bins=30, density=True, alpha=0.7, color='green', edgecolor='black')
ax2.axvline(baseline_allshifts, color='red', linestyle='--', linewidth=2)
ax2.set_xlabel('Mutual Information (bits)')
ax2.set_ylabel('Density')
ax2.set_title(f'All Shifts Baseline: {baseline_allshifts:.6f}')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Final comparison with baselines
plt.figure(figsize=(10, 6))
plt.plot(shifts, mi_values, 'b-', linewidth=2, label='MI(δ)')
plt.axhline(baseline_shuffle, color='red', linestyle='--', alpha=0.7, label=f'Shuffle baseline: {baseline_shuffle:.6f}')
plt.axhline(baseline_allshifts, color='green', linestyle='--', alpha=0.7, label=f'All shifts baseline: {baseline_allshifts:.6f}')
plt.fill_between(shifts, baseline_shuffle - std_shuffle, baseline_shuffle + std_shuffle, alpha=0.2, color='red')
plt.fill_between(shifts, baseline_allshifts - std_allshifts, baseline_allshifts + std_allshifts, alpha=0.2, color='green')
plt.plot(peak_shift, peak_mi, 'ko', markersize=10)
plt.xlabel('Relative Frame Shift δ', fontsize=12)
plt.ylabel('Mutual Information (bits)', fontsize=12)
plt.title('MI Curve with Baseline Estimates', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Summary
print("\n=== RESULTS SUMMARY ===")
print(f"\nExercise 1:")
print(f"  Peak position: δ = {peak_shift}")
print(f"  Peak MI value: {peak_mi:.4f} bits")
print(f"\nExercise 2:")
print(f"  Shuffle baseline: {baseline_shuffle:.6f} bits")
print(f"  All shifts baseline: {baseline_allshifts:.6f} bits")
print(f"\nRecommendation:")
if std_shuffle < std_allshifts:
    print("  The shuffle method provides a more stable baseline estimate.")
else:
    print("  The all shifts method provides a more stable baseline estimate.")