## EM Algorithm

- **Observables**: $x, y, \hat z$
- **Latent**: $z$
- **Model**:
    - $p(z)$: prior over latent $z$
    - $p(\hat z | z)$: observation model for $\hat z$
    - $p(y|z,x,\theta)$: likelihood model for $y$ parameterized by $\theta$ and conditioned on $z$ and $x$

The marginal likelihood is
\begin{align}
p(y|x,\hat z,\theta) &= \int p(y,z|x,\hat z,\theta) dz\\
&= \frac{\int p(y|z,x,\theta)p(\hat z|z)p(z)dz}{\int p(\hat z|z)p(z)dz}
\end{align}

- **E-step**: Infer latent variable $z$ based on current guess of parameters:
$$q(z) \leftarrow p(z|x,y,\hat z,\theta) \propto p(z)p(y,\hat z|z,\theta,x) = p(y|z,x,\theta)p(\hat z|z)p(z)$$
<!-- - **M-step**: $\theta = \text{argmax}_\theta\mathbb{E}_{q(z)}[p(y|z,x,\theta)p(\hat z|z)]$ -->
- **M-step**: Update parameters $\theta$ to maximize the expected log-likelihood of the observed data under that inferred $q(z)$:
$$\theta = \text{argmax}_\theta\mathbb{E}_{q(z)}[\log p(y|z,x,\theta)]$$

In [None]:
# Given:
# function model_lkhd(x, y, z, theta) = p(y|z,x,theta)
# grid size = [H, W]

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import MultivariateNormal

In [None]:
def E_step(x, y, z_hat, theta, z_sample, sigma=1.0):
    """
    E-step: Infer z based on current guess of parameters

    Parameters:
    x, y     : stimulus and observed neural response, inputs to model_lkhd
    z_hat    : torch tensor, observed eye trace location, shape [2]
    theta    : model parameters
    z_sample : np.ndarray of shape [H, W], count of times each pixel was sampled
    sigma    : float, std deviation for Gaussian likelihood p(z_hat | z)

    Returns:
    q_z : np.ndarray of shape [H, W], posterior probability of sampling at each pixel
    """
    # H, W = z_sample.shape
    # cov = torch.eye(2) * sigma**2
    # mvn = MultivariateNormal(z_hat, cov)

    # q_z = np.zeros((H, W))
    # total_samples = z_sample.sum()
    # for h in range(H):
    #     for w in range(W):
    #         count = z_sample[h, w]
    #         if count == 0:
    #             continue
    #         else:
    #             z = torch.tensor([h, w])
    #             p_z_hat_given_z = torch.exp(mvn.log_prob(z)).item()
    #             lkhd = model_lkhd(x, y, z, theta)
    #             q_z[h, w] = lkhd * p_z_hat_given_z * count
    # q_z /= q_z.sum()
    # return q_z

    H, W = z_sample.shape
    h_coords, w_coords = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    z_grid = torch.stack([h_coords, w_coords], dim=-1).reshape(-1, 2)  # [H*W,2]
    z_sample_flat = torch.tensor(z_sample).flatten()  # [N]

    # Mask only sampled pixels
    mask = z_sample_flat > 0  # [H*W]
    z_grid_masked = z_grid[mask]  # [M,2]
    sample_weights = z_sample_flat[mask]  # [M]
    total_samples = sample_weights.sum()

    mvn = MultivariateNormal(z_hat, torch.eye(2) * sigma**2)
    log_p_z_hat_given_z = mvn.log_prob(z_grid_masked)
    log_lkhd = model_log_lkhd(x, y, z_grid_masked, theta)
    log_q_z = log_lkhd + log_p_z_hat_given_z + sample_weights.log()
    q_z_unnorm = torch.exp(log_q_z - log_q_z.max())

    q_flat = torch.zeros_like(z_sample_flat)
    q_flat[mask] = q_z_unnorm
    q_z = q_flat.reshape(H, W)
    q_z /= q_z.sum()
    return q_z

In [None]:
def sample_z(q_z, num_samples):
    """
    Sample z from q(z)

    Parameters:
    q_z         : np.ndarray of shape [H, W], posterior probability of sampling at each pixel
    num_samples : int, number of samples

    Returns:
    z_sample    : np.ndarray of shape [H, W], count of times each pixel was sampled
    """
    H, W = q_z.shape
    flat_q = q_z.flatten()
    # flat_q /= flat_q.sum()
    sampled_indices = np.random.choice(len(flat_q), size=num_samples, p=flat_q)
    counts_flat = np.bincount(sampled_indices, minlength=H * W)
    z_sample = counts_flat.reshape(H, W)
    return z_sample

In [None]:
def M_step(x, y, q_z, theta, lr=1e-2, steps=100):
    """
    M-step: update theta to maximize E_q(z)[log p(y | z, x, theta)]

    Parameters:
        x, y  : stimulus and observed neural response, inputs to model_lkhd
        q_z   : np.ndarray of shape [H, W], posterior probability over latent z
        theta : torch.nn.Module or torch.nn.Parameter (parameters to optimize)
        lr    : learning rate
        steps : number of gradient descent steps

    Returns:
        Updated theta
    """
    # H, W = q_z.shape
    # optimizer = optim.Adam(theta.parameters(), lr=lr)

    # for step in range(steps):
    #     optimizer.zero_grad()
    #     loss = 0
    #     for h in range(H):
    #         for w in range(W):
    #             weight = q_z[h, w]
    #             if weight == 0:
    #                 continue
    #             else:
    #                 z = torch.tensor([h, w])
    #                 log_lkhd = model_log_lkhd(x, y, z, theta)
    #                 loss -= weight * log_lkhd
    #     loss.backward()
    #     optimizer.step()
    # return theta

    H, W = q_z.shape
    q_z_flat = torch.tensor(q_z).flatten()
    h_coords, w_coords = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    z_grid = torch.stack([h_coords, w_coords], dim=-1).reshape(-1, 2)

    mask = q_z_flat > 0
    z_grid_masked = z_grid[mask]
    sample_weights = q_z_flat[mask]

    optimizer = optim.Adam(theta.parameters(), lr=lr)
    for step in range(steps):
        optimizer.zero_grad()
        log_lkhd = model_log_lkhd(x, y, z_grid_masked, theta)
        loss = - (sample_weights * log_lkhd).sum()
        loss.backward()
        optimizer.step()
    return theta

In [None]:
np.random.seed(0)
torch.manual_seed(0)

H, W = 60, 100
z_true = np.array([10, 12])  # latent location
sigma_z_hat = 2.0
z_hat = z_true + np.random.randn(2) * sigma_z_hat
z_hat = torch.tensor(z_hat, dtype=torch.float32)

x = np.random.randn(5)
