Skip to content

Commit

Permalink
LCB test
Browse files Browse the repository at this point in the history
  • Loading branch information
yssaya committed Sep 7, 2021
1 parent 9eb2809 commit 73258b3
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/usi-engine/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void Utils::create_z_table() {
boost::math::students_t dist(i);
auto z = boost::math::quantile(boost::math::complement(dist, cfg_ci_alpha));
z_lookup[i - 1] = z;
// myprintf("%4d:z=%f\n",i,z); // 1:z=31830, 7:z=10, 30:z=5, 100:z=4.5, 1000:z=4.3
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/usi-engine/bona/yss_dcnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ typedef struct child {
float value; // win rate (win=+1, loss=0)
float bias; // policy
int exact_value; // WIN or LOSS or DRAW
float squared_eval_diff; // Variable used for calculating variance of evaluations. for LCB
int acc_virtual_loss; // accumulate virtual loss. for LCB
// std::atomic<bool> exact_value;
} CHILD;

Expand Down
108 changes: 104 additions & 4 deletions src/usi-engine/bona/ysszero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "process_batch.h"

#include "../GTP.h"
#include "../Utils.h"

int NOT_USE_NN = 0;

Expand All @@ -44,6 +45,7 @@ int fPrtNetworkRawPath = 0;
int fVerbose = 1;
int fClearHashAlways = 0;
int fUsiInfo = 0;
bool fLCB = false;

int nLimitUctLoop = 100;
double dLimitSec = 0;
Expand Down Expand Up @@ -290,6 +292,14 @@ const int USI_BESTMOVE_LEN = MAX_LEGAL_MOVES*(8+5)+10;

int YssZero_com_turn_start( tree_t * restrict ptree )
{
if ( 0 ) {
if ( ptree->nrep > 180 ) {
dLimitSec = 1.8;
} else {
dLimitSec = 4.8;
}
}

if ( 0 ) {
int ct1 = get_clock();
int i;
Expand Down Expand Up @@ -866,6 +876,7 @@ int uct_search_start(tree_t * restrict ptree, int sideToMove, int ply, char *buf
int sort_n = 0;
int select_count = 0;
bool found_mate = false;
float max_lcb = -1e6f;

if ( is_use_exact() ) for (i=0;i<phg->child_num;i++) {
CHILD *pc = &phg->child[i];
Expand All @@ -882,15 +893,36 @@ int uct_search_start(tree_t * restrict ptree, int sideToMove, int ply, char *buf
break;
}


if ( !found_mate ) for (i=0;i<phg->child_num;i++) {
CHILD *pc = &phg->child[i];
if ( pc->games > max_games ) {
max_games = pc->games;
max_i = i;

float lcb = 0;
if ( fLCB ) { // Lower confidence bound of winrate.
int visits = pc->games;
lcb = -1e6f + visits; // large negative value if not enough visits.
if (visits >= 2) {
float mean = pc->value;
// if ( sideToMove == white ) mean = -mean; // AZ(-1<x<1), LZ (0<x<1), mean = 1.0f - mean;
float eval_variance = visits > 1 ? pc->squared_eval_diff / (visits - 1) : 1.0f;
auto stddev = std::sqrt(eval_variance / visits);
auto z = Utils::cached_t_quantile(visits - 1);
lcb = mean - z * stddev;
}
if ( lcb > max_lcb ) {
max_lcb = lcb;
max_i = i;
}
} else {
if ( pc->games > max_games ) {
max_games = pc->games;
max_i = i;
}
}
sum_games += pc->games;
if ( pc->games ) {
PRT("%3d(%3d)%7s,%5d,%6.3f,bias=%.10f\n",i,select_count++,str_CSA_move(pc->move),pc->games,pc->value,pc->bias);
float v = pc->value;
PRT("%3d(%3d)%7s,%5d,%6.3f,bias=%.10f,V=%6.2f%%,LCB=%6.2f%%\n",i,select_count++,str_CSA_move(pc->move),pc->games,pc->value,pc->bias,100.0*(v+1.0)/2.0,100.0*(lcb+1.0)/2.0);
if ( sort_n < SORT_MAX ) {
sort[sort_n][0] = pc->games;
sort[sort_n][1] = pc->move;
Expand Down Expand Up @@ -1144,6 +1176,8 @@ if (0) {
pc->games = 0;
pc->value = 0;
pc->exact_value = EX_NONE;
pc->squared_eval_diff = 1e-4f; // Initialized to small non-zero value to avoid accidental zero variances at low visits.
pc->acc_virtual_loss = 0;
}
phg->child_num = move_num;

Expand Down Expand Up @@ -1445,6 +1479,7 @@ double uct_tree(tree_t * restrict ptree, int sideToMove, int ply, int *pExactVal
pc->value = (float)(((double)pc->games * pc->value + one_win*VL_N) / (pc->games + VL_N)); // games==0 の時はpc->value は無視されるので問題なし
pc->games += VL_N;
phg->games_sum += VL_N; // 末端のノードで減らしても意味がない、のでUCTの木だけで減らす
pc->acc_virtual_loss += VL_N;
}

UnLock(phg->entry_lock);
Expand All @@ -1453,6 +1488,7 @@ double uct_tree(tree_t * restrict ptree, int sideToMove, int ply, int *pExactVal
Lock(phg->entry_lock);

if ( fVirtualLoss ) {
pc->acc_virtual_loss -= VL_N;
phg->games_sum -= VL_N;
pc->games -= VL_N; // gamesを減らすのは非常に危険! あちこちで games==0 で判定してるので
if ( pc->games < 0 ) { PRT("Err pc->games=%d\n",pc->games); debug(); }
Expand All @@ -1471,6 +1507,17 @@ double uct_tree(tree_t * restrict ptree, int sideToMove, int ply, int *pExactVal
win = +1.0;
}

// LCB用 Welford's online algorithm for calculating variance.
float eval = win; // eval は netの値そのもの。
float old_eval = (float)pc->games * pc->value + pc->acc_virtual_loss * 1; // 累積。old_accumulate_eval が正しいか
float old_visits = pc->games - pc->acc_virtual_loss;
if ( old_visits < 0 ) DEBUG_PRT("");
float old_delta = old_visits > 0 ? eval - old_eval / old_visits : 0.0f;
float new_delta = eval - (old_eval + eval) / (old_visits + 1);
float delta = old_delta * new_delta;
pc->squared_eval_diff += delta;


double win_prob = ((double)pc->games * pc->value + win) / (pc->games + 1); // 単純平均

pc->value = (float)win_prob;
Expand Down Expand Up @@ -1587,6 +1634,11 @@ int getCmdLineParam(int argc, char *argv[])
dSelectRandom = nf;
continue;
}
if ( strstr(p,"-lcb") ) {
fLCB = true;
PRT("fLCB=%d\n",fLCB);
continue;
}
#ifdef USE_OPENCL
if ( strstr(p,"-dirtune") ) {
PRT("DirTune=%s\n",q);
Expand Down Expand Up @@ -2063,3 +2115,51 @@ double get_sel_rand_prob_from_rate(int rate)
過去8000棋譜の勝率で無条件で 、調整
*/

#if 0
float UCTNode::get_eval_lcb(int color) const {
// Lower confidence bound of winrate.
auto visits = get_visits();
if (visits < 2) {
// Return large negative value if not enough visits.
return -1e6f + visits;
}
auto mean = get_raw_eval(color);

auto stddev = std::sqrt(get_eval_variance(1.0f) / visits);
auto z = cached_t_quantile(visits - 1);

return mean - z * stddev;
}

float UCTNode::get_raw_eval(int tomove, int virtual_loss) const {
auto visits = get_visits() + virtual_loss;
assert(visits > 0);
auto blackeval = get_blackevals(); // 累積。accumulate_eval とすべき
if (tomove == FastBoard::WHITE) {
blackeval += static_cast<double>(virtual_loss);
}
auto eval = static_cast<float>(blackeval / double(visits));
if (tomove == FastBoard::WHITE) {
eval = 1.0f - eval;
}
return eval; // これは勝率。ややこしい
}

float UCTNode::get_eval_variance(float default_var) const {
return m_visits > 1 ? m_squared_eval_diff / (m_visits - 1) : default_var;
}

void UCTNode::update(float eval) {
// Cache values to avoid race conditions.
auto old_eval = static_cast<float>(m_blackevals); // 累積
auto old_visits = static_cast<int>(m_visits);
auto old_delta = old_visits > 0 ? eval - old_eval / old_visits : 0.0f;
m_visits++;
accumulate_eval(eval); // m_blackevals に足す。累積。eval は netの値そのもの。
auto new_delta = eval - (old_eval + eval) / (old_visits + 1);
// Welford's online algorithm for calculating variance.
auto delta = old_delta * new_delta;
atomic_add(m_squared_eval_diff, delta);
}

#endif

0 comments on commit 73258b3

Please sign in to comment.