In [2]:
import numpy as np

x = np.arange(1, 10).reshape(3, 3)

In [3]:
x

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [6]:
x[:, 1][:, None].shape

(3, 1)

In [6]:
x[None, :].shape

(1, 8)

In [7]:
x.sum(axis=0, keepdims=True)

array([36])

In [23]:
import numpy as np 
from scipy.stats import multivariate_normal as mvn 

seed = 1010

def init(X, K):
    pi = np.full(K, 1.0 / K)

    rng = np.random.default_rng(seed)
    rand_idx = rng.choice(X.shape[0], size=K, replace=False)
    mu = X[rand_idx].copy()

    var = np.cov(X, rowvar=False)
    var = np.expand_dims(var, axis=0)
    var = np.repeat(var, K, axis=0)

    return pi, mu, var

def e_step(X, K, pi, mu, var):
    N = X.shape[0]
    gamma = np.zeros((N, K))
    mvn_dists = [mvn(mean=mu[i], cov=var[i]) for i in range(K)]
    
    for j in range(K):
        gamma[:, j] = mvn_dists[j].pdf(X) * pi[j]
    
    gamma /= gamma.sum(axis=1, keepdims=True)
    return gamma 

def m_step(X, K, gamma, mu, var):
    
    N, D = X.shape
    gamma_sum = gamma.sum(axis=0)
    pi_new = gamma_sum / N

    mu_new = (gamma.T @ X) / gamma_sum[:, None]

    var_new = np.zeros_like(var)
    for k in range(K):
        diff = X - mu_new[k]
        g = gamma[:, k][:, None]
        var_new[k] = (g * diff).T @ diff / gamma_sum[k] + 1e-6 * np.eye(D)

    return pi_new, mu_new, var_new

def compute_likelihood(X, pi, mu, var):
    N, K = X.shape[0], pi.shape[0]
    likelihood = np.zeros(N)
    for i in range(K):
        likelihood += pi[i] * mvn(mean=mu[i], cov=var[i]).pdf(X)
    return np.log(likelihood).sum()

def gmm(X, K, max_iter):

    pi, mu, var = init(X, K)
    print(mu)
    likelihood = []

    for i in range(max_iter):
        gamma = e_step(X, K, pi, mu, var)
        pi, mu, var = m_step(X, K, gamma, mu, var)
        likelihood.append(
            compute_likelihood(X, pi, mu, var)
        )

    return pi, mu, var, gamma, likelihood

In [24]:
import numpy as np 
from PIL import Image 

def load_pixel(path):
    img = Image.open(path).convert("RGB")
    arr = np.asarray(img, dtype=np.float32) / 255.0
    h, w = arr.shape[:2]
    pixels = arr.reshape(-1, 3)
    return pixels, h, w

X, h, w = load_pixel('ronaldo.png')

In [25]:
K = 16
pi, mu, var, gamma, _ = gmm(X, K, max_iter=50)

labels = np.argmax(gamma, axis=1)
palette = mu * 255
q = palette[labels].astype(np.uint8)
quant_img = q.reshape(h, w, 3)

[[0.7921569  0.5882353  0.43529412]
 [0.09411765 0.10588235 0.23529412]
 [0.2901961  0.2784314  0.2509804 ]
 [0.49019608 0.38431373 0.02745098]
 [0.03137255 0.14117648 0.28627452]
 [0.5529412  0.5686275  0.5019608 ]
 [0.39607844 0.22352941 0.1254902 ]
 [0.78431374 0.62352943 0.49803922]
 [0.00784314 0.29803923 0.8901961 ]
 [0.47058824 0.41568628 0.2784314 ]
 [0.42352942 0.44313726 0.46666667]
 [0.5176471  0.44313726 0.31764707]
 [0.4627451  0.47058824 0.41960785]
 [0.4509804  0.43529412 0.3372549 ]
 [0.4        0.34117648 0.17254902]
 [0.03529412 0.07843138 0.14901961]]


In [27]:
Image.fromarray(quant_img).show()  

In [26]:
labels

array([2, 2, 2, ..., 8, 8, 8], dtype=int64)

In [21]:
mu

array([[0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285],
       [0.36173343, 0.34768365, 0.26570285]])