# Linear Discriminant Analysis

## Python Code for LDA

In [1]:
import warnings
import numpy as np

def _class_means(X, y):
    classes = np.unique(y)
    means = np.array([np.mean(X[y == i], axis=0) for i in classes])
    return means

def _class_cov(X, y, priors):
    _, p = X.shape
    classes = np.unique(y)
    n_classes = len(classes)
    cov = np.zeros((p, p))
    cov_ = np.array([np.cov(X[y == i], rowvar=False) for i in classes])
    for i in range(n_classes):
        cov += priors[i] * cov_[i]
    return cov

def softmax(X):
    X = np.copy(X)
    max_prob = np.max(X, axis=1).reshape((-1, 1))
    X -= max_prob
    np.exp(X, X)
    sum_prob = np.sum(X, axis=1).reshape((-1, 1))
    X /= sum_prob
    return X

class LinearDiscriminantAnalysis:
    """My Linear Discriminant Analysis"""
    
    def __init__(self, priors=None):
        self.priors = priors
    
    def _solve(self, X, y):
        self.means_ = _class_means(X, y)
        self.covariance_ = _class_cov(X, y, self.priors_)
        self.coef_ = self.means_.dot(np.linalg.inv(self.covariance_))
        self.intercept_ = -0.5 * np.diag(self.means_.dot(self.coef_.T)) + np.log(self.priors_)
    
    def fit(self, X, y):
        X = np.array(X) ; y = np.array(y)
        self.classes_ = np.unique(y)
        n_samples, _ = X.shape
        n_classes = len(self.classes_)
        
        if n_samples == n_classes:
            raise ValueError('데이터의 개수는 범주의 개수보다 많아야합니다.')
        
        if self.priors is None:
            self.priors_ = np.array([sum(y == i) / len(y) for i in self.classes_])
        else:
            self.priors_ = np.array(self.priors)
        
        if any(self.priors_ < 0):
            raise ValueError('사전 확률은 0보다 커야합니다.')
        if not np.isclose(sum(self.priors_), 1):
            warnings.warn('사전 확률의 합이 1이 아닙니다. 값을 재조정합니다', UserWarning)
            self.priors_ = self.priors_ / sum(self.priors_)
        
        self._solve(X, y)
        
        return self
    
    def decision_function(self, X):
        X = np.array(X)
        scores = X.dot(self.coef_.T) + self.intercept_
        return scores.ravel() if scores.shape[1] == 1 else scores
    
    def predict(self, X):
        decision = self.decision_function(X)
        y_pred = self.classes_.take(decision.argmax(1))
        return y_pred
    
    def predict_proba(self, X):
        decision = self.decision_function(X)
        return softmax(decision)
    
    def predict_log_proba(self, X):
        prediction = self.predict_proba(X)
        prediction[prediction == 0] += np.finfo(prediction.dtype).tiny
        return np.log(prediction)

---

## Test with Toy Data

In [2]:
from sklearn.datasets import load_iris

In [3]:
X = load_iris().data
y = load_iris().target

In [4]:
lda = LinearDiscriminantAnalysis()

In [5]:
lda.fit(X, y)

<__main__.LinearDiscriminantAnalysis at 0x7f82fdaf2af0>

In [6]:
lda.means_

array([[5.006, 3.428, 1.462, 0.246],
       [5.936, 2.77 , 4.26 , 1.326],
       [6.588, 2.974, 5.552, 2.026]])

In [7]:
lda.covariance_

array([[0.26500816, 0.09272109, 0.16751429, 0.03840136],
       [0.09272109, 0.11538776, 0.05524354, 0.0327102 ],
       [0.16751429, 0.05524354, 0.18518776, 0.04266531],
       [0.03840136, 0.0327102 , 0.04266531, 0.04188163]])

In [8]:
lda.priors_

array([0.33333333, 0.33333333, 0.33333333])

In [9]:
lda.coef_

array([[ 23.54416672,  23.5878705 , -16.43063902, -17.39841078],
       [ 15.69820908,   7.07250984,   5.21145093,   6.4342292 ],
       [ 12.44584899,   3.68527961,  12.76654497,  21.07911301]])

In [10]:
lda.intercept_

array([ -86.30846997,  -72.8526074 , -104.36831999])

In [11]:
lda.predict(X)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [12]:
lda.predict_proba(X)

array([[1.00000000e+00, 3.89635793e-22, 2.61116827e-42],
       [1.00000000e+00, 7.21796992e-18, 5.04214335e-37],
       [1.00000000e+00, 1.46384895e-19, 4.67593159e-39],
       [1.00000000e+00, 1.26853638e-16, 3.56661049e-35],
       [1.00000000e+00, 1.63738745e-22, 1.08260527e-42],
       [1.00000000e+00, 3.88328166e-21, 4.56654013e-40],
       [1.00000000e+00, 1.11346945e-18, 2.30260848e-37],
       [1.00000000e+00, 3.87758638e-20, 1.07449600e-39],
       [1.00000000e+00, 1.90281306e-15, 9.48293562e-34],
       [1.00000000e+00, 1.11180261e-18, 2.72405964e-38],
       [1.00000000e+00, 1.18527749e-23, 3.23708368e-44],
       [1.00000000e+00, 1.62164851e-18, 1.83320074e-37],
       [1.00000000e+00, 1.45922505e-18, 3.26250644e-38],
       [1.00000000e+00, 1.11721886e-19, 1.31664193e-39],
       [1.00000000e+00, 5.48739873e-30, 1.53126473e-52],
       [1.00000000e+00, 1.26150510e-27, 2.26870463e-48],
       [1.00000000e+00, 6.75433806e-25, 3.86827125e-45],
       [1.00000000e+00, 4.22374

In [13]:
lda.predict_log_proba(X)

array([[ 0.00000000e+00, -4.92968298e+01, -9.57487762e+01],
       [ 0.00000000e+00, -3.94699579e+01, -8.35778172e+01],
       [ 0.00000000e+00, -4.33680475e+01, -8.82583902e+01],
       [-2.22044605e-16, -3.66034977e+01, -7.93188626e+01],
       [ 0.00000000e+00, -5.01637701e+01, -9.66292035e+01],
       [ 0.00000000e+00, -4.69976064e+01, -9.05846479e+01],
       [ 0.00000000e+00, -4.13390509e+01, -8.43616058e+01],
       [ 0.00000000e+00, -4.46964890e+01, -8.97289669e+01],
       [-1.99840144e-15, -3.38954430e+01, -7.60383992e+01],
       [ 0.00000000e+00, -4.13405490e+01, -8.64961103e+01],
       [ 0.00000000e+00, -5.27894802e+01, -1.00139071e+02],
       [ 0.00000000e+00, -4.09630884e+01, -8.45895850e+01],
       [ 0.00000000e+00, -4.10686262e+01, -8.63157378e+01],
       [ 0.00000000e+00, -4.36382743e+01, -8.95257341e+01],
       [ 0.00000000e+00, -6.73750985e+01, -1.19308331e+02],
       [ 0.00000000e+00, -6.19374920e+01, -1.09704875e+02],
       [ 0.00000000e+00, -5.56544424e+01

---

## Compare to Scikit-Learn

In [14]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

In [15]:
lda = LinearDiscriminantAnalysis(solver='lsqr')

In [16]:
lda.fit(X, y)

LinearDiscriminantAnalysis(solver='lsqr')

In [17]:
lda.means_

array([[5.006, 3.428, 1.462, 0.246],
       [5.936, 2.77 , 4.26 , 1.326],
       [6.588, 2.974, 5.552, 2.026]])

In [18]:
lda.covariance_

array([[0.259708  , 0.09086667, 0.164164  , 0.03763333],
       [0.09086667, 0.11308   , 0.05413867, 0.032056  ],
       [0.164164  , 0.05413867, 0.181484  , 0.041812  ],
       [0.03763333, 0.032056  , 0.041812  , 0.041044  ]])

In [19]:
lda.priors_

array([0.33333333, 0.33333333, 0.33333333])

In [20]:
lda.coef_

array([[ 24.02465992,  24.06925561, -16.76595819, -17.75348039],
       [ 16.01858069,   7.21684677,   5.31780708,   6.56554   ],
       [ 12.69984591,   3.7604894 ,  13.02708671,  21.50929899]])

In [21]:
lda.intercept_

array([ -88.04744666,  -74.31697465, -106.47586504])

In [22]:
lda.predict(X)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [23]:
lda.predict_proba(X)

array([[1.00000000e+00, 1.42473310e-22, 3.69997541e-43],
       [1.00000000e+00, 3.22542393e-18, 9.15908803e-38],
       [1.00000000e+00, 6.04113635e-20, 7.72005795e-40],
       [1.00000000e+00, 6.01008392e-17, 7.06708789e-36],
       [1.00000000e+00, 5.88223407e-23, 1.50671375e-43],
       [1.00000000e+00, 1.48816789e-21, 7.18988118e-41],
       [1.00000000e+00, 4.78943170e-19, 4.11633116e-38],
       [1.00000000e+00, 1.55743408e-20, 1.72156439e-40],
       [1.00000000e+00, 9.52742714e-16, 2.00910604e-34],
       [1.00000000e+00, 4.78211583e-19, 4.66217295e-39],
       [1.00000000e+00, 4.03588528e-24, 4.19378389e-45],
       [1.00000000e+00, 7.02901743e-19, 3.26196740e-38],
       [1.00000000e+00, 6.31138642e-19, 5.60430565e-39],
       [1.00000000e+00, 4.58527692e-20, 2.11829983e-40],
       [1.00000000e+00, 1.38742889e-30, 1.34153598e-53],
       [1.00000000e+00, 3.56391683e-28, 2.41794766e-49],
       [1.00000000e+00, 2.16924684e-25, 4.79887806e-46],
       [1.00000000e+00, 1.62141

In [24]:
lda.predict_log_proba(X)

array([[ 0.00000000e+00, -5.03028875e+01, -9.77028328e+01],
       [ 0.00000000e+00, -4.02754673e+01, -8.52834869e+01],
       [ 0.00000000e+00, -4.42531097e+01, -9.00595818e+01],
       [ 0.00000000e+00, -3.73505079e+01, -8.09376148e+01],
       [ 0.00000000e+00, -5.11875205e+01, -9.86012280e+01],
       [ 0.00000000e+00, -4.79567412e+01, -9.24333142e+01],
       [ 0.00000000e+00, -4.21827050e+01, -8.60832713e+01],
       [ 0.00000000e+00, -4.56086622e+01, -9.15601703e+01],
       [-8.88178420e-16, -3.45871868e+01, -7.75902033e+01],
       [ 0.00000000e+00, -4.21842337e+01, -8.82613370e+01],
       [ 0.00000000e+00, -5.38668166e+01, -1.02182726e+02],
       [ 0.00000000e+00, -4.17990698e+01, -8.63159030e+01],
       [ 0.00000000e+00, -4.19067614e+01, -8.80772835e+01],
       [ 0.00000000e+00, -4.45288514e+01, -9.13527899e+01],
       [ 0.00000000e+00, -6.87501005e+01, -1.21743195e+02],
       [ 0.00000000e+00, -6.32015224e+01, -1.11943750e+02],
       [ 0.00000000e+00, -5.67902473e+01