In [7]:
import numpy as np
from scipy.linalg import inv, pinv
from tabulate import tabulate

class RegularizedDiscriminantAnalysis:
    def __init__(self, alpha=0.0, gamma=0.0):
        self.alpha = alpha
        self.gamma = gamma
        self.classes = None
        self.means = {}
        self.covariances = {}
        self.priors = {}

    def fit(self, X, y):
        self.classes = np.unique(y)
        n, p = X.shape
        pooled_cov = np.zeros((p, p))

        for c in self.classes:
            X_c = X[y == c]
            n_c = X_c.shape[0]

            # Calculate means
            self.means[c] = np.mean(X_c, axis=0)

            # Calculate covariances
            cov = np.cov(X_c, rowvar=False)
            self.covariances[c] = self.alpha * np.identity(p) + (1 - self.alpha) * cov

            # Calculate priors
            self.priors[c] = n_c / n

            # Update pooled covariance
            pooled_cov += self.priors[c] * cov

        # Regularize pooled covariance
        self.pooled_cov = self.gamma * pooled_cov + (1 - self.gamma) * np.identity(p)

    def predict(self, X):
        predictions = []
        for x in X:
            class_probs = self.compute_class_probs(x)
            predicted_class = max(class_probs, key=class_probs.get)
            predictions.append(predicted_class)
        return predictions

    def compute_class_probs(self, x):
        class_probs = {}
        for c in self.classes:
            diff = x - self.means[c]
            cov = self.gamma * self.covariances[c] + (1 - self.gamma) * self.pooled_cov
            term = np.dot(np.dot(diff, pinv(cov)), diff.T)
            class_probs[c] = -0.5 * term + np.log(self.priors[c])
        return class_probs

# Example usage
rda = RegularizedDiscriminantAnalysis(alpha=0.1, gamma=0.1)
# rda.fit(X_train, y_train)
# predictions = rda.predict(X_test)


In [2]:
import argparse
import pandas as pd
import numpy as np
from util import *
from sklearn.metrics import accuracy_score
from tabulate import tabulate
from sklearn.model_selection import KFold

In [3]:
args = argparse.Namespace()
args.model = "LDA,QDA,RDA_0.5,RDA_0.25,RDA_0.75,KDE_1.0,KDE_2.0,MQD_0.01,MQD_0.5,MQD_0.99"
args.model = "MQD_0.01,MQD_0.5,MQD_0.99"
args.dataset = "breast_cancer_wisconsin_diagnostic,iris,wine,breast_cancer_wisconsin_original,ionosphere"

model_list = args.model.split(',')
dataset_list = args.dataset.split(',')

all_res = []
for dataset_name in dataset_list:
    dataset = get_dataset(dataset_name)
    dataset = format_dataset(dataset)
    icv_res=[]
    for cv in range(3):
        # np.random.shuffle(dataset)
        X_train, X_test, y_train, y_test  = split_dataset(dataset,random_state=cv)
        print(dataset_name)
        # print(X_train.shape , X_test.shape, len(np.unique(y_train)))
        result = []
        for i in np.linspace(0,1,11):
            res = []
            for j in np.linspace(0,1,11):
                model = RegularizedDiscriminantAnalysis(alpha=i, gamma=j)
                model.fit(X_train,y_train)
                predict_y = model.predict(X_test)
                acc = accuracy_score(predict_y,y_test)
                res.append(acc)
                print( "%3.3f"%acc,end=" ")
            print(" ")
            result.append(res)
        # print(result)
        icv_res.append(result)
    icv_res = np.array(icv_res)
    icv_res = icv_res.mean(axis=0)
    all_res.append(icv_res)


breast_cancer_wisconsin_diagnostic
0.899 0.947 0.952 0.947 0.947 0.941 0.947 0.941 0.931 0.920 0.883  
0.899 0.947 0.952 0.957 0.947 0.952 0.941 0.936 0.936 0.920 0.883  
0.899 0.947 0.952 0.957 0.957 0.952 0.941 0.941 0.926 0.920 0.888  
0.899 0.952 0.952 0.963 0.963 0.957 0.947 0.941 0.931 0.920 0.888  
0.899 0.952 0.952 0.963 0.963 0.963 0.947 0.941 0.931 0.926 0.888  
0.899 0.957 0.957 0.957 0.963 0.963 0.957 0.947 0.931 0.926 0.883  
0.899 0.957 0.957 0.963 0.952 0.957 0.963 0.952 0.936 0.926 0.872  
0.899 0.957 0.957 0.957 0.957 0.952 0.957 0.963 0.936 0.926 0.872  
0.899 0.963 0.957 0.957 0.957 0.957 0.957 0.957 0.952 0.926 0.878  
0.899 0.957 0.957 0.957 0.957 0.957 0.957 0.957 0.957 0.947 0.872  
0.899 0.957 0.957 0.957 0.957 0.957 0.957 0.957 0.957 0.957 0.899  
breast_cancer_wisconsin_diagnostic
0.894 0.926 0.915 0.915 0.910 0.910 0.915 0.915 0.920 0.910 0.840  
0.894 0.926 0.920 0.915 0.910 0.910 0.910 0.910 0.899 0.894 0.862  
0.894 0.926 0.920 0.915 0.915 0.910 0.910 0.90

In [6]:
all_res

[array([[0.89716312, 0.93794326, 0.93617021, 0.93617021, 0.93617021,
         0.93439716, 0.93971631, 0.93794326, 0.93794326, 0.92375887,
         0.87943262],
        [0.89716312, 0.93794326, 0.93794326, 0.93971631, 0.93617021,
         0.93794326, 0.93617021, 0.93439716, 0.93085106, 0.91666667,
         0.88120567],
        [0.89716312, 0.93794326, 0.93794326, 0.94148936, 0.94148936,
         0.93794326, 0.93439716, 0.93439716, 0.92730496, 0.92198582,
         0.88297872],
        [0.89716312, 0.93971631, 0.93971631, 0.94503546, 0.94503546,
         0.94148936, 0.93617021, 0.93439716, 0.92553191, 0.92021277,
         0.88475177],
        [0.89716312, 0.93794326, 0.94148936, 0.94680851, 0.94326241,
         0.94326241, 0.93617021, 0.93617021, 0.92730496, 0.91843972,
         0.88297872],
        [0.89716312, 0.94148936, 0.94503546, 0.94680851, 0.94680851,
         0.94503546, 0.94148936, 0.93439716, 0.92730496, 0.92198582,
         0.87943262],
        [0.89716312, 0.94503546, 0.94503

In [10]:
t = all_res[0].tolist()
t

[[0.8971631205673759,
  0.9379432624113475,
  0.9361702127659575,
  0.9361702127659575,
  0.9361702127659575,
  0.9343971631205674,
  0.9397163120567376,
  0.9379432624113475,
  0.9379432624113475,
  0.923758865248227,
  0.8794326241134751],
 [0.8971631205673759,
  0.9379432624113475,
  0.9379432624113475,
  0.9397163120567376,
  0.9361702127659575,
  0.9379432624113475,
  0.9361702127659575,
  0.9343971631205674,
  0.9308510638297873,
  0.9166666666666666,
  0.8812056737588653],
 [0.8971631205673759,
  0.9379432624113475,
  0.9379432624113475,
  0.9414893617021276,
  0.9414893617021276,
  0.9379432624113475,
  0.9343971631205674,
  0.9343971631205674,
  0.9273049645390071,
  0.9219858156028368,
  0.8829787234042553],
 [0.8971631205673759,
  0.9397163120567376,
  0.9397163120567376,
  0.9450354609929077,
  0.9450354609929077,
  0.9414893617021276,
  0.9361702127659575,
  0.9343971631205674,
  0.9255319148936171,
  0.9202127659574467,
  0.8847517730496454],
 [0.8971631205673759,
  0.937

In [15]:
t = all_res[0].tolist()
tt = [[i]+x for i,x in zip(np.linspace(0,1,11),t)]

latex_table = tabulate(tt, headers=['h']+["%3.3f"%i for i in np.linspace(0,1,11).tolist()] , tablefmt='latex_booktabs' ,floatfmt="3.3f")
print(latex_table)

\begin{tabular}{rrrrrrrrrrrr}
\toprule
     h &   0.000 &   0.100 &   0.200 &   0.300 &   0.400 &   0.500 &   0.600 &   0.700 &   0.800 &   0.900 &   1.000 \\
\midrule
 0.000 &   0.897 &   0.938 &   0.936 &   0.936 &   0.936 &   0.934 &   0.940 &   0.938 &   0.938 &   0.924 &   0.879 \\
 0.100 &   0.897 &   0.938 &   0.938 &   0.940 &   0.936 &   0.938 &   0.936 &   0.934 &   0.931 &   0.917 &   0.881 \\
 0.200 &   0.897 &   0.938 &   0.938 &   0.941 &   0.941 &   0.938 &   0.934 &   0.934 &   0.927 &   0.922 &   0.883 \\
 0.300 &   0.897 &   0.940 &   0.940 &   0.945 &   0.945 &   0.941 &   0.936 &   0.934 &   0.926 &   0.920 &   0.885 \\
 0.400 &   0.897 &   0.938 &   0.941 &   0.947 &   0.943 &   0.943 &   0.936 &   0.936 &   0.927 &   0.918 &   0.883 \\
 0.500 &   0.897 &   0.941 &   0.945 &   0.947 &   0.947 &   0.945 &   0.941 &   0.934 &   0.927 &   0.922 &   0.879 \\
 0.600 &   0.897 &   0.945 &   0.945 &   0.949 &   0.945 &   0.947 &   0.943 &   0.938 &   0.929 &   0.922 &   0

In [23]:
t = all_res[3].tolist()
latex_table = tabulate(t, floatfmt="3.3f")
print(latex_table)

-----  -----  -----  -----  -----  -----  -----  -----  -----  -----  -----
0.948  0.955  0.957  0.962  0.958  0.951  0.954  0.944  0.938  0.922  0.889
0.948  0.955  0.955  0.961  0.958  0.954  0.949  0.952  0.941  0.932  0.918
0.948  0.955  0.955  0.958  0.961  0.958  0.951  0.952  0.948  0.938  0.925
0.948  0.952  0.955  0.955  0.961  0.960  0.957  0.951  0.954  0.944  0.935
0.948  0.949  0.955  0.955  0.957  0.962  0.958  0.957  0.952  0.951  0.941
0.948  0.951  0.954  0.955  0.955  0.958  0.962  0.960  0.955  0.955  0.945
0.948  0.948  0.954  0.954  0.955  0.955  0.957  0.962  0.960  0.957  0.955
0.948  0.951  0.952  0.955  0.954  0.955  0.955  0.955  0.961  0.962  0.961
0.948  0.949  0.949  0.952  0.954  0.955  0.954  0.955  0.955  0.958  0.962
0.948  0.948  0.948  0.951  0.951  0.951  0.952  0.954  0.954  0.955  0.955
0.948  0.945  0.947  0.947  0.947  0.947  0.947  0.947  0.947  0.945  0.948
-----  -----  -----  -----  -----  -----  -----  -----  -----  -----  -----
