Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/example_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def egreedy_action(state: np.ndarray) -> int:
if np.random.rand() < epsilon:
return random.randrange(N_ACTIONS)
prediction_array = xcs.predict(state.reshape(1, -1))[0]
return int(np.argmax(prediction_array))
# break ties randomly
best_actions = np.where(prediction_array == prediction_array.max())[0]
return int(np.random.choice(best_actions))


def episode(episode_nr: int, create_gif: bool) -> tuple[float, int]:
Expand Down
4 changes: 3 additions & 1 deletion python/example_rmux.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def egreedy_action(state: np.ndarray, epsilon: float) -> tuple[int, float]:
if np.random.rand() < epsilon:
return random.randrange(N_ACTIONS), 0
prediction_array = xcs.predict(state.reshape(1, -1))[0]
action = int(np.argmax(prediction_array))
# break ties randomly
best_actions = np.where(prediction_array == prediction_array.max())[0]
action = int(np.random.choice(best_actions))
prediction = prediction_array[action]
return action, prediction

Expand Down
504 changes: 255 additions & 249 deletions python/notebooks/example_cartpole.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/notebooks/example_maze.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
19 changes: 10 additions & 9 deletions python/notebooks/example_rmux.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (C) 2020 Richard Preen <rpreen@gmail.com> Copyright (C) 2021 David
# Pätzel
# Copyright (C) 2020--2022 Richard Preen <rpreen@gmail.com>
#
# Copyright (C) 2021 David Pätzel
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
Expand All @@ -27,6 +28,7 @@ set(XCSF_TESTS
neural_layer_maxpool_test.cpp
neural_layer_recurrent_test.cpp
neural_test.cpp
pa_test.cpp
pred_nlms_test.cpp
pred_rls_test.cpp
util_test.cpp
Expand Down
64 changes: 64 additions & 0 deletions test/pa_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

/**
* @file pa_test.cpp
* @author Richard Preen <rpreen@gmail.com>
* @copyright The Authors.
* @date 2022.
* @brief Prediction array tests.
*/

#include "../lib/doctest/doctest/doctest.h"

extern "C" {
#include "../xcsf/pa.h"
#include "../xcsf/param.h"
#include "../xcsf/utils.h"
#include "../xcsf/xcsf.h"
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
}

TEST_CASE("PA")
{
struct XCSF xcsf;
param_init(&xcsf, 5, 1, 5);
rand_init_seed(2);

// test best action
double pa1[5] = { 0.214, 0.6423, 0.111, 0.775, 0.445 };
xcsf.pa = pa1;
int action = pa_best_action(&xcsf);
CHECK_EQ(action, 3);

double pa2[5] = { 0.214, 0.9423, 0.111, 0.775, 0.445 };
xcsf.pa = pa2;
action = pa_best_action(&xcsf);
CHECK_EQ(action, 1);

double pa3[5] = { 0.6423, 0.6423, 0.6423, 0.6423, 0.445 };
xcsf.pa = pa3;
action = pa_best_action(&xcsf);
CHECK_EQ(action, 1);

action = pa_best_action(&xcsf);
CHECK_EQ(action, 2);

action = pa_best_action(&xcsf);
CHECK_EQ(action, 0);
}
10 changes: 5 additions & 5 deletions test/util_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @file util_test.cpp
* @author Richard Preen <rpreen@gmail.com>
* @copyright The Authors.
* @date 2020.
* @date 2020--2022.
* @brief Utility tests.
*/

Expand Down Expand Up @@ -87,14 +87,14 @@ TEST_CASE("UTIL")
for (int i = 0; i < 2; ++i) {
CHECK_EQ(correct[i], tmp[i]);
}
// test max index
// test argmax
double x[5] = { 0.214, 0.6423, 0.111, 0.775, 0.445 };
int max = max_index(x, 5);
int max = argmax(x, 5);
CHECK_EQ(max, 3);
x[3] = 0.1;
max = max_index(x, 5);
max = argmax(x, 5);
CHECK_EQ(max, 1);
x[1] = -0.2;
max = max_index(x, 5);
max = argmax(x, 5);
CHECK_EQ(max, 4);
}
4 changes: 2 additions & 2 deletions xcsf/act_neural.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @file act_neural.c
* @author Richard Preen <rpreen@gmail.com>
* @copyright The Authors.
* @date 2020--2021.
* @date 2020--2022.
* @brief Neural network action functions.
*/

Expand Down Expand Up @@ -109,7 +109,7 @@ act_neural_compute(const struct XCSF *xcsf, const struct Cl *c, const double *x)
struct ActNeural *act = c->act;
neural_propagate(&act->net, x, xcsf->explore);
const double *outputs = neural_outputs(&act->net);
return max_index(outputs, xcsf->n_actions);
return argmax(outputs, xcsf->n_actions);
}

/**
Expand Down
4 changes: 2 additions & 2 deletions xcsf/loss.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @file loss.c
* @author Richard Preen <rpreen@gmail.com>
* @copyright The Authors.
* @date 2019--2020.
* @date 2019--2022.
* @brief Loss functions for calculating prediction error.
*/

Expand Down Expand Up @@ -119,7 +119,7 @@ loss_binary_log(const struct XCSF *xcsf, const double *pred, const double *y)
double
loss_onehot(const struct XCSF *xcsf, const double *pred, const double *y)
{
const int max_i = max_index(pred, xcsf->y_dim);
const int max_i = argmax(pred, xcsf->y_dim);
if (y[max_i] != 1) {
return 1;
}
Expand Down
23 changes: 20 additions & 3 deletions xcsf/pa.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* @file pa.c
* @author Richard Preen <rpreen@gmail.com>
* @copyright The Authors.
* @date 2015--2020.
* @date 2015--2022.
* @brief Prediction array functions.
*/

Expand Down Expand Up @@ -111,13 +111,30 @@ pa_build(const struct XCSF *xcsf, const double *x)

/**
* @brief Returns the best action in the prediction array.
* @details Ties broken uniformly random.
* @param [in] xcsf The XCSF data structure.
* @return The best action.
*/
int
pa_best_action(const struct XCSF *xcsf)
{
return max_index(xcsf->pa, xcsf->n_actions);
int *max_i = calloc(xcsf->n_actions, sizeof(int));
double max = xcsf->pa[0];
int n_max = 1;
for (int i = 1; i < xcsf->n_actions; ++i) {
const double val = xcsf->pa[i];
if (val > max) {
max = val;
max_i[0] = i;
n_max = 1;
} else if (val == max) {
max_i[n_max] = i;
++n_max;
}
}
const int best = max_i[rand_uniform_int(0, n_max)];
free(max_i);
return best;
}

/**
Expand All @@ -143,7 +160,7 @@ pa_rand_action(const struct XCSF *xcsf)
double
pa_best_val(const struct XCSF *xcsf)
{
const int max_i = max_index(xcsf->pa, xcsf->pa_size);
const int max_i = argmax(xcsf->pa, xcsf->pa_size);
return xcsf->pa[max_i];
}

Expand Down
7 changes: 4 additions & 3 deletions xcsf/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* @author Richard Preen <rpreen@gmail.com>
* @author David Pätzel
* @copyright The Authors.
* @date 2015--2021.
* @date 2015--2022.
* @brief Utility functions for random number handling, etc.
*/

Expand Down Expand Up @@ -76,15 +76,16 @@ clamp_int(const int a, const int min, const int max)

/**
* @brief Returns the index of the largest element in vector X.
* @details First occurrence is selected in the case of a tie.
* @param [in] X Vector with N elements.
* @param [in] N The number of elements in vector X.
* @return The index of the largest element.
*/
static inline int
max_index(const double *X, const int N)
argmax(const double *X, const int N)
{
if (N < 1) {
printf("max_index() error: N < 1\n");
printf("argmax() error: N < 1\n");
exit(EXIT_FAILURE);
}
int max_i = 0;
Expand Down