In [1]:
###Author: Tuan Le
###Email: tuanle@hotmail.de

## Task:
## Implement quadratic discriminant analysis as explained in lecture

Formulas needed for the quadratic discriminant analysis are the following estimates:

$$
\begin{array}{c}
1.  \hat{\pi}_j = \frac{n_j}{n} \cr \cr
2.  \hat{\mu}_j = \sum_{i: y_i=j} \frac{x_i}{n_j} \cr \cr
3.  \hat{\Sigma}_j = \frac{1}{n_j - 1}\sum_{i: y_i=j}(x_i - \hat{\mu}_j)(x_i - \hat{\mu}_j)' \cr \cr
4.  \pi_k(x) \propto \pi_k \cdot p(x|y=k) \cr \cr
\end{array}
$$

where $p(x|y=k)$ is a multivariate normal distribution with specific class-covariance matrix:

$$
\begin{array}{c}
5. p(x | y = k) = \frac{1}{(2\pi)^{0.5p}det(\Sigma_k)^{0.5}}exp(-\frac{1}{2}(x - \mu_k)'{\Sigma_k}^{-1}(x-\mu_k))
\end{array}
$$


In [2]:
import pandas as pd
import numpy as np
from sklearn import datasets
from scipy.stats import multivariate_normal

iris = datasets.load_iris()
iris_features = pd.DataFrame(iris.data)
iris_features.columns = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]
iris_target = pd.DataFrame(iris.target)
iris_target.columns = ["Species"]
iris_target["Species"] = iris_target["Species"].map({0: "setosa", 1: "versicolor", 2: "virginica"})

iris_data = pd.concat([iris_features, iris_target], axis = 1)

In [3]:
class myQDA:
    #constructor
    def __init__(self, data, target):
        self.data = data
        self.feature_data = data.drop([target], axis = 1)
        self.target = target
        self.n = data.shape[0]
        self.p = data.shape[1] - 1
        self.classes = data[target].unique()
        self.n_j = data[target].value_counts()
        self.label_col = np.where(self.data.columns == self.target)
        self.pi_j = None
        self.mu_j = None
        self.sigma_j = None
        self.probs = None

    #helpers
    def dropCol(self, data, col):
        return data.drop(data.columns[[col]], axis = 1)
        

    def estimateParams(self):
        self.pi_j = self.n_j / self.n
        #calculate class means and covariance
        #initialize empty dictionaries
        self.mu_j = {cl: None for cl in self.classes} 
        self.sigma_j = {cl: None for cl in self.classes} 
        for cl in self.classes:
            #filter dataset by class
            idx = self.data[self.target] == cl
            filtered_data = self.data[idx].drop([self.target], axis=1)
            self.mu_j[cl] = filtered_data.mean()
            self.sigma_j[cl] = np.cov(filtered_data, rowvar=False)
        return({'pi_j': self.pi_j, 'mu_j':self.mu_j ,'sigma_j':self.sigma_j })
    
    def predictQDA(self, data):
        data = data
        self.probs = pd.DataFrame(columns = self.classes)
        for cl in self.classes:
            self.probs[cl] = multivariate_normal.pdf(data, mean = self.mu_j[cl], cov = self.sigma_j[cl]) * self.pi_j[cl]
        
        row_sums = self.probs.apply(np.sum, axis=1)
        for i in range(0, self.probs.shape[0] - 1):
            self.probs.iloc[i] = self.probs.iloc[i]/row_sums[i]
        
        self.predicted_class = self.probs.idxmax(axis=1)
        return({"class_prob": self.predicted_class})

In [4]:
#Test the function:        
testQDA = myQDA(data = iris_data, target = "Species")
params = testQDA.estimateParams()

print(testQDA.pi_j)
print(testQDA.mu_j)
print(testQDA.sigma_j)
type(testQDA.mu_j)
probs = testQDA.predictQDA(data = iris_data.drop("Species", axis=1))
print(probs)


virginica     0.333333
setosa        0.333333
versicolor    0.333333
Name: Species, dtype: float64
{'setosa': Sepal.Length    5.006
Sepal.Width     3.418
Petal.Length    1.464
Petal.Width     0.244
dtype: float64, 'versicolor': Sepal.Length    5.936
Sepal.Width     2.770
Petal.Length    4.260
Petal.Width     1.326
dtype: float64, 'virginica': Sepal.Length    6.588
Sepal.Width     2.974
Petal.Length    5.552
Petal.Width     2.026
dtype: float64}
{'setosa': array([[0.12424898, 0.10029796, 0.01613878, 0.01054694],
       [0.10029796, 0.14517959, 0.01168163, 0.01143673],
       [0.01613878, 0.01168163, 0.03010612, 0.00569796],
       [0.01054694, 0.01143673, 0.00569796, 0.01149388]]), 'versicolor': array([[0.26643265, 0.08518367, 0.18289796, 0.05577959],
       [0.08518367, 0.09846939, 0.08265306, 0.04120408],
       [0.18289796, 0.08265306, 0.22081633, 0.07310204],
       [0.05577959, 0.04120408, 0.07310204, 0.03910612]]), 'virginica': array([[0.40434286, 0.09376327, 0.3032898 , 0.0490938