# ANU ASTR4004 2025 - Week 9

Author: Sven Buder (sven.buder@anu.edu.au)

Based on the tutorial by Yuan-Sen Ting from ASTR4004 2023

In [None]:
try:
    %matplotlib inline
    %config InlineBackend.figure_format='retina'
except:
    pass

import numpy as np
from astropy.table import Table

# We will use these later for the Gaussian Mixture Models
from scipy.stats import multivariate_normal as mvnorm
from scipy.stats import chi2

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.colors import LogNorm

# Make the size and fonts larger for this presentation
plt.rcParams['font.size'] = 15
plt.rcParams['legend.fontsize'] = 12

## 1 Data: SDSS DR19 and GALAH DR4

In [None]:
# Let's read in the tables; These are measurements of [Fe/H], [Mg/Fe], [Al/Fe], [Mn/Fe] from SDSS DR19 and GALAH DR4, already matched via their Gaia DR3 source_id
galah = Table.read('data/galah_dr4_alfe_mgmn.fits')
sdss  = Table.read('data/sdss_dr19_alfe_mgmn.fits')

In [None]:
def plot_abundance(element = 'Mg'):

    f, gs = plt.subplots(1,2,figsize=(10,3.5),sharex=True,sharey=True)

    ax = gs[0]
    ax.set_xlabel('[Fe/H]')
    ax.set_ylabel('['+element+'/Fe]')
    ax.set_title('SDSS DR19')
    h = ax.hist2d(
        sdss['fe_h_sdss'],
        sdss[element.lower()+'_fe_sdss'],
        bins = 100,
        norm = LogNorm()
    )
    cbar = plt.colorbar(h[-1], ax=ax, label = 'Nr.')

    ax = gs[1]
    ax.set_xlabel('[Fe/H]')
    ax.set_title('GALAH DR4')
    h = ax.hist2d(
        galah['fe_h_galah'],
        galah[element.lower()+'_fe_galah'],
        bins = 100,
        norm = LogNorm()
    )
    cbar = plt.colorbar(h[-1], ax=ax, label = 'Nr.')

    plt.tight_layout()
    plt.show()
    plt.close()

[plot_abundance(element) for element in ['Mg','Al','Mn']]

In [None]:
# nan values in GALAH?

# Look for quality flags in the galah catalogue (columns with flag_*)
[print('Stars with raised flags for element '+element+':',len(np.where(galah['flag_'+element+'_fe_galah'] > 0)[0])) for element in ['al','mg','mn']];

# Filter for good abundance measurements in GALAH (flag == 0 for the elements we care about)
galah = galah#[] --> add filters here

In [None]:
# Specific stars observed a lot in GALAH?
np.unique(galah['gaiadr3_source_id'], return_counts=True)

In [None]:
# Let's apply some more filtering to avoid repeat observations / wrong Gaia source_id matches

# Filter to avoid missing Gaia source_ids (source_id has negative values)
sdss = sdss#[] --> add filters here
galah = galah#[] --> add filters here

# 2) Classification of thin and thick disk in SDSS DR19

## 2.1) K-Means

In [None]:
# for this tutorial, let's use variables rather than keywords
feh  = np.array(sdss['fe_h_sdss'])
mgfe = np.array(sdss['mg_fe_sdss'])

# noramlize data
feh = (feh - np.mean(feh)) / np.std(feh)
mgfe = (mgfe - np.mean(mgfe)) / np.std(mgfe)

# prepare data for sklearn
data = np.vstack([feh, mgfe]).T
data_labels = ["[Fe/H]", "[Mg/Fe]"]

### Initializing K-means Clustering Parameters

Before running the K-means algorithm on our dataset, it's important to set initial values for the model's centroids. These initializations guide the iterative process of the K-means algorithm. For our problem, we'll use a two-cluster K-means model to capture the bimodal distribution of alpha-enriched and alpha-normal stars.

We will initialize the centroids of our two clusters as follows:

- For the first cluster ($ C_1 $):
$$
\begin{bmatrix}
    -1 \\
    +1
\end{bmatrix}
$$

- For the second cluster ($ C_2 $):
$$
\begin{bmatrix}
    +1 \\
    -1
\end{bmatrix}
$$

By initializing these parameters, we can now proceed with the K-means algorithm to cluster our dataset based on elemental abundances.


In [None]:
# K-means initial centroids
C_0 = np.array([[-1.0, 1.0], [1.0, -1.0]])

### Visualizing the Data and Cluster Centroids

To better understand the iterative process of the K-means algorithm and to visually inspect how well the centroids represent different clusters in the dataset, we plot both the data points and the centroids. Optionally, we can also visualize the "redness" of each data point, which could represent its closeness to a specific centroid or its likelihood of belonging to a particular cluster.

**Helper Functions for Plotting: plot_data_and_centroids Function**

The `plot_data_and_centroids` function serves the primary role of plotting both the data and the centroids on the same chart. Here are the parameters:

- `data`: The actual data points you wish to plot.
- `centroids`: The positions of the centroids.
- `redness`: Optional parameter to color data points based on some metric, which could be their distance to the nearest centroid or some other relevant measure.

This function uses an optional 'redness' parameter to color the data points. When the 'redness' values are provided, each point will be colored on a blue-to-red scale based on this value. This visual cue can help us better understand the cluster assignment at each step of the K-means iteration.

**Optional: Custom Color Map**

A custom color map (`br_cmap`) is defined using Matplotlib's `LinearSegmentedColormap`. This color map is used when coloring data points based on their 'redness' value. The color map transitions from blue to red, allowing for a clear visualization of cluster memberships or distances to centroids.

These helper functions enable us to visually assess how well the K-means algorithm is performing at each iteration and offer insights into the clustering process.


In [None]:
def plot_data_and_centroids(data, centroids, redness=None):
    plt.figure()
    
    br_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("seismic",["b","r"])
    
    # Plotting the data points
    if redness is not None:
        assert len(redness) == data.shape[0]
        assert all(_ >= 0 and _ <= 1 for _ in redness)
        c = redness
    else:
        c = 'grey'
    plt.scatter(data[:, 0], data[:, 1], c=c, s=0.1, alpha=0.3, linewidths=2, cmap=br_cmap, label="Data points")
    
    # Plotting the centroids with different colors
    centroid_colors = ['b', 'r']
    for i in range(centroids.shape[0]):
        plt.scatter(centroids[i, 0], centroids[i, 1], c=centroid_colors[i], marker='x', s=100, label=f"Centroid {i+1}")
    
    plt.xlabel("[Fe/H]")
    plt.ylabel("[Mg/Fe]")
    plt.legend()
    plt.show()

# Assuming data and initial centroids C_0 are defined earlier, and that 'redness' values are calculated somehow
plot_data_and_centroids(data, C_0)

### Optimizing the K-means Clustering

As discussed in lectures, the K-means algorithm can be optimized through an iterative approach. This involves two steps: the assignment step, and the update step.

**Function Signatures**

The primary functions involved in K-means are:

- For the E-step or Assignment step,
```python
def assignment_step(X, centroids):
```

- For the M-step or the Update step,
```python
def update_step(X, labels):
```

**Tips**

You can make use of the `scipy.spatial.distance` library to compute Euclidean distances between points:

```python
from scipy.spatial import distance
```

You can use `distance.euclidean(a, b)` to compute the distance between points `a` and `b`.

In this K-means implementation:

1. `assignment_step`: Assigns each data point to the nearest centroid, effectively partitioning the data into clusters.
2. `update_step`: Calculates the new centroids by taking the mean of all the data points in each cluster.

Unlike GMM, K-means doesn't work with covariance matrices, Gaussian components, or weighted probabilities. It's a simpler algorithm in that regard.

In [None]:
from scipy.spatial import distance
import numpy as np

def assignment_step(X, centroids):
    """
    Performs the E-step of the K-means algorithm.
    """
    labels = np.argmin([distance.cdist(X, c.reshape(1, -1), 'euclidean') for c in centroids], axis=0)
    return labels.flatten()

def update_step(X, labels, K):
    """
    Performs the M-step of the K-means algorithm.
    """
    new_centroids = np.array([X[labels == k].mean(axis=0) for k in range(K)])
    return new_centroids

### Evaluating the K-means Model

To evaluate the performance of the K-means model, you can calculate its inertia.

The function computes the inertia by summing up the squared Euclidean distances between each data point and the centroid of its cluster. The resulting value will be a single scalar that measures how well the centroids fit the data $ \mathbf{X} $.


In [None]:
from scipy.spatial import distance
import numpy as np

def calculate_inertia(X, labels, centroids):
    inertia = 0
    for i in range(len(centroids)):
        cluster_points = X[labels == i]
        inertia += np.sum([distance.euclidean(p, centroids[i]) ** 2 for p in cluster_points])
    return inertia

### Monitoring Convergence in K-means

To check if the K-means algorithm is converging, one can track the model's inertia against the number of iterations. The idea is to run K-means for a fixed number of iterations and compute the inertia at each step. This information can then be visualized to assess the algorithm's convergence.



In [None]:
# Assuming `data` is your input data
# Initialize centroids
initial_centroids = C_0  

# Number of clusters
K = 2

# Number of iterations
iterations = 100

# Array to hold the inertia values
inertia_values = np.zeros(iterations)

# Initialize centroids to some starting values (C_0)
centroids = initial_centroids

# Run the K-means algorithm
for i in range(iterations):
    # Perform the Assignment step to update labels
    labels = assignment_step(data, centroids)
    
    # Compute the inertia of the current model
    inertia_values[i] = calculate_inertia(data, labels, centroids)
    
    # Update centroids for the next iteration
    centroids = update_step(data, labels, K)
    print(i, centroids)

# Plotting the inertia values
plt.title("Inertia Values")
plt.xlabel("Number of Updates")
plt.ylabel("Inertia")
plt.plot(inertia_values)
plt.show()

In [None]:
# We can also include a "convergence" criterion based on inertia change
tolerance = 1e-4  # Define a tolerance level for convergence

# Initialize centroids
initial_centroids = C_0  
inertia_values = np.zeros(iterations)
centroids = initial_centroids

for i in range(iterations):
    # Perform the Assignment step to update labels
    labels = assignment_step(data, centroids)
    
    # Compute the inertia of the current model
    inertia_values[i] = calculate_inertia(data, labels, centroids)
    
    # Update centroids for the next iteration
    new_centroids = update_step(data, labels, K)
    
    # Check for convergence
    if np.all(np.abs(new_centroids - centroids) < tolerance):
        print(f"Converged (within absolute tolerance of {tolerance}) after {i} iterations.")
        break
    
    centroids = new_centroids
    print(i, centroids)


In [None]:
# You could also think about a relative tolerance, e.g. 0.1% change in inertia
relative_tolerance = 1e-3  # Define a relative tolerance level for convergence

# Initialize centroids
initial_centroids = C_0  
inertia_values = np.zeros(iterations)
centroids = initial_centroids

for i in range(iterations):
    # Perform the Assignment step to update labels
    labels = assignment_step(data, centroids)

    # Compute the inertia of the current model
    inertia_values[i] = calculate_inertia(data, labels, centroids)

    # Update centroids for the next iteration
    new_centroids = update_step(data, labels, K)

    # Check for convergence
    if np.all(np.abs(new_centroids - centroids) < relative_tolerance * np.abs(centroids)):
        print(f"Converged (within relative tolerance of {relative_tolerance}) after {i} iterations.")
        break

    centroids = new_centroids
    print(i, centroids)


### Visualising EM

Use the function `plot_data_and_centroids` to visualize the data points and your current cluster centroids at different stages of the K-means algorithm.


In [None]:
# Initialize centroids to some starting values (C_0)
centroids = C_0 

# Number of clusters
K = 2

# Number of iterations for the K-means algorithm
iterations = 50

# Array to keep track of the inertia values
inertia_values = np.zeros(iterations)

# Run the K-means algorithm
for i in range(iterations):
    
    # Perform the Assignment step to update labels
    labels = assignment_step(data, centroids)
    
    # Compute the inertia of the current model
    inertia_values[i] = calculate_inertia(data, labels, centroids)
    
    # Visualize the model every 10 iterations
    if i % 10 == 0:
        print(f"Iteration {i}: Current Inertia = {inertia_values[i]}")
        
        # Show the current state of the model
        plot_data_and_centroids(data, centroids, redness=labels)
        
    # Perform the Update step to update centroids
    centroids = update_step(data, labels, K)

### Why May K-means Not Work Well?

In our analysis, we observed that the K-means clustering algorithm did not perform well with our elemental abundances data that forms a double moon shape. The reason lies in the assumptions behind the K-means algorithm, particularly its use of Euclidean distance to partition data points into clusters. K-means tries to minimize the variance within each cluster, which is equivalent to minimizing the Euclidean distance from each data point to its cluster's centroid. This works well for clusters that are spherical and equally sized, but not for data with more complex shapes or varying densities.

**Limitations of K-means:**

1. **Spherical Assumption**: K-means assumes that the clusters are spherical and equally sized, which doesn't hold true for double moon-shaped data.
  
2. **Euclidean Distance**: The algorithm uses Euclidean distance to allocate points to the nearest cluster. For more complex shapes like double moons, a different distance measure or model might be more appropriate.
  
3. **Lack of Flexibility**: K-means does not have the flexibility to account for different shapes or orientations of clusters.

To overcome these limitations, we turn to Gaussian Mixture Models (GMMs), a more flexible approach that can model elliptical clusters and is capable of handling different cluster shapes and orientations. Unlike K-means, GMMs do not assume equal-sized or spherical clusters, and they provide a probabilistic framework to capture uncertainty, making them a more suitable option for clustering our double moon-shaped data.


## 2.2) Gaussian Mixture Models

In [None]:
# for this tutorial, let's again use variables rather than keywords
feh  = np.array(sdss['fe_h_sdss'])
mgfe = np.array(sdss['mg_fe_sdss'])

data = np.vstack([feh, mgfe]).T
data_labels = ["[Fe/H]", "[Mg/Fe]"]

Gaussian Mixture Models (GMMs) offer a robust approach to model complex data distributions. When given a dataset $\{\mathbf{x}_1, \ldots, \mathbf{x}_n \}$ with each $\mathbf{x}_i \in \mathbb{R}^D$, GMMs postulate that this data is generated from a sum of $K$ different Gaussian distributions.

We want to find $K$ components, each with a mixture weight $\pi_k$ and individual Gaussian distributions $\mathcal{N}(x \mid \mu_k, \Sigma_k )$ with means $\mu_k$ and covariances $\Sigma_k$. The mixture weights must satisfy the constraint:

$$
\sum_{k=1}^K \pi_k = 1
$$

The goal is to find $\pi$, $\mu$, and $\Sigma$ that maximize this function. However, this is easier said than done due to the complex inter-dependencies between these parameters.

### Linear and Log-Likelihood function

Imagine drawing a random number $p$ uniformly from the interval $[0,1)$. Based on the value of $p$, we determine which Gaussian distribution to sample from. For example, if $p$ lies in the range $[\sum_{i=1}^{k-1} \pi_i, \sum_{i=1}^{k} \pi_k)$, we sample $\mathbf{x}$ from $\mathcal{N}(x \mid \mu_k, \Sigma_k )$.

Mathematically, we can then describe the probability distribution $p$:

$$
p(\mathbf{x}) := \sum_{k=1}^K \pi_k \mathcal{N}(x \mid \mu_k, \Sigma_k )
$$

The log-likelihood function for GMMs is expressed as:

$$
\log p(\mathbf{x} \mid \pi, \mu, \Sigma) 
= \sum_{n=1}^N \log \left\{ \sum_{k=1}^K \pi_k \mathcal{N}(\mathbf{x} \mid \mu_k, \Sigma_k) \right\}
$$

### Maximizing Log-Likelihood Through EM Algorithm

**Responsibilities and Effective Number of Points**

We introduce a variable called the 'responsibility', denoted by $\gamma_{nk}$, which essentially quantifies the likelihood that the $n^{th}$ data point is generated by the $k^{th}$ Gaussian. It is computed as:

$$
\gamma_{nk} = \frac{\pi_{k} \mathcal{N}( \mathbf{x}_n \mid \mu_k, \Sigma_k ) }
    { \sum_{k=1}^K \pi_{k} \mathcal{N}( \mathbf{x}_n \mid \mu_k, \Sigma_k)}
$$

We then define $N_k$, the effective number of points for the $k^{th}$ Gaussian, as:

$$
N_k := \sum_{n=1}^N \gamma_{nk}
$$

**EM Algorithm Steps**

The Expectation-Maximization (EM) algorithm proceeds in a series of steps to iteratively update the model parameters and maximize the log-likelihood.

- **Step 1 (Initialization)**: Choose initial values for $\mu_k, \Sigma_k, \pi_k$, and calculate the initial log-likelihood.

- **Step 2 (Expectation)**: Compute the responsibilities using the current parameters.

- **Step 3 (Maximization)**: Update the parameters based on the newly computed responsibilities:

\begin{align*}
N_k^\text{new} & := \sum_{n=1}^N \gamma_{nk} \\
\pi_k^{\text{new}} & := \frac{N_k^\text{new}}{N} \\
\mu_k^\text{new} & := \frac{1}{N_k^\text{new}} \sum_{n=1}^N \gamma_{nk} \mathbf{x}_n \\
\Sigma_k^\text{new} & := \frac{1}{N_k^\text{new}} \sum_{n=1}^N \gamma_{nk} (\mathbf{x}_n - \mu_k^\text{new}) (\mathbf{x}_n - \mu_k^\text{new})^T
\end{align*}


- **Step 4 (Evaluate)**: Calculate the new log-likelihood. If it hasn't converged, return to Step 2.

By iteratively applying these steps, the EM algorithm arrives at an optimized set of parameters that maximize the log-likelihood.

### Initializing Gaussian Mixture Model Parameters

Before running the Gaussian Mixture Model (GMM) on our dataset, it's crucial to set initial values for the model's parameters. These initializations guide the iterative process of the Expectation-Maximization (EM) algorithm, which we will be using to optimize the model. For our problem, we'll use a two-component Gaussian Mixture Model to capture the bimodal distribution of alpha-enriched and alpha-normal stars.

**Initial Means: $ \mu_0 $**

We will initialize the means of our two Gaussian components as follows:

- For the first Gaussian ($ \mu_1 $):
$$
\begin{bmatrix}
    -1 \\
    +1
\end{bmatrix}
$$

- For the second Gaussian ($ \mu_2 $):
$$
\begin{bmatrix}
    +1 \\
    -1
\end{bmatrix}
$$

**Initial Covariance Matrices: $ \Sigma_0 $**

Both covariance matrices will be initialized as identity matrices. This assumes, initially, that the features are uncorrelated and have a unit variance.

**Initial Mixture Weights: $ \pi_0 $**

Initially, we will assume that both Gaussian components are equally likely to generate any given data point. Thus, the initial mixture weights will be equal; specifically, each will be set to 0.5.

**Summary of Initial Parameters**

- `mu_0`: A $2 \times 2$ matrix containing the initial means of the Gaussians.
  - $\mu_0 = \left[ \begin{array}{cc} -1 & 1 \\ 1 & -1 \end{array} \right]$

- `Sigma_0`: A $2 \times 2 \times 2$ 3-tensor containing the initial covariance matrices.
  - $\Sigma_0 = \left[ \begin{array}{cc} 1 & 0 \\ 0 & 1 \end{array} \right] \times 2$

- `pi_0`: A vector containing the initial mixture weights.
  - $\pi_0 = \left[ 0.5, 0.5 \right]$

By initializing these parameters, we can now proceed with the EM algorithm to fit the Gaussian Mixture Model to our dataset.

In [None]:
# Solution
pi_0 = np.array([0.5,0.5])
mu_0 = np.array([[-0.5,1.0],[0.5,-0.5]])
Sigma_0 = np.array([ [[1.0,0.0],[0.0,1.0]] , [[1.0,0.0],[0.0,1.0]]])

### Visualizing the Data and Gaussian Components

To better understand the Expectation-Maximization (EM) process and the distribution of our data, we will plot both the data points and ellipses representing our Gaussian components. The ellipses will provide a visual interpretation of the covariance and mean of each Gaussian in our mixture model.

**plot_cov_ellipse Function**

The `plot_cov_ellipse` function plots an ellipse based on a given 2x2 covariance matrix and a location for the ellipse's center. The function also allows the customization of the ellipse's appearance through various parameters.

- `cov`: The 2x2 covariance matrix to base the ellipse on.
- `pos`: The location of the center of the ellipse.
- `volume`: The volume inside the ellipse; default is 0.5.
- `ax`: The axis to plot the ellipse on; defaults to the current axis.

**plot_components Function**

The `plot_components` function uses `plot_cov_ellipse` to plot ellipses for each Gaussian component in our mixture model. It takes in the means (`mu`) and covariances (`Sigma`) of each Gaussian, along with their respective colors, to render these ellipses on the plot.

**plot_data Function**

The `plot_data` function plots the actual data points. Optionally, it can color the points based on a 'redness' value, allowing us to visually separate data that might belong to different components of the Gaussian mixture.

These helper functions will enable us to visualize the Gaussian components and their evolution as we proceed with the EM algorithm.

In [None]:
# plot_cov_ellipse was taken from here:
# http://www.nhsilbert.net/source/2014/06/bivariate-normal-ellipse-plotting-in-python/

def plot_cov_ellipse(cov, pos, volume=.5, ax=None, fc='none', ec=[0,0,0], a=1, lw=2):
    """
    Plots an ellipse enclosing *volume* based on the specified covariance
    matrix (*cov*) and location (*pos*). Additional keyword arguments are passed on to the 
    ellipse patch artist.

    Parameters
    ----------
        cov : The 2x2 covariance matrix to base the ellipse on
        pos : The location of the center of the ellipse. Expects a 2-element
            sequence of [x0, y0].
        volume : The volume inside the ellipse; defaults to 0.5
        ax : The axis that the ellipse will be plotted on. Defaults to the 
            current axis.
    """
    def eigsorted(cov):
        vals, vecs = np.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        return vals[order], vecs[:,order]

    if ax is None:
        ax = plt.gca()

    vals, vecs = eigsorted(cov)
    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))

    kwrg = {'facecolor':fc, 'edgecolor':ec, 'alpha':a, 'linewidth':lw}

    # Width and height are "full" widths, not radius
    width, height = 2 * np.sqrt(chi2.ppf(volume,2)) * np.sqrt(vals)
    ellip = patches.Ellipse(xy=pos, width=width, height=height, angle=theta, **kwrg)

    ax.add_artist(ellip)
    

def plot_components(mu, Sigma, colours, *args, **kwargs):
    '''
    Plot ellipses for the bivariate normals with mean mu[:,i] and covariance Sigma[:,:,i]
    '''
    assert mu.shape[1] == Sigma.shape[2]
    assert mu.shape[0] == 2
    assert Sigma.shape[0] == 2
    assert Sigma.shape[1] == 2
    for i in range(mu.shape[1]):
        kwargs['ec'] = colours[i]
        plot_cov_ellipse(Sigma[i], mu[i], *args, **kwargs)

br_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("seismic",["b","r"])

def plot_data(redness=None):
    if redness is not None:
        assert len(redness) == data.shape[0]
        assert all(_ >= 0 and _ <= 1 for _ in redness)
        c = redness
    else:
        c = 'grey'
    plt.figure()
    plt.scatter(data[:,0],data[:,1], marker='.', s=0.1, alpha=0.3, linewidths=2, c=c, cmap=br_cmap)
    plt.xlabel(data_labels[0])
    plt.ylabel(data_labels[1])
    plt.axis('equal')

In the next code cell, we demonstrate how to utilize the plotting functions we've defined earlier. We'll plot the dataset along with the initial guesses for our Gaussian components. This visualization will help us understand our starting point before we begin iterating with the Expectation-Maximization algorithm.

The code will only execute if the required variables (`mu_0`, `Sigma_0`, and `data`) have been previously defined. If these variables are in place, the function `plot_data()` will plot the data points, and `plot_components()` will overlay the initial Gaussian ellipses on the same plot.

This plot will serve as a reference point, helping us to evaluate how well our Gaussian Mixture Model fits the data as we proceed with the EM algorithm.

In [None]:
# Check if the required variables are defined
if 'mu_0' in locals() and 'Sigma_0' in locals() and 'data' in locals():
    # Plot the data points
    plot_data()
    # Plot the initial Gaussian components
    plot_components(mu_0, Sigma_0, ['C0', 'C3'], 0.2)
    plt.xlim(-2,1)
    plt.ylim(-1,1)
    # Display the plot
    plt.show()

### Optimizing the Gaussian Mixture Model with EM

As discussed in lectures, the Gaussian Mixture Model (GMM) can be optimized using the Expectation-Maximization (EM) algorithm. This technique allows for the maximum likelihood estimation of the model parameters $\mathbf{\mu}$, $\mathbf{\Sigma}$, and $\mathbf{\pi}$.

**Function Signatures**

The primary functions involved in this algorithm are the E-step and the M-step. The suggested function signatures are as follows:

- For the E-step: 
```python
def e_step(X, mu, Sigma, pi):
```

- For the M-step:
```python
def m_step(X, gamma):
```

**Helper Functions**

In addition to these, a helper function named `weighted_normals` is used to calculate an $ N \times K $ matrix of weighted normal probabilities, i.e., $ \pi_{k} \mathcal{N}( \mathbf{x}_n \mid \mu_k, \Sigma_k) $.

The function signature for the helper function is:

```python
def weighted_normals(X, mu, Sigma, pi):
```

**Tips**

You can make use of the `scipy.stats` library to compute the multivariate normal distribution probabilities:

```python
from scipy.stats import multivariate_normal as mvnorm
```

You can use `mvnorm.pdf(x, mu, sigma)` to compute $ \mathcal{N}(\mathbf{x}_N \mid \mu_k, \Sigma_k) $.


In [None]:
def weighted_normals(X, mu, Sigma, pi):
    """
    Calculates the numerator of the gamma_i's, i.e., the 
    weighted normal probabilities for each data point.
    """
    N, D = X.shape
    K, = pi.shape
    w_norms = np.zeros((N,K)) # (N,K)
    for k in range(K):
        w_norms[:,k] = pi[k] * mvnorm.pdf(data, mu[k], Sigma[k]) # (N,)
    return w_norms # (N,K)

def e_step(X, mu, Sigma, pi):
    """
    Performs the E-step of the EM algorithm.
    """
    w_norms = weighted_normals(X, mu, Sigma, pi)
    gamma = w_norms / np.sum(w_norms, axis=1, keepdims=True)
    return gamma

def m_step(X, gamma):
    """
    Performs the M-step of the EM algorithm.
    """
    N,D = X.shape
    _, K = gamma.shape
    Nk = gamma.sum(axis=0) # (N,K)=>(K,)
    new_pi = Nk / N # (K,)

    # Best, no iteration / sum
    new_mu = gamma.T @ X # (K,N)@(N,D) => (K,D)
    new_mu /= Nk.reshape(-1,1) # (K,D) / (K,1) => (K,D)

    # Best, no iteration / sum
    diff = X.reshape(N,1,D) - new_mu.reshape(1,K,D) # (N,1,D) - (1,K,D) => (N,K,D)
    scaled_diff = gamma.reshape(N,K,1) * diff # (N,K,1) * (N,K,D) => (N,K,D)
    new_Sigma = scaled_diff.transpose((1,2,0)) @ diff.transpose(1,0,2) # (K,D,N) @ (K,N,D) => (K,D,D)
    new_Sigma /= Nk.reshape(-1,1,1)  # (K,D,D) / (K,1,1) = (K,D,D)

    return new_mu, new_Sigma, new_pi # (K,D), (K,D,D), (K,)

### Evaluating the Model

To evaluate the performance of the Gaussian Mixture Model, we will calculate its log-likelihood for given parameters $ \mathbf{\mu} $, $ \mathbf{\Sigma} $, and $\mathbf{\pi}$.

The function computes the log-likelihood by first obtaining the weighted normal probabilities using the `weighted_normals` function. Then it sums up the log of these probabilities across all data points. The resulting value will be a single scalar that measures how well the given parameters $ \mathbf{\mu} $, $ \mathbf{\Sigma} $, and $ \mathbf{\pi} $ fit the data $ \mathbf{X} $.


In [None]:
def log_likelihood(X, mu, Sigma, pi):
    # Calls the helper function to get the weighted normal probabilities
    w = weighted_normals(X, mu, Sigma, pi)
    
    # Computes the log-likelihood using the obtained probabilities
    ll = np.log(w.sum(axis=1)).sum()
    
    return ll

An essential aspect of any iterative algorithm, like the EM algorithm, is to check if it converges. For this purpose, we can plot the log-likelihood of the model against the number of updates made to it.

The idea is to run the EM algorithm for a fixed number of trials and calculate the log-likelihood at each step. We will then visualize this information to assess the convergence of the model.



In [None]:
# Initialize parameters
mu = mu_0
Sigma = Sigma_0
pi = pi_0

# Number of trials
trials = 100

# Array to hold the log-likelihoods
ll = np.zeros(trials)

# Run the EM algorithm
for i in range(0, trials):
    ll[i] = log_likelihood(data, mu, Sigma, pi)
    gamma = e_step(data, mu, Sigma, pi)
    (mu, Sigma, pi) = m_step(data, gamma)

# Plotting the log-likelihoods
plt.title("Log-Likelihoods")
plt.xlabel("Number of Updates")
plt.ylabel("Log-Likelihood")
plt.plot(ll)
plt.show()

In [None]:
# Initialize parameters to some starting values (mu_0, Sigma_0, pi_0)
mu = mu_0
Sigma = Sigma_0
pi = pi_0

# Number of iterations for the EM algorithm
trials = 50

# Array to keep track of the log-likelihood values
ll = np.zeros(trials)

# Run the EM algorithm
for i in range(trials):
    # Compute the log-likelihood of the current model
    ll[i] = log_likelihood(data, mu, Sigma, pi)
    
    # Visualize the model the first 5 and then every 5 iterations
    if (0 < i < 5) | (i % 5 == 0):
        print(f"Iteration {i}: Current Log-Likelihood = {ll[i]}")
        
        # show the color
        if i < 1:
            plot_data()
        else:
            plot_data(redness = gamma[:,1])

        # Plot the Gaussian components
        plot_components(mu, Sigma, ['C0', 'C3'], 0.2)
        
        # Add a title to indicate the current iteration and log-likelihood
        plt.title(f"Iteration {i}: Log-Likelihood = {np.round(ll[i])}")

        plt.savefig(f"figures/em_iteration_{i:03d}_sdss_dr19.png", dpi=150)

        plt.show()
        
    # Perform the E-step to update responsibilities
    gamma = e_step(data, mu, Sigma, pi)
    
    # Perform the M-step to update parameters
    mu, Sigma, pi = m_step(data, gamma)

## 2.4) Let's make a movie/GIF out of this!

In [None]:
# Use python packages to execute terminal command inside jupyter
import os
os.system('magick -delay 100 -loop 0 figures/em_iteration_*sdss_dr19.png figures/em_sdss_dr19.gif');

# Note: this only works properly if the images are sorted correctly; to ensure that, you can use {i:02d}, which adds a leading zero; if you have more than 100 iterations, use {i:03d}
# >>> plt.savefig(f"figures/em_iteration_{i:02d}_sdss_dr19.png", dpi=150)

## 2.3) sklearn.mixture.GaussianMixture

In [None]:
matplotlib.__version__

In [None]:
from sklearn.mixture import GaussianMixture

# Set the number of components (start with 2, as expected for thin/thick disk)
gmm = GaussianMixture(
    n_components=2, 
    random_state=42, # optional, important for reproducing results
    n_init=5, # optional, but a good idea for robustness
    means_init=mu_0 # optional, but can be helpful
)
gmm.fit(data)

# Extract the predicted labels and probabilities
labels = gmm.predict(data)
probabilities = gmm.predict_proba(data)

# Plotting the [Fe/H] vs. [Mg/Fe] distribution colored by GMM components
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5),sharex=True,sharey=True)

# Left panel: [Fe/H] vs. [Mg/Fe] colored by component labels
scatter1 = ax1.scatter(data[:, 0], data[:, 1], c=labels, cmap='seismic', s=5, alpha=0.5)

# Because we are now using the off-the-shelf GMM, we have to redefine the ellipse plotting
def plot_gmm_ellipses(gmm, ax, colors=['C0', 'C3']):
    for n, color in enumerate(colors):
        mean = gmm.means_[n]
        covar = gmm.covariances_[n]
        
        # Eigenvalues and eigenvectors to draw ellipses
        v, w = np.linalg.eigh(covar)
        v = np.sqrt(2.0) * np.sqrt(v)  # standard deviations
        u = w[0] / np.linalg.norm(w[0])
        
        angle = np.arctan2(u[1], u[0]) * 180 / np.pi
        ellipse = patches.Ellipse(mean, v[0], v[1], angle = 180 + angle, color=color, alpha=0.3)
        ax.add_patch(ellipse)

plot_gmm_ellipses(gmm, ax1)  # Add ellipses for the GMM components

ax1.set_xlabel('[Fe/H]')
ax1.set_ylabel('[Mg/Fe]')
ax1.set_xlim(-2,1)
ax1.set_ylim(-1,1)

fig.colorbar(scatter1, ax=ax1, label='GMM Component', orientation='horizontal')

# Right panel: [Fe/H] vs. [Mg/Fe] colored by certainty of component assignment
certainty = np.max(probabilities, axis=1)  # Take the maximum probability (certainty) for each point
scatter2 = ax2.scatter(data[:, 0], data[:, 1], c=certainty, cmap='seismic', s=5, alpha=0.5)
ax2.set_xlabel('[Fe/H]')
fig.colorbar(scatter2, ax=ax2, label='Component Membership Probability', orientation='horizontal')

plt.tight_layout()
plt.show()

In [None]:
# Weights of our GMMs
gmm.weights_

In [None]:
# Means of our GMMs
gmm.means_

In [None]:
# Covariances of our GMMs
gmm.covariances_

### How many components do we need?

We will discuss this in detail on Thursday, but the Bayesian and Aikake Information Criteria tell you how well your model explains the data (with a maximum Likelihood $\hat{L}$) while applying a penalty for the number of free parameters $k$.

For a given model, the AIC is defined as

\begin{equation}
\mathrm{AIC} = 2k - 2\ln\left(\hat{L}\right).
\end{equation}

BIC has a similar form but with a different penalty for the number of parameters:

\begin{equation}
\mathrm{BIC} = k\ln(n) - 2\ln\left(\hat{L}\right).
\end{equation}

The lower these values, the better.

In [None]:
import numpy as np

n_components = np.arange(1, 11)
bics = []
aics = []

for n in n_components:
    gmm = GaussianMixture(n_components=n, random_state=42, n_init=5)
    gmm.fit(data)
    bics.append(gmm.bic(data))
    aics.append(gmm.aic(data))

# Plot BIC/AIC to find the optimal number of components
plt.figure(figsize=(8,6))
plt.plot(n_components, bics, label='BIC')
plt.plot(n_components, aics, label='AIC')
plt.xlabel('Number of Components')
plt.ylabel('BIC/AIC')
plt.legend()
plt.title('BIC and AIC for GMM with Different Components')
plt.show()

## 3) Let's use GALAH DR4 instead of SDSS DR19

Apply what you have learned onto the lower precision GALAH DR4 measurements. Do you get similar results?

In [None]:
feh = []
mg_fe = []

data = []
data_labels = []

In [None]:
# Run the self-written GMM on the GALAH data

# Initialize parameters to some starting values (mu_0, Sigma_0, pi_0)
pi = []
mu = []
Sigma = []

# Number of iterations for the EM algorithm
trials = 50

# Array to keep track of the log-likelihood values
ll = np.zeros(trials)

# Run the EM algorithm
for i in range(trials):

    # Compute the log-likelihood of the current model
    
    # Perform the E-step to update responsibilities
    
    # Perform the M-step to update parameters
    
    # Visualize the model for selected iterations

## 4) Comparison of Classifications

In [None]:
from sklearn.mixture import GaussianMixture

# GMM for SDSS DR19
gmm_sdss = GaussianMixture(
    n_components=2, 
    random_state=42, # optional, important for reproducing results
    n_init=5, # optional, but a good idea for robustness
    means_init=mu_0 # optional, but can be helpful
)

sdss_data = np.vstack([sdss['fe_h_sdss'], sdss['mg_fe_sdss']]).T
gmm_sdss.fit(sdss_data)
sdss_labels = gmm_sdss.predict(sdss_data)
sdss_probabilities = gmm_sdss.predict_proba(sdss_data)
sdss['gmm_label_sdss'] = sdss_labels
sdss['gmm_1_prob_label_sdss'] = sdss_probabilities[:, 0]
sdss['gmm_2_prob_label_sdss'] = sdss_probabilities[:, 1]

# GMM for GALAH DR4
gmm_galah = GaussianMixture(
    n_components=2, 
    random_state=42, # optional, important for reproducing results
    n_init=5, # optional, but a good idea for robustness
    means_init=mu_0 # optional, but can be helpful
)
galah_data = np.vstack([galah['fe_h_galah'], galah['mg_fe_galah']]).T
gmm_galah.fit(galah_data)
galah_labels = gmm_galah.predict(galah_data)
galah_probabilities = gmm_galah.predict_proba(galah_data)
galah['gmm_label_galah'] = galah_labels
galah['gmm_1_prob_galah'] = galah_probabilities[:, 0]
galah['gmm_2_prob_galah'] = galah_probabilities[:, 1]

In [None]:
from astropy.table import join
sdss_galah = join(sdss, galah, keys='gaiadr3_source_id')#, metadata_conflicts='silent')
sdss_galah

In [None]:
fig, gs = plt.subplots(1,2,figsize=(12,5),sharex=True,sharey=True)

# Plot labels (of highest probability)
ax = gs[0]
# Compute the histogram
counts, xedges, yedges, im = ax.hist2d(
    sdss_galah['gmm_label_sdss'],
    sdss_galah['gmm_label_galah'],
    bins=2
)

# Normalize by total entries
counts_norm = counts / counts.sum()

# Clear the previous image (optional, if you want only normalized shown)
ax.clear()

# Plot normalized counts
im = ax.imshow(
    counts_norm.T,
    origin="lower",
    extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
    aspect="auto",
    cmap="coolwarm_r"
)

# Add colorbar
cbar = plt.colorbar(im, ax=ax, label="Fraction of total")

# Add labels inside each bin
for i in range(counts.shape[0]):
    for j in range(counts.shape[1]):
        # Bin center
        x = (xedges[i] + xedges[i+1]) / 2
        y = (yedges[j] + yedges[j+1]) / 2
        # Show count (or fraction if you prefer)
        ax.text(x, y, f"{counts[i,j]:.0f}",
                ha="center", va="center", color="white", fontsize=12, fontweight="bold")

ax.set_xlabel("SDSS DR19 GMM Label")
ax.set_ylabel("GALAH DR4 GMM Label")
ax.set_title("GMM Labels\nComparison (normalized)")

# Now plot probabilities
ax = gs[1]
counts, xedges, yedges, im = ax.hist2d(
    sdss_galah['gmm_1_prob_label_sdss'],
    sdss_galah['gmm_1_prob_galah'],
    bins=50,
    norm=LogNorm(),
    range=[[0, 1], [0, 1]],
    cmap="coolwarm_r"
)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, label="Nr.")
ax.plot([0,1],[0.5,0.5],'k--')
ax.plot([0.5,0.5],[0,1],'k--')
ax.set_xlim(0,1)
ax.set_ylim(0,1)
ax.set_xlabel("SDSS DR19 GMM\nComponent 1 Probability")
ax.set_ylabel("GALAH DR4 GMM\nComponent 1 Probability")
ax.set_title("GMM Component 1\nProbability Comparison")

plt.tight_layout()
plt.show()


In [None]:
# Comparison of Classifications
fig, gs = plt.subplots(1,4,figsize=(15,5),sharex=True,sharey=True)

ax = gs[0]
s = ax.scatter(
    sdss_galah['fe_h_sdss'], 
    sdss_galah['mg_fe_sdss'], 
    c=sdss_galah['gmm_label_sdss'], 
    cmap='seismic', 
    s=5, 
    alpha=0.5
)
ax.set_xlabel('[Fe/H] SDSS')
ax.set_ylabel('[Mg/Fe] SDSS')
ax.set_title('GMM Classification\nSDSS DR19')
fig.colorbar(s, ax=ax, label='GMM Component', orientation='horizontal')
plot_gmm_ellipses(gmm_sdss, ax)  # Add ellipses for the GMM components
ax.set_xlim(-2,1)
ax.set_ylim(-1,1)

ax = gs[1]
s = ax.scatter(
    sdss_galah['fe_h_galah'], 
    sdss_galah['mg_fe_galah'], 
    c=sdss_galah['gmm_label_galah'], 
    cmap='seismic', 
    s=5, 
    alpha=0.5
)
ax.set_xlabel('[Fe/H] GALAH')
ax.set_ylabel('[Mg/Fe] GALAH')
ax.set_title('GMM Classification\nGALAH DR4')
fig.colorbar(s, ax=ax, label='GMM Component', orientation='horizontal')
plot_gmm_ellipses(gmm_galah, ax)  # Add ellipses for the GMM components
ax.set_xlim(-2,1)
ax.set_ylim(-1,1)

ax = gs[2]
match_labels = sdss_galah['gmm_label_sdss'] == sdss_galah['gmm_label_galah']
s = ax.scatter(
    sdss_galah['fe_h_sdss'], 
    sdss_galah['mg_fe_sdss'], 
    c=match_labels, 
    cmap='bwr', 
    s=5, 
    alpha=0.5
)
ax.set_xlabel('[Fe/H] SDSS')
ax.set_ylabel('[Mg/Fe] SDSS')
ax.set_title('Matching Classifications')
fig.colorbar(s, ax=ax, label='Classification Match', orientation='horizontal')

ax = gs[3]
prob_diff = np.abs(sdss_galah['gmm_1_prob_label_sdss'] - sdss_galah['gmm_1_prob_galah'])
s = ax.scatter(
    sdss_galah['fe_h_sdss'], 
    sdss_galah['mg_fe_sdss'], 
    c=prob_diff, 
    cmap='viridis', 
    s=5, 
    alpha=0.5
)
ax.set_xlabel('[Fe/H] SDSS')
ax.set_ylabel('[Mg/Fe] SDSS')
ax.set_title('Difference in\nGMM Probabilities')
fig.colorbar(s, ax=ax, label='|P(GMM1_SDSS) - P(GMM1_GALAH)|', orientation='horizontal')

plt.tight_layout()
plt.show()

# 5) If there is time: Can we classify accreted vs. thin disc vs. thick disc stars?

In [None]:
# For this, we will use the projection of [Al/Fe] vs. [Mn/Fe].

data = np.vstack([sdss_galah['al_fe_sdss'], sdss_galah['mg_fe_sdss'] - sdss_galah['mn_fe_sdss']]).T
data_labels = ["[Al/Fe]", "[Mg/Mn]"]

f, gs = plt.subplots(1,2,figsize=(10,4),sharex=True,sharey=True)
ax = gs[0]
h = ax.hist2d(
    sdss_galah['al_fe_sdss'],
    sdss_galah['mg_fe_sdss'] - sdss_galah['mn_fe_sdss'],
    bins = 100,
    norm = LogNorm()
)
cbar = plt.colorbar(h[-1], ax=ax, label = 'Nr.')
ax.set_xlabel('[Al/Fe] SDSS')
ax.set_ylabel('[Mg/Fe] - [Mn/Fe] SDSS')
ax.set_title('SDSS DR19')

ax = gs[1]
h = ax.hist2d(
    sdss_galah['al_fe_galah'],
    sdss_galah['mg_fe_galah'] - sdss_galah['mn_fe_galah'],
    bins = 100,
    norm = LogNorm()
)
cbar = plt.colorbar(h[-1], ax=ax, label = 'Nr.')
ax.set_xlabel('[Al/Fe] GALAH')
ax.set_ylabel('[Mg/Fe] - [Mn/Fe] GALAH')
ax.set_title('GALAH DR4')

plt.tight_layout()
plt.show()

In [None]:
# Because we have so many thin and thick disc stars, but actually care about the accreted stars in the top-left, we have to try different numbers of components
# and see which one works best

n_components = np.arange(2, 11)
gmm_models = {}

for n in n_components:
    gmm = GaussianMixture(
        n_components=n, 
        random_state=42, # optional, important for reproducing results
        n_init=5 # optional, but a good idea for robustness
    )
    gmm.fit(data)
    gmm_models[n] = gmm
    print(f"Fitted GMM with {n} components: BIC={gmm.bic(data)}, AIC={gmm.aic(data)}")

# Plot BIC/AIC to find the optimal number of components
bics = [gmm_models[n].bic(data) for n in n_components]
aics = [gmm_models[n].aic(data) for n in n_components]

plt.figure(figsize=(8,6))
plt.plot(n_components, bics, label='BIC')
plt.plot(n_components, aics, label='AIC')
plt.xlabel('Number of Components')
plt.ylabel('BIC/AIC')
plt.legend()
plt.title('BIC and AIC for GMM with Different Components')
plt.show() 

In [None]:
# This seems to stabilize around 5 components, so let's use that
nr_components_best = 5
best_gmm = gmm_models[nr_components_best]

# Extract the predicted labels and probabilities
labels = best_gmm.predict(data)
probabilities = best_gmm.predict_proba(data)

# Plotting the [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] distribution colored by GMM components
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5),sharex=True,sharey=True)
# Left panel: [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] colored by component labels
scatter1 = ax1.scatter(data[:, 0], data[:, 1], c=labels, cmap='tab10', s=5, alpha=0.5)
# Plot the GMM ellipses
def plot_gmm_ellipses_custom(gmm, ax):
    for n in range(nr_components_best):
        mean = gmm.means_[n]
        covar = gmm.covariances_[n]
        v, w = np.linalg.eigh(covar)
        v = np.sqrt(2.0) * np.sqrt(v)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan2(u[1], u[0]) * 180 / np.pi
        ellipse = patches.Ellipse(mean, v[0], v[1], angle = 180 + angle, color='k', alpha=0.3)
        ax.add_patch(ellipse)
plot_gmm_ellipses_custom(best_gmm, ax1)  # Add ellipses for the GMM components
ax1.set_xlabel('[Al/Fe]')
ax1.set_ylabel('[Mg/Fe] - [Mn/Fe]')
ax1.set_xlim(-1,1)
ax1.set_ylim(-1,1)
fig.colorbar(scatter1, ax=ax1, label='GMM Component', orientation='horizontal')
# Right panel: [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] colored by certainty of component assignment
certainty = np.max(probabilities, axis=1)  # Take the maximum probability (certainty) for each point
scatter2 = ax2.scatter(data[:, 0], data[:, 1], c=certainty, cmap='viridis', s=5, alpha=0.5)
ax2.set_xlabel('[Al/Fe]')
ax2.set_ylabel('[Mg/Fe] - [Mn/Fe]')
fig.colorbar(scatter2, ax=ax2, label='Component Membership Probability', orientation='horizontal')
plt.tight_layout()
plt.show()

In [None]:
# This clearly has not worked well, because the accreted stars are not cleanly separated
# But we know that accreted stars have lower [Fe/H] than the disc stars, so let's use that to downweight the disc stars

feh_poor = sdss_galah['fe_h_sdss'] < -0.5
mgfe = sdss_galah['mg_fe_sdss']
alfe = sdss_galah['al_fe_sdss']
mnfe = sdss_galah['mn_fe_sdss']
data = np.vstack([alfe[feh_poor], mgfe[feh_poor] - mnfe[feh_poor]]).T
data_labels = ["[Al/Fe]", "[Mg/Mn]"]

# 
f, ax = plt.subplots(1,1,figsize=(6,4))
h = ax.hist2d(
    data[:,0],
    data[:,1],
    bins = 100,
    norm = LogNorm()
)
cbar = plt.colorbar(h[-1], ax=ax, label = 'Nr.')
ax.set_xlabel('[Al/Fe] SDSS')
ax.set_ylabel('[Mg/Fe] - [Mn/Fe] SDSS')
ax.set_title('SDSS DR19')
plt.show()

In [None]:
n_components = np.arange(2, 11)
gmm_models = {}

for n in n_components:
    gmm = GaussianMixture(
        n_components=n, 
        random_state=42, # optional, important for reproducing results
        n_init=5 # optional, but a good idea for robustness
    )
    gmm.fit(data)
    gmm_models[n] = gmm
    print(f"Fitted GMM with {n} components: BIC={gmm.bic(data)}, AIC={gmm.aic(data)}")

# Plot BIC/AIC to find the optimal number of components
bics = [gmm_models[n].bic(data) for n in n_components]
aics = [gmm_models[n].aic(data) for n in n_components]

plt.figure(figsize=(8,6))
plt.plot(n_components, bics, label='BIC')
plt.plot(n_components, aics, label='AIC')
plt.xlabel('Number of Components')
plt.ylabel('BIC/AIC')
plt.legend()
plt.title('BIC and AIC for GMM with Different Components')
plt.show() 

In [None]:
# This seems to stabilize around 5 components, so let's use that
nr_components_best = 5
best_gmm = gmm_models[nr_components_best]

# Extract the predicted labels and probabilities
labels = best_gmm.predict(data)
probabilities = best_gmm.predict_proba(data)

# Plotting the [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] distribution colored by GMM components
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5),sharex=True,sharey=True)
# Left panel: [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] colored by component labels
scatter1 = ax1.scatter(data[:, 0], data[:, 1], c=labels, cmap='tab10', s=5, alpha=0.5)
# Plot the GMM ellipses
def plot_gmm_ellipses_custom(gmm, ax):
    for n in range(nr_components_best):
        mean = gmm.means_[n]
        covar = gmm.covariances_[n]
        v, w = np.linalg.eigh(covar)
        v = np.sqrt(2.0) * np.sqrt(v)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan2(u[1], u[0]) * 180 / np.pi
        ellipse = patches.Ellipse(mean, v[0], v[1], angle = 180 + angle, color='k', alpha=0.3)
        ax.add_patch(ellipse)
plot_gmm_ellipses_custom(best_gmm, ax1)  # Add ellipses for the GMM components
ax1.set_xlabel('[Al/Fe]')
ax1.set_ylabel('[Mg/Fe] - [Mn/Fe]')
ax1.set_xlim(-1,1)
ax1.set_ylim(-1,1)
fig.colorbar(scatter1, ax=ax1, label='GMM Component', orientation='horizontal')
# Right panel: [Al/Fe] vs. [Mg/Fe]-[Mn/Fe] colored by certainty of component assignment
certainty = np.max(probabilities, axis=1)  # Take the maximum probability (certainty) for each point
scatter2 = ax2.scatter(data[:, 0], data[:, 1], c=certainty, cmap='viridis', s=5, alpha=0.5)
ax2.set_xlabel('[Al/Fe]')
ax2.set_ylabel('[Mg/Fe] - [Mn/Fe]')
fig.colorbar(scatter2, ax=ax2, label='Component Membership Probability', orientation='horizontal')
plt.tight_layout()
plt.show()

# Solutions to make the notebook run during tutorial

In [None]:
# Filters
galah = galah[np.where((galah['flag_mg_fe_galah'] == 0) & (galah['flag_al_fe_galah'] == 0) & (galah['flag_mn_fe_galah'] == 0))]
galah = galah[galah['gaiadr3_source_id'] > 0]
sdss  = sdss[sdss['gaiadr3_source_id'] > 0]

In [None]:
# for this tutorial, let's again use variables rather than keywords
feh  = np.array(galah['fe_h_galah'])
mgfe = np.array(galah['mg_fe_galah'])

data = np.vstack([feh, mgfe]).T
data_labels = ["[Fe/H]", "[Mg/Fe]"]

In [None]:
# Initialize parameters to some starting values (mu_0, Sigma_0, pi_0)
pi = np.array([0.5,0.5])
mu = np.array([[-1.0,1.0],[1.0,-1.0]])
Sigma = np.array([ [[1.0,0.0],[0.0,1.0]] , [[1.0,0.0],[0.0,1.0]]])

# Number of iterations for the EM algorithm
trials = 50

# Array to keep track of the log-likelihood values
ll = np.zeros(trials)

# Run the EM algorithm
for i in range(trials):
    # Compute the log-likelihood of the current model
    ll[i] = log_likelihood(data, mu, Sigma, pi)

    # Perform the E-step to update responsibilities
    gamma = e_step(data, mu, Sigma, pi)
    
    # Perform the M-step to update parameters
    mu, Sigma, pi = m_step(data, gamma)

    # Visualize the model every 5 iterations
    if (i < 5) | (i % 5 == 0):
        print(f"Iteration {i}: Current Log-Likelihood = {ll[i]}")
        
        # show the color 
        if i < 1:
            plot_data()
        else:
            plot_data(redness = gamma[:,1])

        # Plot the Gaussian components
        plot_components(mu, Sigma, ['b', 'r'], 0.2)
        
        # Add a title to indicate the current iteration and log-likelihood
        plt.title(f"Iteration {i}: Log-Likelihood = {np.round(ll[i])}")
        plt.show()
        plt.savefig(f"figures/em_iteration_{i:03d}_galah_dr4.png", dpi=150)
        