<a href="https://colab.research.google.com/github/sreehitha177/hhh/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import torch
import math

# def log_posterior(X, mu, Sigma, pi):
#     """
#     Compute log P(Z=z | X=x_n) for a Gaussian Mixture Model.

#     Parameters:
#         X: Tensor of shape (N, 784) - input data.
#         mu: Tensor of shape (10, 784) - cluster means.
#         Sigma: Tensor of shape (10, 784, 784) - cluster covariances.
#         pi: Tensor of shape (10, 1) - mixture proportions.

#     Returns:
#         Tensor of shape (N, 10) - log probabilities.
#     """
#     N, D = X.shape
#     K = mu.shape[0]  # Number of clusters

#     # Precompute useful terms
#     log_pi = torch.log(pi.squeeze())  # Shape (10,)
#     Sigma_inv = torch.linalg.inv(Sigma)  # Shape (10, 784, 784)
#     Sigma_det = torch.linalg.det(Sigma)  # Shape (10,)
#     log_det_Sigma = torch.log(Sigma_det)  # Shape (10,)

#     log_prob = torch.zeros((N, K))  # Output tensor

#     for z in range(K):
#         diff = X - mu[z]  # Shape (N, 784)
#         term1 = -0.5 * (D * math.log(2 * math.pi) + log_det_Sigma[z])
#         term2 = -0.5 * (diff @ Sigma_inv[z] * diff).sum(dim=1)  # Quadratic form
#         log_prob[:, z] = log_pi[z] + term1 + term2

#     # Normalize using log-sum-exp for numerical stability
#     log_prob_x = torch.logsumexp(log_prob, dim=1, keepdim=True)  # Shape (N, 1)
#     log_post = log_prob - log_prob_x  # Shape (N, 10)

#     return log_post


In [21]:
def log_posterior(X, mu, Sigma, pi):
    N, D = X.shape
    K = mu.shape[0]

    # Regularize covariance matrix to ensure positive definiteness
    epsilon = 1e-6
    Sigma = Sigma + torch.eye(Sigma.shape[-1]) * epsilon

    # Precompute terms
    log_pi = torch.log(torch.clamp(pi.squeeze(), min=epsilon))  # Avoid log(0)
    Sigma_inv = torch.linalg.inv(Sigma)  # Inverse covariance
    Sigma_det = torch.linalg.det(Sigma)  # Determinant
    log_det_Sigma = torch.log(torch.clamp(Sigma_det, min=epsilon))  # Log determinant

    log_prob = torch.zeros((N, K))  # Output tensor

    for z in range(K):
        diff = X - mu[z]  # Shape (N, 784)
        term1 = -0.5 * (D * math.log(2 * math.pi) + log_det_Sigma[z])
        term2 = -0.5 * (diff @ Sigma_inv[z] * diff).sum(dim=1)
        log_prob[:, z] = log_pi[z] + term1 + term2

    # Normalize using log-sum-exp
    log_prob_x = torch.logsumexp(log_prob, dim=1, keepdim=True)
    log_post = log_prob - log_prob_x

    return log_post


In [22]:
import numpy as np

# Load data
data = np.load('/content/mixture_data.npz')
model = np.load('/content/mixture_model.npz')

# print(data.files)
# print(model.files)

# print(data['X'].shape)
# print(model['mu'].shape)
# print(model['Sigma'].shape)
# print(model['pi'].shape)

# print(data['X'])
# print(model['mu'])
# print(model['Sigma'])
# print(model['pi'])

X = torch.tensor(data['X'], dtype=torch.float32)  # Shape (10, 784)
mu = torch.tensor(model['mu'], dtype=torch.float32)  # Shape (10, 784)
Sigma = torch.tensor(model['Sigma'], dtype=torch.float32)  # Shape (10, 784, 784)
pi = torch.tensor(model['pi'], dtype=torch.float32)  # Shape (10, 1)

# Compute log probabilities
log_probs = log_posterior(X, mu, Sigma, pi)

# Format the output table
output_table = log_probs.numpy()
print(np.round(output_table, 4))


[[ 0.0000000e+00 -1.4490413e+03 -1.2133790e+02 -1.8086990e+02
  -3.7586850e+02 -2.6029459e+02 -2.6055621e+02 -6.0158801e+02
  -2.5019270e+02 -3.5346121e+02]
 [-2.0523740e+02  0.0000000e+00 -8.5421700e+01 -1.1874310e+02
  -8.4223297e+01 -1.4616409e+02 -2.7214090e+02 -1.9995700e+02
  -2.2805099e+01 -1.8698030e+02]
 [-3.0925781e+02 -9.4687451e+02  0.0000000e+00 -1.8090610e+02
  -3.2771011e+02 -2.7386020e+02 -4.2414221e+02 -5.3770880e+02
  -1.9584380e+02 -5.1668219e+02]
 [-1.8257629e+02 -8.3698218e+02 -8.0354401e+01  0.0000000e+00
  -4.1945560e+02 -1.4189560e+02 -7.3004779e+02 -5.0164871e+02
  -6.2616199e+01 -2.5344749e+02]
 [-5.2869360e+02 -1.2784456e+03 -2.8783179e+02 -2.5842789e+02
   0.0000000e+00 -4.5504941e+02 -7.1718738e+02 -1.5763370e+02
  -3.9440979e+02 -1.8817790e+02]
 [-2.1799440e+02 -6.4717310e+02 -1.4254111e+02 -9.6253004e+00
  -3.4454999e+02 -9.9999997e-05 -6.5221899e+02 -4.0397180e+02
  -7.0150803e+01 -3.2922919e+02]
 [-2.3335181e+02 -1.0722101e+03 -1.7255040e+02 -2.5294180e

In [23]:
def log_posterior_left(X, mu, Sigma, pi, indices):
    """
    Compute log P(Z=z | X=x_n,l) for the left half of the image.

    Parameters:
        X: Tensor of shape (N, 784) - input data.
        mu: Tensor of shape (10, 784) - cluster means.
        Sigma: Tensor of shape (10, 784, 784) - cluster covariances.
        pi: Tensor of shape (10, 1) - mixture proportions.
        indices: Tensor of indices for the left half of the image.

    Returns:
        Tensor of shape (N, 10) - log probabilities.
    """
    X_left = X[:, indices]  # Select left half pixels
    mu_left = mu[:, indices]  # Adjust means
    Sigma_left = Sigma[:, indices, :][:, :, indices]  # Adjust covariance matrices

    return log_posterior(X_left, mu_left, Sigma_left, pi)


In [24]:
left_indices = torch.arange(392)  # Indices for left half
log_probs_left = log_posterior_left(X, mu, Sigma, pi, left_indices)

# Format the output table
output_table_left = log_probs_left.numpy()
print(np.round(output_table_left, 4))


[[ 0.000000e+00 -6.397985e+02 -6.065180e+01 -1.045065e+02 -7.825170e+01
  -1.701527e+02 -1.415796e+02 -2.505657e+02 -5.975220e+01 -1.191777e+02]
 [-5.626850e+01 -0.000000e+00 -2.794870e+01 -2.147520e+01 -2.426430e+01
  -6.645190e+01 -6.782190e+01 -1.290239e+02 -1.089110e+01 -7.796000e+01]
 [-2.453350e+01 -2.283005e+02 -6.770000e-02 -9.255800e+00 -3.833030e+01
  -1.013166e+02 -1.902393e+02 -2.438320e+01 -2.727300e+00 -1.069990e+01]
 [-6.488020e+01 -3.568331e+02 -1.086860e+01 -3.250000e-02 -9.692220e+01
  -1.100251e+02 -2.473572e+02 -1.667924e+02 -3.443100e+00 -7.647750e+01]
 [-1.765778e+02 -5.130408e+02 -1.265878e+02 -1.020840e+02  0.000000e+00
  -2.907932e+02 -3.456262e+02 -3.584280e+01 -1.247374e+02 -1.358392e+02]
 [-8.505390e+01 -2.555688e+02 -8.521680e+01 -5.253130e+01 -1.018681e+02
  -8.508900e+00 -2.446612e+02 -2.305468e+02 -2.000000e-04 -1.839394e+02]
 [-1.362167e+02 -4.523502e+02 -1.243000e+02 -1.498995e+02 -2.052042e+02
  -1.232898e+02  0.000000e+00 -4.646506e+02 -1.745310e+02 