Skip to content

Commit

Permalink
Refactor global variables and thread model (#24)
Browse files Browse the repository at this point in the history
Refactor most of the search-related global variables into the
ThreadPool.
While there, simplify the TT code and remove the PerftTable tests.
Also, use smart pointers for the ThreadPool object.

Non-regression STC:
LLR:  2.98/2.94<-6.00, 0.00> Elo diff: -0.73 [-2.20, 0.75] (95%)
Games: 43095 W: 5266 L: 5355 D: 32474 Draw ratio: 75.4%
Pntl: [230, 3508, 14153, 3434, 222]

No functional change
  • Loading branch information
ruicoelhopedro committed Feb 26, 2024
1 parent 035c351 commit c9f62cf
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pawn-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ jobs:
sudo apt update
sudo apt install valgrind
make -B -j4 DEBUG=2
valgrind --leak-check=full build/pawn bench 11
valgrind --leak-check=full --error-exitcode=1 build/pawn bench 11
4 changes: 0 additions & 4 deletions src/Hash.cpp

This file was deleted.

8 changes: 4 additions & 4 deletions src/Hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class TranspositionEntry

public:
TranspositionEntry()
: m_hash(0), m_type(gen_type(0, EntryType::EMPTY))
: m_hash(0), m_depth(0), m_type(gen_type(0, EntryType::EMPTY)),
m_score(SCORE_NONE), m_best_move(MOVE_NULL), m_static_eval(SCORE_NONE)
{}

inline bool query(Age age, Hash hash, TranspositionEntry** entry)
Expand Down Expand Up @@ -193,6 +194,5 @@ class HashTable
}
};


extern HashTable<TranspositionEntry> ttable;
extern HashTable<PerftEntry> perft_table;
using TranspositionTable = HashTable<TranspositionEntry>;
using PerftTable = HashTable<PerftEntry>;
2 changes: 1 addition & 1 deletion src/NNUE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace NNUE



Feature Accumulator::get_feature(PieceType p, Square s, Square ks, Turn pt, Turn kt) const
Feature Accumulator::get_feature(PieceType p, Square s, Square ks, Turn pt, Turn kt)
{
// Vertical mirror for black kings
if (kt == BLACK)
Expand Down
4 changes: 1 addition & 3 deletions src/NNUE.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace NNUE

void clear();

Feature get_feature(PieceType p, Square s, Square ks, Turn pt, Turn kt) const;
static Feature get_feature(PieceType p, Square s, Square ks, Turn pt, Turn kt);

void push(PieceType p, Square s, Square ks, Turn pt, Turn kt);

Expand All @@ -52,8 +52,6 @@ namespace NNUE
bool operator!=(const Accumulator& other) const;
};

extern const Net* nnue_net;

void init();

void load(std::string file);
Expand Down
6 changes: 4 additions & 2 deletions src/Search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ namespace Search
: BoundType::LOWER_BOUND; // Blessed loss
}

void MultiPVData::write_pv(const Board& board, int index, uint64_t nodes, uint64_t tb_hits, double elapsed) const
void MultiPVData::write_pv(const Board& board, int index, int hashfull, uint64_t nodes, uint64_t tb_hits, double elapsed) const
{
// Don't write if PV line is incomplete
if (search_bound == BoundType::NO_BOUND)
Expand All @@ -167,7 +167,7 @@ namespace Search
// Nodes, nps, hashful and timing
std::cout << " nodes " << nodes;
std::cout << " nps " << static_cast<int>(nodes / elapsed);
std::cout << " hashfull " << ttable.hashfull();
std::cout << " hashfull " << hashfull;
std::cout << " tbhits " << tb_hits;
std::cout << " time " << std::max(1, static_cast<int>(elapsed * 1000));

Expand Down Expand Up @@ -344,6 +344,7 @@ namespace Search
const Turn Turn = position.get_turn();
const Depth Ply = data.ply();
CurrentHistory history = data.histories.get(position);
TranspositionTable& ttable = data.thread().pool().tt();

if (PvNode)
{
Expand Down Expand Up @@ -740,6 +741,7 @@ namespace Search
const bool InCheck = position.in_check();
const Turn Turn = position.get_turn();
const Depth Ply = data.ply();
TranspositionTable& ttable = data.thread().pool().tt();

if (PvNode)
{
Expand Down
17 changes: 4 additions & 13 deletions src/Search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ namespace Search

Score score() const;
BoundType bound() const;
void write_pv(const Board& board, int index, uint64_t nodes, uint64_t tb_hits, double elapsed) const;
void write_pv(const Board& board, int index, int hashfull, uint64_t nodes, uint64_t tb_hits, double elapsed) const;
};


Expand Down Expand Up @@ -185,14 +185,9 @@ namespace Search
bool legality_tests(Position& position, MoveList& move_list);


template<bool OUTPUT, bool USE_ORDER = false, bool TT = false, bool LEGALITY = false, bool VALIDITY = false>
template<bool OUTPUT, bool USE_ORDER = false, bool LEGALITY = false, bool VALIDITY = false>
int64_t perft(Position& position, Depth depth, Histories& hists)
{
// TT lookup
PerftEntry* entry = nullptr;
if (TT && perft_table.query(position.hash(), &entry) && entry->depth() == depth)
return entry->n_nodes();

// Move generation
int64_t n_nodes = 0;
auto move_list = position.generate_moves(MoveGenType::LEGAL);
Expand Down Expand Up @@ -220,7 +215,7 @@ namespace Search
if (depth > 1)
{
position.make_move(move);
count = perft<false, USE_ORDER, TT, LEGALITY, VALIDITY>(position, depth - 1, hists);
count = perft<false, USE_ORDER, LEGALITY, VALIDITY>(position, depth - 1, hists);
position.unmake_move();
}
n_nodes += count;
Expand All @@ -244,7 +239,7 @@ namespace Search
return 0;

position.make_move(move);
count = perft<false, USE_ORDER, TT, LEGALITY, VALIDITY>(position, depth - 1, hists);
count = perft<false, USE_ORDER, LEGALITY, VALIDITY>(position, depth - 1, hists);
position.unmake_move();

n_nodes += count;
Expand All @@ -262,10 +257,6 @@ namespace Search
}
}

// TT storing
if (TT)
perft_table.store(position.hash(), depth, n_nodes);

return n_nodes;
}

Expand Down
21 changes: 8 additions & 13 deletions src/Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ namespace Tests

void bench(Search::Limits limits, int threads, int hash)
{
// Resize thread pool and hash
pool->resize(threads);
ttable.resize(hash);
pool->clear();
ttable.clear();
// Create a thread pool with the given options
ThreadPool pool(threads, hash);

// Start master timer
Search::Timer time;
Expand All @@ -171,7 +168,7 @@ namespace Tests

// Loop over each position
int i = 0;
Position& pos = pool->position();
Position& pos = pool.position();
std::vector<std::string> fens = bench_suite();
for (auto fen : fens)
{
Expand All @@ -181,16 +178,16 @@ namespace Tests
// Update position
pos = Position(fen);
pos.set_init_ply();
pool->update_position_threads();
pool.update_position_threads();
std::cerr << "\nPosition " << (++i) << "/" << fens.size() << ": " << fen << std::endl;

// Start searching and wait for completion
Search::Timer pos_timer;
pool->search(pos_timer, limits);
pool->wait();
pool.search(pos_timer, limits);
pool.wait();

// Update number of nodes
nodes += pool->nodes_searched();
nodes += pool.nodes_searched();
}

// Output bench stats
Expand All @@ -200,10 +197,8 @@ namespace Tests
std::cerr << "Elapsed time (s): " << std::setw(7) << elapsed << std::endl;
std::cerr << "Nodes per second: " << uint64_t(nodes / elapsed) << std::endl;

// Restore initial options, thread pool and hash
// Restore initial options
UCI::Options::UCI_Chess960 = Chess960;
pool->resize(UCI::Options::Threads);
ttable.resize(UCI::Options::Hash);
}


Expand Down
12 changes: 2 additions & 10 deletions src/Tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,14 @@ namespace Tests
void bench(Search::Limits limits, int threads, int hash);


template<bool USE_ORDER, bool TT, bool LEGALITY, bool VALIDITY>
template<bool USE_ORDER, bool LEGALITY, bool VALIDITY>
int perft_techniques_tests()
{
// Store initial state
bool Chess960 = UCI::Options::UCI_Chess960;

auto tests = test_suite();

// Allocate TT
if (TT)
perft_table.resize(16);

int n_failed = 0;
for (auto& test : tests)
{
Expand All @@ -59,7 +55,7 @@ namespace Tests
auto hists = std::make_unique<Histories>();
Depth depth = test.depth() - 1 - 2 * (LEGALITY || VALIDITY);
auto result_base = Search::perft<false>(pos, depth, *hists);
auto result_test = Search::template perft<false, USE_ORDER, TT, LEGALITY, VALIDITY>(pos, depth, *hists);
auto result_test = Search::template perft<false, USE_ORDER, LEGALITY, VALIDITY>(pos, depth, *hists);
if (result_base == result_test)
{
std::cout << "[ OK ] " << test.fen() << " (" << result_test << ")" << std::endl;
Expand All @@ -71,10 +67,6 @@ namespace Tests
}
}

// Deallocate TT
if (TT)
perft_table.resize(0);

UCI::Options::UCI_Chess960 = Chess960;
std::cout << "\nFailed/total tests: " << n_failed << "/" << tests.size() << std::endl;
return n_failed;
Expand Down
24 changes: 16 additions & 8 deletions src/Thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include <condition_variable>


ThreadPool* pool;


Thread::Thread(int id, ThreadPool& pool)
: m_id(id),
Expand Down Expand Up @@ -82,13 +80,14 @@ bool Thread::is_root_move(Move move) const

void Thread::output_pvs()
{
int hashfull = m_pool.tt().hashfull();
double elapsed = m_pool.m_time.elapsed();
uint64_t nodes = m_pool.nodes_searched();
uint64_t tb_hits = m_pool.tb_hits();

// Output information
for (int iPV = 0; iPV < UCI::Options::MultiPV; iPV++)
m_multiPV[iPV].write_pv(m_position.board(), iPV, nodes, tb_hits, elapsed);
m_multiPV[iPV].write_pv(m_position.board(), iPV, hashfull, nodes, tb_hits, elapsed);
}


Expand All @@ -104,11 +103,19 @@ void Thread::tb_hit() { m_tb_hits.fetch_add(1, std::memory_order_relaxed); }



ThreadPool::ThreadPool()
: m_threads(0),
ThreadPool::ThreadPool(int n_threads, int hash_size_mb)
: m_tt(hash_size_mb),
m_threads(0),
m_status(ThreadStatus::WAITING)
{
m_threads.push_back(std::make_unique<Thread>(0, *this));
for (int i = 0; i < n_threads; i++)
m_threads.push_back(std::make_unique<Thread>(i, *this));
}


ThreadPool::~ThreadPool()
{
kill_threads();
}


Expand Down Expand Up @@ -156,7 +163,7 @@ void ThreadPool::search(const Search::Timer& timer, const Search::Limits& limits
this->wait();

// Set the search data before waking the threads
ttable.new_search();
m_tt.new_search();
m_status = ThreadStatus::SEARCHING;
m_limits = limits;

Expand Down Expand Up @@ -198,6 +205,7 @@ void ThreadPool::wait(Thread* skip)

void ThreadPool::clear()
{
m_tt.clear();
for (auto& thread : m_threads)
thread->clear();
}
Expand Down Expand Up @@ -509,7 +517,7 @@ void Thread::search()
{
m_position.make_move(bestmove);
TranspositionEntry* entry = nullptr;
if (ttable.query(m_position.hash(), &entry) && m_position.board().legal(entry->hash_move()))
if (m_pool.tt().query(m_position.hash(), &entry) && m_position.board().legal(entry->hash_move()))
pondermove = entry->hash_move();
m_position.unmake_move();
}
Expand Down
11 changes: 7 additions & 4 deletions src/Thread.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include "Types.hpp"
#include "Move.hpp"
#include "Position.hpp"
Expand Down Expand Up @@ -83,6 +85,7 @@ class Thread
class ThreadPool
{
Position m_position;
TranspositionTable m_tt;
std::vector<std::unique_ptr<Thread>> m_threads;

void send_signal(ThreadStatus signal);
Expand All @@ -94,7 +97,9 @@ class ThreadPool
std::atomic<ThreadStatus> m_status;

public:
ThreadPool();
ThreadPool(int n_threads, int hash_size_mb);

virtual ~ThreadPool();

void resize(int n_threads);

Expand Down Expand Up @@ -131,7 +136,5 @@ class ThreadPool
Thread* get_best_thread() const;

inline Thread& front() { return *(m_threads.front()); }
inline TranspositionTable& tt() { return m_tt; }
};


extern ThreadPool* pool;

0 comments on commit c9f62cf

Please sign in to comment.