In [1]:
from Gaussian import Gaussian
from EM import EM
from NN import NN
import numpy as np
import matplotlib.pyplot as plt

def complex_train(D, K):
    method_number = 5

    # parameters for neural networks
    neuron_1     = {0: K}
    neuron_2_10  = {0: 10, 1: K}
    neuron_2_100 = {0: 100, 1: K}
    act_func_1 = {0: NN.softmax}
    act_func_2 = {0: NN.relu, 1: NN.softmax}

    while True:
        # store the result
        accuracy = np.zeros(method_number)
        time     = np.zeros(method_number)

        # generate sample
        G = Gaussian(D=D, K=K, background=False, index_para=[6000, 9000])

        neet_continue = False
        # train and test each method
        method_set = [EM(K),
                      NN(D, neuron_1, act_func_1, NN_type="QNN"),
                      NN(D, neuron_2_100, act_func_2, NN_type="CNN"),
                      NN(D, neuron_2_10, act_func_2, NN_type="CNN"),
                      NN(D, neuron_1, act_func_1, NN_type="CNN")]
        for j in range(method_number):
            method = method_set[j]
            if j == 0:
                method.train(G.train_point)
                method.order_correction(G.valid_point, G.valid_label)
            else:
                method.train(G.train_point, G.train_label,
                             G.valid_point, G.valid_label, step_size=500)

            accuracy[j] = method.test(G.test_point, G.test_label)[0] * 100
            time[j]     = method.train_time

            print(accuracy, "\n",time)
            if j == 0:
                if accuracy[j] < 70 or time[j] > 10: neet_continue = True
            if j == 1:
                if abs(accuracy[0] - accuracy[1]) > 1: neet_continue = True
            if j == 2:
                if time[2] < (3 * time[1]): neet_continue = True
            if j == 3:
                if time[1] - time[3] > 5: neet_continue = True
            if neet_continue is True: break
        if neet_continue is False: break

    return accuracy, time

In [None]:
sample_number = 10

accuracy_25, accuracy_35, accuracy_28, accuracy_38 = [], [], [], []
time_25, time_35, time_28, time_38 = [], [], [], []

for _ in range(sample_number):
    accuracy, time = complex_train(D=2, K=5)
    accuracy_25.append(accuracy)
    time_25.append(time)
    np.savetxt("accuracy_25.csv", accuracy_25, delimiter=",")
    np.savetxt("time_25.csv", time_25, delimiter=",")

    accuracy, time = complex_train(D=3, K=5)
    accuracy_35.append(accuracy)
    time_35.append(time)
    np.savetxt("accuracy_35.csv", accuracy_35, delimiter=",")
    np.savetxt("time_35.csv", time_35, delimiter=",")

    accuracy, time = complex_train(D=2, K=8)
    accuracy_28.append(accuracy)
    time_28.append(time)
    np.savetxt("accuracy_28.csv", accuracy_28, delimiter=",")
    np.savetxt("time_28.csv", time_28, delimiter=",")

    accuracy, time = complex_train(D=3, K=8)
    accuracy_38.append(accuracy)
    time_38.append(time)
    np.savetxt("accuracy_38.csv", accuracy_38, delimiter=",")
    np.savetxt("time_38.csv", time_38, delimiter=",")

[82.80237581  0.          0.          0.          0.        ] 
 [3.45824981 0.         0.         0.         0.        ]
[82.80237581 82.85637149  0.          0.          0.        ] 
 [ 3.45824981 10.4515276   0.          0.          0.        ]
[82.80237581 82.85637149 82.64038877  0.          0.        ] 
 [ 3.45824981 10.4515276  42.94193482  0.          0.        ]
[82.80237581 82.85637149 82.64038877 82.61339093  0.        ] 
 [ 3.45824981 10.4515276  42.94193482 16.92936826  0.        ]
[82.80237581 82.85637149 82.64038877 82.61339093 81.42548596] 
 [ 3.45824981 10.4515276  42.94193482 16.92936826  5.54567409]
[92.81398755  0.          0.          0.          0.        ] 
 [0.45943403 0.         0.         0.         0.        ]
[92.81398755 91.64225558  0.          0.          0.        ] 
 [ 0.45943403 20.79525495  0.          0.          0.        ]
[84.48784083  0.          0.          0.          0.        ] 
 [1.16178989 0.         0.         0.         0.        ]
[84.487

In [None]:
accuracy_25 = np.loadtxt("accuracy_25.csv", delimiter=",")
accuracy_35 = np.loadtxt("accuracy_35.csv", delimiter=",")
accuracy_28 = np.loadtxt("accuracy_28.csv", delimiter=",")
accuracy_38 = np.loadtxt("accuracy_38.csv", delimiter=",")
time_25 = np.loadtxt("time_25.csv", delimiter=",")
time_35 = np.loadtxt("time_35.csv", delimiter=",")
time_28 = np.loadtxt("time_28.csv", delimiter=",")
time_38 = np.loadtxt("time_38.csv", delimiter=",")

average_accuracy_25 = np.sum(accuracy_25, axis=0) / len(accuracy_25)
average_accuracy_35 = np.sum(accuracy_35, axis=0) / len(accuracy_35)
average_accuracy_28 = np.sum(accuracy_28, axis=0) / len(accuracy_28)
average_accuracy_38 = np.sum(accuracy_38, axis=0) / len(accuracy_38)
average_time_25 = np.sum(time_25, axis=0) / len(time_25)
average_time_35 = np.sum(time_35, axis=0) / len(time_35)
average_time_28 = np.sum(time_28, axis=0) / len(time_28)
average_time_38 = np.sum(time_38, axis=0) / len(time_38)