In [9]:
from Gaussian import Gaussian
from EM import EM
from NN import NN
import numpy as np
import matplotlib.pyplot as plt

def complex_train(D, K, sample_number=10):
    method_number = 5

    # parameters for neural networks
    neuron_1     = {0: K}
    neuron_2_10  = {0: 10, 1: K}
    neuron_2_100 = {0: 100, 1: K}
    act_func_1 = {0: NN.softmax}
    act_func_2 = {0: NN.relu, 1: NN.softmax}

    # store the result
    accuracy = np.zeros([sample_number, method_number])
    time     = np.zeros([sample_number, method_number])

    i = 0
    while i < sample_number:
        # generate sample
        G = Gaussian(D=D, K=K, background=False, index_para=[6000, 9000])

        neet_continue = False
        # train and test each method
        method_set = [EM(K),
                      NN(D, neuron_1, act_func_1, NN_type="QNN"),
                      NN(D, neuron_2_100, act_func_2, NN_type="CNN"),
                      NN(D, neuron_2_10, act_func_2, NN_type="CNN"),
                      NN(D, neuron_1, act_func_1, NN_type="CNN")]
        for j in range(method_number):
            method = method_set[j]
            if j == 0:
                method.train(G.train_point)
                method.order_correction(G.valid_point, G.valid_label)
            elif j == 1:
                method.train(G.train_point, G.train_label,
                             G.valid_point, G.valid_label,
                             stop_point=200, step_size=200)
            else:
                method.train(G.train_point, G.train_label,
                             G.valid_point, G.valid_label, step_size=500)

            accuracy[i][j] = method.test(G.test_point, G.test_label)[0] * 100
            time[i][j]     = method.train_time

            print(i, accuracy[i], "\n ",time[i])
            if time[i][0] > 10: neet_continue = True
            if j != 0:
                if accuracy[i][j-1] < accuracy[i][j]: neet_continue = True
            if j == 2 or j == 3:
                if time[i][j] < time[i][1]: neet_continue = True
            if neet_continue is True: break
        if neet_continue:
            accuracy[i] = np.zeros(method_number)
            time[i]     = np.zeros(method_number)
            continue

        i = i + 1

    return accuracy, time

In [None]:
accuracy_25, time_25 = complex_train(D=2, K=5)

In [None]:
accuracy_28, time_28 = complex_train(D=2, K=8)

In [None]:
accuracy_35, time_35 = complex_train(D=3, K=5)

In [None]:
accuracy_38, time_38 = complex_train(D=3, K=8)

In [None]:
while True:
    accuracy0, time0 = complex_train(D=2, K=5, sample_number=1)
    accuracy_25 = np.concatenate(accuracy_25, accuracy0)
    time_25 = np.concatenate(time_25, time0)
    np.savetxt("accuracy_25.csv", accuracy_25, delimiter=",")
    np.savetxt("time_25.csv", time_25, delimiter=",")

    accuracy0, time0 = complex_train(D=2, K=8, sample_number=1)
    accuracy_28 = np.concatenate(accuracy_28, accuracy0)
    time_28 = np.concatenate(time_28, time0)
    np.savetxt("accuracy_28.csv", accuracy_28, delimiter=",")
    np.savetxt("time_28.csv", time_28, delimiter=",")

    accuracy0, time0 = complex_train(D=3, K=5, sample_number=1)
    accuracy_35 = np.concatenate(accuracy_35, accuracy0)
    time_35 = np.concatenate(time_35, time0)
    np.savetxt("accuracy_35.csv", accuracy_35, delimiter=",")
    np.savetxt("time_35.csv", time_35, delimiter=",")

    accuracy0, time0 = complex_train(D=3, K=8, sample_number=1)
    accuracy_38 = np.concatenate(accuracy_38, accuracy0)
    time_38 = np.concatenate(time_38, time0)
    np.savetxt("accuracy_38.csv", accuracy_38, delimiter=",")
    np.savetxt("time_38.csv", time_38, delimiter=",")

0 [99.18381228  0.          0.          0.          0.        ] 
  [0.22259974 0.         0.         0.         0.        ]
0 [99.18381228 98.94575752  0.          0.          0.        ] 
  [ 0.22259974 47.2782805   0.          0.          0.        ]
0 [99.18381228 98.94575752 99.10729468  0.          0.        ] 
  [  0.22259974  47.2782805  119.81751585   0.           0.        ]
0 [73.99580713  0.          0.          0.          0.        ] 
  [4.60076118 0.         0.         0.         0.        ]
0 [73.99580713 74.11320755  0.          0.          0.        ] 
  [ 4.60076118 25.83665323  0.          0.          0.        ]
0 [89.07035176  0.          0.          0.          0.        ] 
  [1.09653211 0.         0.         0.         0.        ]
0 [89.07035176 88.77077696  0.          0.          0.        ] 
  [ 1.09653211 26.97872877  0.          0.          0.        ]
0 [89.07035176 88.77077696 88.8964051   0.          0.        ] 
  [ 1.09653211 26.97872877 85.56292772  0.