<a href="https://colab.research.google.com/github/vanadhisivakumar-source/Machine-learning-projects/blob/main/EM_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Expectation-Maximization (EM) Algorithm for Clustering

The EM algorithm is an iterative approach used to find maximum likelihood estimates of parameters in statistical models, particularly when the model depends on unobserved latent variables. In clustering, it's often used with Gaussian Mixture Models (GMMs) to identify underlying Gaussian distributions within the data.

In [None]:
import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs # To generate synthetic data

# 1. Generate Synthetic Data
# Let's create a dataset with 3 distinct clusters (blobs)
X, y = make_blobs(n_samples=300, centers=3, cluster_std=0.8, random_state=42)

print("Generated Synthetic Data (first 5 rows):\n", X[:5])

plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], s=50, alpha=0.7)
plt.title('Synthetic Data for EM Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.grid(True)
plt.show()

# 2. Implement the EM Algorithm using GaussianMixture
# Initialize a GaussianMixture model with 3 components (clusters)
gmm = GaussianMixture(n_components=3, random_state=42)

# Fit the model to the data
gmm.fit(X)

# Predict the cluster assignments for each data point
em_predictions = gmm.predict(X)

print("\nEM Predictions (first 5):\n", em_predictions[:5])
print("\nMean of each Gaussian component:\n", gmm.means_)
print("\nCovariance of each Gaussian component:\n", gmm.covariances_)

# 3. Visualize the EM Clustering Results
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=em_predictions, s=50, cmap='viridis', alpha=0.7)
plt.scatter(gmm.means_[:, 0], gmm.means_[:, 1], marker='X', s=200, color='red', label='Cluster Centers')
plt.title('EM Clustering Results (Gaussian Mixture Model)')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.grid(True)
plt.show()
