In [1]:
import numpy as np
import math

# Probabilistic Neural Network with 4 layers
class PNN(object):
    def __init__(self):
        self.L2 = []         # Layer 2 that holds the patterns
        print('Empty PNN created')

    def train(self, X, y, p=2):
        self.n_ = X.shape[1]  # num of features
        self.p_ = p           # num of classes
        
        # Layer 2 (Pattern): Set up empty lists for each class
        for k in range(self.p_):
            self.L2.append([])

        # Enter patterns into Layer 2
        for i in range(X.shape[0]):
            self.L2[y[i]].append(X[i])

        print('Trained.')

    def crossValidate(self, X, y, sigma=0.5):
        result = self.predict(X, sigma)
        num_correct = sum(result[:, 0] == y)

        print('Cross validation accuracy with sigma', sigma, ':', num_correct/len(y) * 100, '%')

    def predict(self, X, sigma=0.5):
        self.sigma_ = sigma   # smoothing parameter, not standard deviation
        
        m = X.shape[0]
        accL3 = np.zeros((m, self.p_))
        accL4 = np.zeros(m)
       
        # Layer 1 (Input): x
        for i in range(m):
            x = X[i]
#            if i == 50:    # After 50
#                print('-------------------------------------------')
#            elif i == 100:    # After 100
#                print('-------------------------------------------')

            # Layer 3 (Averaging): for each class
            self.L3_ = np.zeros(self.p_)
            for k in range(self.p_):
                for ki in range(len(self.L2[k])):
                    self.L3_[k] += self._activation(x, self.L2[k][ki])
                self.L3_[k] /= len(self.L2[k])
                
                # Multiply constant
                self.L3_[k] *= (math.sqrt(2*math.pi) * self.sigma_)**(- self.n_)
                accL3[i][k] = self.L3_[k]

            # Layer 4 (Output/Decision): Maxing
            self.L4_ = self.L3_.argmax()
            accL4[i] = self.L4_

        return np.column_stack((accL4, accL3))

    def _activation(self, x, w):
        return math.exp( - np.dot((x - w), (x - w)) / (2 * self.sigma_**2) )


# Normalize to unit length: [0, 1]
# X must be numpy array
def Normalize(X):
    x_max = X.max(axis=0)
    x_min = X.min(axis=0)
    return (X - x_min) / (x_max - x_min)

In [2]:
import pandas as pd

df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
df.tail()

Unnamed: 0,0,1,2,3,4
145,6.7,3.0,5.2,2.3,Iris-virginica
146,6.3,2.5,5.0,1.9,Iris-virginica
147,6.5,3.0,5.2,2.0,Iris-virginica
148,6.2,3.4,5.4,2.3,Iris-virginica
149,5.9,3.0,5.1,1.8,Iris-virginica


In [3]:
X = df.iloc[:, :4].values

y = df.iloc[:, 4].values
y[:50] = 0
y[50:100] = 1
y[100:] = 2

X_tr = Normalize(X)    # Training data
X_cv = X_tr            # Cross validation data (may be same as training data for a PNN)
X_tt = X_tr            # Test data (same as training data for simplicity. Should be different in practice)

In [4]:
pnn = PNN()

Empty PNN created


In [5]:
pnn.train(X_tr, y, p=3)

Trained.


In [6]:
n_iter = 10
for sigma in range(n_iter):
    pnn.crossValidate(X_cv, y, (sigma+1)/n_iter)

Cross validation accuracy with sigma 0.1 : 97.3333333333 %
Cross validation accuracy with sigma 0.2 : 96.6666666667 %
Cross validation accuracy with sigma 0.3 : 96.0 %
Cross validation accuracy with sigma 0.4 : 94.6666666667 %
Cross validation accuracy with sigma 0.5 : 94.6666666667 %
Cross validation accuracy with sigma 0.6 : 92.0 %
Cross validation accuracy with sigma 0.7 : 91.3333333333 %
Cross validation accuracy with sigma 0.8 : 91.3333333333 %
Cross validation accuracy with sigma 0.9 : 91.3333333333 %
Cross validation accuracy with sigma 1.0 : 91.3333333333 %


In [7]:
pnn.predict(X_tt, sigma=0.5)

array([[ 0.        ,  0.37627058,  0.11545619,  0.03247945],
       [ 0.        ,  0.3574241 ,  0.13178106,  0.03426819],
       [ 0.        ,  0.36717443,  0.11390842,  0.02896835],
       [ 0.        ,  0.35823278,  0.12069577,  0.03076219],
       [ 0.        ,  0.37349894,  0.10662514,  0.02984881],
       [ 0.        ,  0.34004024,  0.11832374,  0.0410828 ],
       [ 0.        ,  0.36849206,  0.11249613,  0.03045874],
       [ 0.        ,  0.37770034,  0.12216335,  0.03398294],
       [ 0.        ,  0.33093038,  0.11303293,  0.02708203],
       [ 0.        ,  0.36348709,  0.12279655,  0.03128301],
       [ 0.        ,  0.36019638,  0.11186349,  0.03398561],
       [ 0.        ,  0.37507059,  0.1190057 ,  0.03258235],
       [ 0.        ,  0.35389293,  0.11854458,  0.02907882],
       [ 0.        ,  0.32975809,  0.08942308,  0.01962029],
       [ 0.        ,  0.30947537,  0.08396871,  0.02680337],
       [ 0.        ,  0.25980904,  0.07209917,  0.02676542],
       [ 0.        ,  0.