In [1]:
import numpy as np

# 1. K-means and K-medoids

In [2]:
X = np.array([[ 0.,-6.],
              [ 4., 4.],
              [ 0., 0.],
              [-5., 2.]])

centers = np.array([[-5., 2.],
                    [ 0.,-6.]])

## K-medoids

In [3]:
class kMedodis:
  def __init__(self, max_iter=10, dist="l2"):
    self.max_iter = max_iter
    if dist == "l1":
      self.dist = self._l1
    elif dist == "l2":
      self.dist = self._l2
  
  def fit(self, X, centers, k):
    centers = centers.copy()
    for i in range(self.max_iter):
      idxs = self._step(X, centers, k)
      centers = self._update(X, idxs, centers, k)
    
    return idxs, centers
  
  def _step(self, X, centers, k):
    N = X.shape[0]
    idxs = np.zeros(N)
    for i in range(N):
      x = X[i]
      dists = np.zeros(k)
      for j in range(k):
        center = centers[j]
        dists[j] = self.dist(x, center)
      
      idxs[i] = np.argmin(dists)
    
    return idxs
  
  def _update(self, X, idxs, centers, k):
    N = X.shape[0]
    for j in range(k):
      xs = X[idxs==j]
      dists = np.zeros(N)
      for i in range(N):
        for x in xs:
          dists[i] += self.dist(x, X[i])

      centers[j] = X[np.argmin(dists)]

    return centers

  def _l1(self, x, y):
    return np.linalg.norm(x-y, ord=1)
  
  def _l2(self, x, y):
    return np.linalg.norm(x-y, ord=2)

In [4]:
km = kMedodis(10, "l1")
km.fit(X, centers, 2)

(array([1., 0., 1., 0.]), array([[ 4.,  4.],
        [ 0., -6.]]))

In [5]:
km = kMedodis(10, "l2")
km.fit(X, centers, 2)

(array([1., 0., 0., 0.]), array([[ 0.,  0.],
        [ 0., -6.]]))

## K-means

In [6]:
class kMeans(kMedodis):
  def _update(self, X, idxs, centers, k):
    for j in range(k):
      xs = X[idxs == j]
      centers[j] = np.mean(xs, axis=0)
    return centers

In [7]:
km = kMeans(10, "l1")
km.fit(X, centers, 2)

(array([1., 0., 1., 0.]), array([[-0.5,  3. ],
        [ 0. , -3. ]]))

# 3. EM Algorithm

In [8]:
pis = np.array([0.5, 0.5])
mus = np.array([6, 7])
sigs= np.array([1, 2])

X = np.array([-1, 0, 4, 5, 6])

In [9]:
def norm(x, mu, sig):
  return 1/np.sqrt(2*np.pi*sig**2) * np.exp(-(x-mu)**2/(2*sig**2))

## Likelihood Function

In [10]:
def logLikelihood(X, pis, mus, sigs):
  ll = 0
  for x in X:
    p = 0
    for pi, mu, sig in zip(pis, mus, sigs):
      p += pi * norm(x, mu, sig)
    ll += np.log(p)
  
  return ll

In [11]:
logLikelihood(X, pis, mus, sigs)

-24.512532330086678

## E-Step

In [12]:
def EStep(X, pis, mus, sigs):
  N = X.shape[0]
  k = pis.shape[0]

  gammas = np.zeros((N, k))
  for i in range(N):
    for j in range(k):
      gammas[i, j] = pis[j] * norm(X[i], mus[j], sigs[j])
  
  gammas = gammas / np.sum(gammas, axis=1, keepdims=True)

  return gammas

In [13]:
gammas = EStep(X, pis, mus ,sigs)
np.argmax(gammas, axis=1)+1

array([2, 2, 2, 1, 1])

In [14]:
gammas

array([[1.36512049e-07, 9.99999863e-01],
       [1.39244156e-05, 9.99986076e-01],
       [4.54661673e-01, 5.45338327e-01],
       [6.66666667e-01, 3.33333333e-01],
       [6.93842896e-01, 3.06157104e-01]])

## M-Step

In [15]:
def MStep(X, gammas):
  ## update mus
  mus = gammas.T @ X
  mus = mus / gammas.sum(axis=0)

  ## update sigs
  sigs = np.diag(gammas.T @ (X.reshape((-1, 1)) - mus)**2)
  sigs = sigs / gammas.sum(axis=0)
  sigs = np.sqrt(sigs)

  ## update pis
  pis = gammas.mean(axis=0)

  return mus, sigs, pis

In [16]:
print(mus, sigs, pis)
print(MStep(X, gammas))

[6 7] [1 2] [0.5 0.5]
(array([5.13172803, 1.47103149]), array([0.78457791, 2.63951172]), array([0.36303706, 0.63696294]))


In [17]:
print(mus, sigs, pis)

for i in range(10):
  gammas = EStep(X, pis, mus ,sigs)
  mus, sigs, pis = MStep(X, gammas)

  print(mus, sigs, pis)

[6 7] [1 2] [0.5 0.5]
[5.13172803 1.47103149] [0.78457791 2.63951172] [0.36303706 0.63696294]
[5.13978976 1.02237472] [0.77435184 2.42503305] [0.43173332 0.56826668]
[5.11546151 0.58512637] [0.78648    2.13473459] [0.48889841 0.51110159]
[5.0734208  0.12831221] [0.80041232 1.70726345] [0.54026878 0.45973122]
[ 5.02314395 -0.32454628] [0.81208719 1.01508805] [0.58427959 0.41572041]
[ 5.00004802 -0.49967409] [0.8164869  0.50144798] [0.59997106 0.40002894]
[ 4.99999998 -0.5       ] [0.81649663 0.5       ] [0.6 0.4]
[ 4.99999998 -0.5       ] [0.81649664 0.5       ] [0.6 0.4]
[ 4.99999998 -0.5       ] [0.81649664 0.5       ] [0.6 0.4]
[ 4.99999998 -0.5       ] [0.81649664 0.5       ] [0.6 0.4]
