From a979442b430663417399e1504de3457c42ad0b8e Mon Sep 17 00:00:00 2001 From: taylorhansen Date: Sun, 29 Jan 2023 14:10:26 -0800 Subject: [PATCH] Synchronize rollout with proper replay buffer, reorganize Improvement on #353 and setup for #354. Rewrite training algorithm (again) to remove the concept of episodes and instead focus on pure learning steps according to the DQN algorithm. Also add a proper replay buffer implementation. Add/rewrite some configs/metrics code to mesh with above. Also reorganize source tree, general housekeeping. --- src/config/config.example.ts | 40 ++- src/config/types.ts | 85 +++-- src/demo/compare.ts | 7 +- src/demo/train.ts | 6 +- src/game/agent/random.ts | 13 + src/{train => }/game/experience/Experience.ts | 23 +- src/{train => }/game/experience/index.ts | 0 src/game/experience/tensor.ts | 25 ++ src/{train => }/game/pool/GamePipeline.ts | 31 +- src/{train => }/game/pool/GamePool.ts | 17 +- src/{train => }/game/pool/GamePoolStream.ts | 0 src/{train => }/game/pool/index.ts | 0 .../game/pool/worker/GameProtocol.ts | 18 +- .../game/pool/worker/GameWorker.ts | 48 ++- src/{train => }/game/pool/worker/index.ts | 0 src/{train => }/game/pool/worker/worker.js | 0 src/{train => }/game/pool/worker/worker.ts | 80 ++--- src/{train => }/game/sim/playGame.ts | 61 ++-- src/game/sim/ps/experienceBattleParser.ts | 116 ++++++ src/{train => }/game/sim/ps/index.ts | 0 src/{train => }/game/sim/ps/ps.ts | 28 +- src/model/model.ts | 97 ++++- src/model/port/ModelPort.ts | 152 ++++++++ src/model/port/ModelPortProtocol.ts | 66 ++++ src/{train => }/model/port/index.ts | 0 src/model/shapes.ts | 72 +--- src/model/verify.ts | 51 +++ src/{train => }/model/worker/Metrics.ts | 4 +- src/{train => }/model/worker/ModelProtocol.ts | 36 +- src/{train => }/model/worker/ModelRegistry.ts | 67 +--- src/{train => }/model/worker/ModelWorker.ts | 4 +- src/{train => }/model/worker/PredictBatch.ts | 31 +- src/{train => }/model/worker/index.ts | 0 src/{train => }/model/worker/worker.js | 0 src/{train => }/model/worker/worker.ts | 23 +- src/psbot/handlers/battle/BattleHandler.ts | 2 +- .../handlers/battle/agent/BattleAgent.ts | 5 +- src/psbot/handlers/battle/ai/maxAgent.ts | 12 +- src/psbot/handlers/battle/ai/networkAgent.ts | 27 +- src/train/{model/worker => }/Evaluate.ts | 27 +- src/train/Learn.ts | 232 ++++++++++++ src/train/README.md | 23 +- src/train/ReplayBuffer.ts | 124 +++++++ src/train/{model/worker => }/Rollout.ts | 94 +++-- src/train/RolloutModel.ts | 307 ++++++++++++++++ src/train/TrainingProgress.ts | 158 ++++---- src/train/game/agent/random.ts | 33 -- .../game/sim/ps/experienceBattleParser.ts | 112 ------ src/train/model/port/ModelPort.ts | 167 --------- src/train/model/port/ModelPortProtocol.ts | 36 -- src/train/model/worker/Learn.ts | 340 ------------------ src/train/model/worker/dataset.ts | 73 ---- src/train/model/worker/train.ts | 198 ---------- src/train/pool/index.ts | 1 - src/train/train.ts | 259 +++++++++++++ src/util/format.ts | 35 ++ src/util/model.ts | 20 ++ src/{train => util}/pool/ThreadPool.ts | 42 ++- src/{train => util}/port/AsyncPort.ts | 0 src/{train => util}/port/PortProtocol.ts | 0 src/{train/port => util/worker}/WorkerPort.ts | 11 +- .../port => util/worker}/WorkerProtocol.ts | 6 +- 62 files changed, 1989 insertions(+), 1556 deletions(-) create mode 100644 src/game/agent/random.ts rename src/{train => }/game/experience/Experience.ts (57%) rename src/{train => }/game/experience/index.ts (100%) create mode 100644 src/game/experience/tensor.ts rename src/{train => }/game/pool/GamePipeline.ts (87%) rename src/{train => }/game/pool/GamePool.ts (91%) rename src/{train => }/game/pool/GamePoolStream.ts (100%) rename src/{train => }/game/pool/index.ts (100%) rename src/{train => }/game/pool/worker/GameProtocol.ts (89%) rename src/{train => }/game/pool/worker/GameWorker.ts (65%) rename src/{train => }/game/pool/worker/index.ts (100%) rename src/{train => }/game/pool/worker/worker.js (100%) rename src/{train => }/game/pool/worker/worker.ts (66%) rename src/{train => }/game/sim/playGame.ts (66%) create mode 100644 src/game/sim/ps/experienceBattleParser.ts rename src/{train => }/game/sim/ps/index.ts (100%) rename src/{train => }/game/sim/ps/ps.ts (92%) create mode 100644 src/model/port/ModelPort.ts create mode 100644 src/model/port/ModelPortProtocol.ts rename src/{train => }/model/port/index.ts (100%) create mode 100644 src/model/verify.ts rename src/{train => }/model/worker/Metrics.ts (94%) rename src/{train => }/model/worker/ModelProtocol.ts (85%) rename src/{train => }/model/worker/ModelRegistry.ts (81%) rename src/{train => }/model/worker/ModelWorker.ts (98%) rename src/{train => }/model/worker/PredictBatch.ts (72%) rename src/{train => }/model/worker/index.ts (100%) rename src/{train => }/model/worker/worker.js (100%) rename src/{train => }/model/worker/worker.ts (88%) rename src/train/{model/worker => }/Evaluate.ts (86%) create mode 100644 src/train/Learn.ts create mode 100644 src/train/ReplayBuffer.ts rename src/train/{model/worker => }/Rollout.ts (63%) create mode 100644 src/train/RolloutModel.ts delete mode 100644 src/train/game/agent/random.ts delete mode 100644 src/train/game/sim/ps/experienceBattleParser.ts delete mode 100644 src/train/model/port/ModelPort.ts delete mode 100644 src/train/model/port/ModelPortProtocol.ts delete mode 100644 src/train/model/worker/Learn.ts delete mode 100644 src/train/model/worker/dataset.ts delete mode 100644 src/train/model/worker/train.ts delete mode 100644 src/train/pool/index.ts create mode 100644 src/train/train.ts create mode 100644 src/util/format.ts create mode 100644 src/util/model.ts rename src/{train => util}/pool/ThreadPool.ts (83%) rename src/{train => util}/port/AsyncPort.ts (100%) rename src/{train => util}/port/PortProtocol.ts (100%) rename src/{train/port => util/worker}/WorkerPort.ts (92%) rename src/{train/port => util/worker}/WorkerProtocol.ts (89%) diff --git a/src/config/config.example.ts b/src/config/config.example.ts index e72c7562..96aedea4 100644 --- a/src/config/config.example.ts +++ b/src/config/config.example.ts @@ -8,6 +8,8 @@ import {Config} from "./types"; // the training script should make the whole process fully deterministic. const numThreads = os.cpus().length; +const maxTurns = 100; + /** * Top-level config. Should only be accessed by the top-level. * @@ -32,7 +34,7 @@ export const config: Config = { tf: {gpu: false}, train: { name: "train", - episodes: 16, + steps: maxTurns * 2 * 50 /*enough for at least 50 games*/, batchPredict: { maxSize: numThreads, timeoutNs: 10_000_000n /*10ms*/, @@ -47,36 +49,36 @@ export const config: Config = { rollout: { pool: { numThreads, - maxTurns: 100, + maxTurns, reduceLogs: true, }, policy: { exploration: 1.0, - explorationDecay: 0.9, + explorationDecay: 0.999, minExploration: 0.1, }, prev: 0.1, }, + experience: { + rewardDecay: 0.99, + bufferSize: maxTurns * 2 * 25 /*enough for at least 25 games*/, + prefill: maxTurns * 2 * numThreads /*at least one complete game*/, + }, + learn: { + learningRate: 0.0001, + batchSize: 32, + target: "double", + targetInterval: 500, + metricsInterval: 100, + }, eval: { - numGames: 128, + numGames: 32, pool: { numThreads, - maxTurns: 100, + maxTurns, reduceLogs: true, }, - }, - learn: { - updates: 1024, - learningRate: 0.0001, - buffer: { - shuffle: 100 * 2 * 8 /*at least 8 game's worth*/, - batch: 32, - prefetch: 4, - }, - experience: { - rewardDecay: 0.99, - }, - target: "double", + interval: 1000, }, seeds: { model: "abc", @@ -87,6 +89,8 @@ export const config: Config = { learn: "pqr", }, savePreviousVersions: true, + checkpointInterval: 1000, + progress: true, verbose: Verbose.Info, }, compare: { diff --git a/src/config/types.ts b/src/config/types.ts index 663c2f7d..68b5a648 100644 --- a/src/config/types.ts +++ b/src/config/types.ts @@ -52,25 +52,31 @@ export interface TensorflowConfig { export interface TrainConfig { /** Name of the training run under which to store logs. */ readonly name: string; - /** Number of training episodes to complete. */ - readonly episodes: number; - /** Batch predict config. */ + /** Number of learning steps. Omit or set to zero to train indefinitely. */ + readonly steps?: number; + /** Batch predict config for models outside the learning step. */ readonly batchPredict: BatchPredictConfig; /** Model config. */ readonly model: ModelConfig; /** Rollout config. */ readonly rollout: RolloutConfig; - /** Evaluation config. */ - readonly eval: EvalConfig; + /** Experience config. */ + readonly experience: ExperienceConfig; /** Learning config. */ readonly learn: LearnConfig; + /** Evaluation config. */ + readonly eval: EvalConfig; /** RNG config. */ readonly seeds?: TrainSeedConfig; + /** Whether to save model checkpoints as separate versions. */ + readonly savePreviousVersions?: boolean; /** - * Whether to save each previous version of the model separately after each - * training step. + * Step interval for saving model checkpoints. Omit to not store + * checkpoints. */ - readonly savePreviousVersions: boolean; + readonly checkpointInterval?: number; + /** Whether to display a progress bar if {@link steps} is also defined. */ + readonly progress?: boolean; /** Verbosity level for logging. Default highest. */ readonly verbose?: Verbose; } @@ -133,6 +139,8 @@ export interface RolloutConfig { /** * Fraction of self-play games that should by played against the model's * previous version rather than itself. + * + * The previous version is defined by the last {@link EvalConfig eval} step. */ readonly prev: number; } @@ -160,54 +168,53 @@ export interface PolicyConfig { * proportion of actions to take randomly rather than consulting the model. */ readonly exploration: number; - /** - * Exploration (epsilon) decay factor. Applied after each full episode of - * training. - */ + /** Exploration (epsilon) decay factor. Applied after each learning step. */ readonly explorationDecay: number; /** Minumum exploration (epsilon) value. */ readonly minExploration: number; } -/** Configuration for the evaluation process. */ -export interface EvalConfig { - /** Number of games to play against each eval opponent. */ - readonly numGames: number; - /** Game pool config. */ - readonly pool: GamePoolConfig; +/** Configuration for learning on experience generated from rollout games. */ +export interface ExperienceConfig { + /** Discount factor for future rewards. */ + readonly rewardDecay: number; + /** Size of the experience replay buffer. */ + readonly bufferSize: number; + /** + * Minimum number of experiences to generate before starting training. Must + * be at least as big as the batch size. + */ + readonly prefill: number; } /** Configuration for the learning process. */ export interface LearnConfig { - /** Number of batch updates before starting the next episode. */ - readonly updates: number; /** Optimizer learning rate. */ readonly learningRate: number; - /** Replay buffer config. */ - readonly buffer: BufferConfig; - /** Experience config. */ - readonly experience: ExperienceConfig; + /** Batch size. */ + readonly batchSize: number; /** * Whether to use a target network to increase training stability, or * `"double"` to implement double Q learning approach using the target net. */ - readonly target?: boolean | "double"; -} - -/** Configuration for the experience replay buffer. */ -export interface BufferConfig { - /** Number of experiences to buffer for shuffling. */ - readonly shuffle: number; - /** Batch size for learning updates. */ - readonly batch: number; - /** Number of batches to prefetch for learning. */ - readonly prefetch: number; + readonly target: boolean | "double"; + /** Step interval for updating the target network. */ + readonly targetInterval: number; + /** + * Step interval for tracking expensive batch update model metrics such as + * histograms which can significanly slow down training. + */ + readonly metricsInterval: number; } -/** Configuration for learning on experience generated from rollout games. */ -export interface ExperienceConfig { - /** Discount factor for future rewards. */ - readonly rewardDecay: number; +/** Configuration for the evaluation process. */ +export interface EvalConfig { + /** Number of games to play against each eval opponent. */ + readonly numGames: number; + /** Game pool config. */ + readonly pool: GamePoolConfig; + /** Step interval for performing model evaluations. */ + readonly interval: number; } /** Configuration for random number generators in the training script. */ diff --git a/src/demo/compare.ts b/src/demo/compare.ts index 2cb3167a..95ca9ecf 100644 --- a/src/demo/compare.ts +++ b/src/demo/compare.ts @@ -9,8 +9,8 @@ import { GameArgsGenOptions, GameArgsGenSeeders, GamePipeline, -} from "../train/game/pool"; -import {ModelWorker} from "../train/model/worker"; +} from "../game/pool"; +import {ModelWorker} from "../model/worker"; import {Logger} from "../util/logging/Logger"; import {Verbose} from "../util/logging/Verbose"; import {ensureDir} from "../util/paths/ensureDir"; @@ -198,8 +198,9 @@ void (async function () { ); } }); + await games.close(); } finally { - await games.cleanup(); + await games.terminate(); progressBar.terminate(); } diff --git a/src/demo/train.ts b/src/demo/train.ts index f4cf0a31..4a727c91 100644 --- a/src/demo/train.ts +++ b/src/demo/train.ts @@ -2,8 +2,9 @@ import {join} from "path"; import {setGracefulCleanup} from "tmp-promise"; import {config} from "../config"; +import {ModelWorker} from "../model/worker"; import {TrainingProgress} from "../train/TrainingProgress"; -import {ModelWorker} from "../train/model/worker"; +import {formatUptime} from "../util/format"; import {Logger} from "../util/logging/Logger"; import {Verbose} from "../util/logging/Verbose"; import {ensureDir} from "../util/paths/ensureDir"; @@ -70,7 +71,7 @@ void (async function () { ); } - const trainProgress = new TrainingProgress(config, logger); + const trainProgress = new TrainingProgress(config.train, logger); try { await models.train( model, @@ -84,6 +85,7 @@ void (async function () { await models.unload(model); await models.close(); trainProgress.done(); + logger.info("Uptime: " + formatUptime(process.uptime())); logger.info("Done"); } })(); diff --git a/src/game/agent/random.ts b/src/game/agent/random.ts new file mode 100644 index 00000000..04610984 --- /dev/null +++ b/src/game/agent/random.ts @@ -0,0 +1,13 @@ +import {Choice} from "../../psbot/handlers/battle/agent"; +import {ReadonlyBattleState} from "../../psbot/handlers/battle/state"; +import {Rng, shuffle} from "../../util/random"; + +/** BattleAgent that chooses actions randomly. */ +export async function randomAgent( + state: ReadonlyBattleState, + choices: Choice[], + random?: Rng, +): Promise { + shuffle(choices, random); + return await Promise.resolve(); +} diff --git a/src/train/game/experience/Experience.ts b/src/game/experience/Experience.ts similarity index 57% rename from src/train/game/experience/Experience.ts rename to src/game/experience/Experience.ts index f51b26e1..d2bc20cf 100644 --- a/src/train/game/experience/Experience.ts +++ b/src/game/experience/Experience.ts @@ -1,20 +1,13 @@ -import {BattleAgent} from "../../../psbot/handlers/battle/agent"; - -/** BattleAgent decision data. */ -export interface ExperienceAgentData { - /** State in which the action was taken. Flattened array data. */ - state: Float32Array[]; -} - -/** BattleAgent type that emits partial Experience objects. */ -export type ExperienceAgent = BattleAgent; +import {BattleAgent} from "../../psbot/handlers/battle/agent"; /** * BattleAgent decision evaluation data. Can be processed in batches to * effectively train a neural network. */ -export interface Experience extends ExperienceAgentData { - /** ID of the Choice that was taken. */ +export interface Experience { + /** State in which the action was taken. Flattened array data. */ + state: Float32Array[]; + /** Id of the action that was taken. */ action: number; /** Reward gained from the action and state transition. */ reward: number; @@ -23,3 +16,9 @@ export interface Experience extends ExperienceAgentData { /** Marks {@link nextState} as a terminal state so it won't be processed. */ done: boolean; } + +/** BattleAgent with additional args for experience generation. */ +export type ExperienceBattleAgent = BattleAgent< + TInfo, + [lastAction?: number, reward?: number] +>; diff --git a/src/train/game/experience/index.ts b/src/game/experience/index.ts similarity index 100% rename from src/train/game/experience/index.ts rename to src/game/experience/index.ts diff --git a/src/game/experience/tensor.ts b/src/game/experience/tensor.ts new file mode 100644 index 00000000..7326aa8f --- /dev/null +++ b/src/game/experience/tensor.ts @@ -0,0 +1,25 @@ +import * as tf from "@tensorflow/tfjs"; +import {Experience} from "./Experience"; + +/** {@link Experience} with values converted to {@link tf.Tensor tensors}. */ +export type TensorExperience = { + [T in keyof Experience]: Experience[T] extends number | boolean + ? tf.Scalar + : Experience[T] extends Float32Array[] + ? tf.Tensor[] + : never; +}; + +/** + * Batched {@link Experience} stacked {@link tf.Tensor tensors}. + * + * Essentially a list of {@link TensorExperience}s but with values converted to + * stacked tensors. + */ +export type BatchTensorExperience = { + [T in keyof Experience]: Experience[T] extends number | boolean + ? tf.Tensor1D + : Experience[T] extends Float32Array[] + ? tf.Tensor[] + : never; +}; diff --git a/src/train/game/pool/GamePipeline.ts b/src/game/pool/GamePipeline.ts similarity index 87% rename from src/train/game/pool/GamePipeline.ts rename to src/game/pool/GamePipeline.ts index 5764b67f..c926b2a1 100644 --- a/src/train/game/pool/GamePipeline.ts +++ b/src/game/pool/GamePipeline.ts @@ -1,9 +1,8 @@ import {join} from "path"; import {pipeline} from "stream/promises"; import {MessagePort} from "worker_threads"; -import {GamePoolConfig, ExperienceConfig} from "../../../config/types"; -import {generatePsPrngSeed, rng, Rng, Seeder} from "../../../util/random"; -import {Experience} from "../experience"; +import {GamePoolConfig, ExperienceConfig} from "../../config/types"; +import {generatePsPrngSeed, rng, Rng, Seeder} from "../../util/random"; import { GamePool, GamePoolAgentConfig, @@ -38,11 +37,6 @@ export interface GameArgsGenOptions { readonly experienceConfig?: ExperienceConfig; /** Random seed generators. */ readonly seeders?: GameArgsGenSeeders; - /** - * Callback for processing Experience objects if the game is configured for - * it. - */ - readonly experienceCallback?: (exp: Experience) => void; } /** Random number generators used by the game and policy. */ @@ -59,8 +53,6 @@ export interface GameArgsGenSeeders { export class GamePipeline { /** Manages game threads. */ private readonly pool: GamePool; - /** Used to abort the run if needed. */ - private readonly ac = new AbortController(); /** * Creates a GamePipeline. @@ -71,12 +63,22 @@ export class GamePipeline { this.pool = new GamePool(config); } - /** Closes game threads. Can be called while games are still running. */ - public async cleanup(): Promise { - this.ac.abort(); + /** + * Waits for in-progress games to complete then closes game threads. Calls + * to {@link run} that are currently running may never resolve. + */ + public async close(): Promise { return await this.pool.close(); } + /** + * Terminates in-progress games and closes the thread pool. Calls to + * {@link run} that are currently running may never resolve. + */ + public async terminate(): Promise { + return await this.pool.terminate(); + } + /** * Starts the game pipeline. Can be called multiple times. * @@ -98,7 +100,6 @@ export class GamePipeline { await callback?.(result); } }, - {signal: this.ac.signal}, ); } @@ -112,7 +113,6 @@ export class GamePipeline { reduceLogs, experienceConfig, seeders, - experienceCallback, }: GameArgsGenOptions): Generator { const battleRandom = seeders?.battle && rng(seeders.battle()); const teamRandom = seeders?.team && rng(seeders.team()); @@ -141,7 +141,6 @@ export class GamePipeline { seed: generatePsPrngSeed(battleRandom), ...(experienceConfig && {experienceConfig}), }, - ...(experienceCallback && {experienceCallback}), }; } } diff --git a/src/train/game/pool/GamePool.ts b/src/game/pool/GamePool.ts similarity index 91% rename from src/train/game/pool/GamePool.ts rename to src/game/pool/GamePool.ts index 02cbacda..722ab75d 100644 --- a/src/train/game/pool/GamePool.ts +++ b/src/game/pool/GamePool.ts @@ -1,9 +1,8 @@ import {resolve} from "path"; import {MessagePort} from "worker_threads"; import {PRNGSeed} from "@pkmn/sim"; -import {ExperienceConfig, GamePoolConfig} from "../../../config/types"; -import {ThreadPool} from "../../pool"; -import {Experience} from "../experience"; +import {ExperienceConfig, GamePoolConfig} from "../../config/types"; +import {ThreadPool} from "../../util/pool/ThreadPool"; import {SimResult} from "../sim/playGame"; import { GameAgentConfig, @@ -24,11 +23,6 @@ export interface GamePoolArgs { ) => MessagePort | Promise; /** Args for starting the game. */ readonly play: PlayArgs; - /** - * Callback for processing Experience objects if the game is configured for - * it. - */ - readonly experienceCallback?: (experience: Experience) => void; } /** Config for {@link GamePool.add} agents. */ @@ -128,8 +122,13 @@ export class GamePool { } } - /** Closes the thread pool. */ + /** Waits for in-progress games to complete then closes the thread pool. */ public async close(): Promise { return await this.pool.close(); } + + /** Terminates in-progress games and closes the thread pool. */ + public async terminate(): Promise { + return await this.pool.terminate(); + } } diff --git a/src/train/game/pool/GamePoolStream.ts b/src/game/pool/GamePoolStream.ts similarity index 100% rename from src/train/game/pool/GamePoolStream.ts rename to src/game/pool/GamePoolStream.ts diff --git a/src/train/game/pool/index.ts b/src/game/pool/index.ts similarity index 100% rename from src/train/game/pool/index.ts rename to src/game/pool/index.ts diff --git a/src/train/game/pool/worker/GameProtocol.ts b/src/game/pool/worker/GameProtocol.ts similarity index 89% rename from src/train/game/pool/worker/GameProtocol.ts rename to src/game/pool/worker/GameProtocol.ts index 160c1fef..a34e48f6 100644 --- a/src/train/game/pool/worker/GameProtocol.ts +++ b/src/game/pool/worker/GameProtocol.ts @@ -1,9 +1,8 @@ /** @file Defines the protocol typings for GameWorkers. */ import {MessagePort} from "worker_threads"; import {PRNGSeed} from "@pkmn/sim"; -import {PortMessageBase, PortResultBase} from "../../../port/PortProtocol"; -import {WorkerProtocol} from "../../../port/WorkerProtocol"; -import {Experience} from "../../experience"; +import {PortMessageBase, PortResultBase} from "../../../util/port/PortProtocol"; +import {WorkerProtocol} from "../../../util/worker/WorkerProtocol"; import {SimResult} from "../../sim/playGame"; import {PlayArgs} from "../GamePool"; @@ -102,14 +101,8 @@ export type GameResult = GameProtocol[GameRequestType]["result"]; /** Base interface for game worker message results. */ type GameResultBase = PortResultBase; -interface GamePlayResultExp extends GameResultBase<"play"> { - /** Generated experience data. */ - experience: Experience; - /** @override */ - done: false; -} - -interface GamePlayResultDone +/** Result of a game after it has been completed and processed by the worker. */ +export interface GamePlayResult extends GameResultBase<"play">, Omit { /** @@ -121,6 +114,3 @@ interface GamePlayResultDone /** @override */ done: true; } - -/** Result of a game after it has been completed and processed by the worker. */ -export type GamePlayResult = GamePlayResultExp | GamePlayResultDone; diff --git a/src/train/game/pool/worker/GameWorker.ts b/src/game/pool/worker/GameWorker.ts similarity index 65% rename from src/train/game/pool/worker/GameWorker.ts rename to src/game/pool/worker/GameWorker.ts index 8f50008e..f6895424 100644 --- a/src/train/game/pool/worker/GameWorker.ts +++ b/src/game/pool/worker/GameWorker.ts @@ -1,6 +1,6 @@ import {deserialize} from "v8"; import {Worker} from "worker_threads"; -import {WorkerPort} from "../../../port/WorkerPort"; +import {WorkerPort} from "../../../util/worker/WorkerPort"; import {GamePoolArgs, GamePoolResult} from "../GamePool"; import {GameProtocol, GameAgentConfig, GamePlay} from "./GameProtocol"; @@ -23,6 +23,11 @@ export class GameWorker { await this.workerPort.close(); } + /** Force-closes the worker. */ + public async terminate(): Promise { + await this.workerPort.terminate(); + } + /** Queues a game for the worker. */ public async playGame(args: GamePoolArgs): Promise { const msg: GamePlay = { @@ -59,33 +64,24 @@ export class GameWorker { ), workerResult => { let result: GamePoolResult; - if (!workerResult.done) { - args.experienceCallback?.(workerResult.experience); + if (workerResult.type === "error") { + result = { + id: args.id, + agents: [args.agents[0].name, args.agents[1].name], + err: workerResult.err, + }; } else { - if (workerResult.type === "error") { - result = { - id: args.id, - agents: [ - args.agents[0].name, - args.agents[1].name, - ], - err: workerResult.err, - }; - } else { - result = { - id: args.id, - agents: workerResult.agents, - winner: workerResult.winner, - // GamePort doesn't automatically deserialize - // errors outside of PortResultError (where - // type=error). - ...(workerResult.err && { - err: deserialize(workerResult.err) as Error, - }), - }; - } - res(result); + result = { + id: args.id, + agents: workerResult.agents, + winner: workerResult.winner, + // Manually deserialize game error. + ...(workerResult.err && { + err: deserialize(workerResult.err) as Error, + }), + }; } + res(result); }, ), ); diff --git a/src/train/game/pool/worker/index.ts b/src/game/pool/worker/index.ts similarity index 100% rename from src/train/game/pool/worker/index.ts rename to src/game/pool/worker/index.ts diff --git a/src/train/game/pool/worker/worker.js b/src/game/pool/worker/worker.js similarity index 100% rename from src/train/game/pool/worker/worker.js rename to src/game/pool/worker/worker.js diff --git a/src/train/game/pool/worker/worker.ts b/src/game/pool/worker/worker.ts similarity index 66% rename from src/train/game/pool/worker/worker.ts rename to src/game/pool/worker/worker.ts index 45f050bc..9a981733 100644 --- a/src/train/game/pool/worker/worker.ts +++ b/src/game/pool/worker/worker.ts @@ -1,11 +1,11 @@ import * as stream from "stream"; import {serialize} from "v8"; import {parentPort, TransferListItem, workerData} from "worker_threads"; -import {rng} from "../../../../util/random"; import {ModelPort} from "../../../model/port"; -import {RawPortResultError} from "../../../port/PortProtocol"; -import {WorkerClosed} from "../../../port/WorkerProtocol"; -import {randomAgent, randomExpAgent} from "../../agent/random"; +import {RawPortResultError} from "../../../util/port/PortProtocol"; +import {rng} from "../../../util/random"; +import {WorkerClosed} from "../../../util/worker/WorkerProtocol"; +import {randomAgent} from "../../agent/random"; import {playGame, SimArgsAgent} from "../../sim/playGame"; import { GameMessage, @@ -36,15 +36,21 @@ const gameStream = new stream.Writable({ const transferList: TransferListItem[] = []; const modelPorts: ModelPort[] = []; try { + const experiencePorts: {[name: string]: ModelPort} = {}; const agents = msg.agents.map(config => { switch (config.exploit.type) { case "model": { const modelPort = new ModelPort(config.exploit.port); modelPorts.push(modelPort); + if (config.emitExperience) { + experiencePorts[config.name] = modelPort; + } return { name: config.name, agent: modelPort.getAgent(config.explore), - emitExperience: !!config.emitExperience, + ...(config.emitExperience && { + emitExperience: true, + }), ...(config.seed && {seed: config.seed}), }; } @@ -52,29 +58,12 @@ const gameStream = new stream.Writable({ const agentRandom = config.exploit.seed ? rng(config.exploit.seed) : undefined; - return config.emitExperience - ? ({ - name: config.name, - agent: async (state, choices) => - await randomExpAgent( - state, - choices, - agentRandom, - ), - emitExperience: true, - ...(config.seed && {seed: config.seed}), - } as SimArgsAgent) - : ({ - name: config.name, - agent: async (state, choices) => - await randomAgent( - state, - choices, - agentRandom, - ), - emitExperience: false, - ...(config.seed && {seed: config.seed}), - } as SimArgsAgent); + return { + name: config.name, + agent: async (state, choices) => + await randomAgent(state, choices, agentRandom), + ...(config.seed && {seed: config.seed}), + }; } default: { const unsupported: unknown = config.exploit; @@ -96,19 +85,16 @@ const gameStream = new stream.Writable({ ...(msg.play.onlyLogOnError && {onlyLogOnError: true}), ...(msg.play.seed && {seed: msg.play.seed}), }, - experience => { - const expResult: GamePlayResult = { - type: "play", - rid: msg.rid, - done: false, - experience, - }; - parentPort!.postMessage( - expResult, - [experience.state, experience.nextState].flatMap( - s => s?.map(a => a.buffer) ?? [], - ), - ); + async (name, state, action, reward) => { + if ( + !Object.prototype.hasOwnProperty.call( + experiencePorts, + name, + ) + ) { + return; + } + await experiencePorts[name].finalize(state, action, reward); }, ); @@ -146,7 +132,17 @@ parentPort.on("message", function handle(msg: GameMessage) { case "play": // Note: Due to stream buffering, this may not be immediately // processed. - gameStream.write(msg); + try { + gameStream.write(msg); + } catch (err) { + const result: RawPortResultError = { + type: "error", + rid: msg.rid, + done: true, + err: serialize(err), + }; + parentPort!.postMessage(result, [result.err.buffer]); + } break; case "close": gameStream.end(() => { diff --git a/src/train/game/sim/playGame.ts b/src/game/sim/playGame.ts similarity index 66% rename from src/train/game/sim/playGame.ts rename to src/game/sim/playGame.ts index 32e1862d..008967f1 100644 --- a/src/train/game/sim/playGame.ts +++ b/src/game/sim/playGame.ts @@ -1,8 +1,6 @@ import {PRNGSeed} from "@pkmn/sim"; -import {BattleAgent} from "../../../psbot/handlers/battle/agent"; -import {BattleParser} from "../../../psbot/handlers/battle/parser/BattleParser"; -import {main} from "../../../psbot/handlers/battle/parser/main"; -import {Experience, ExperienceAgent} from "../experience"; +import {BattleAgent} from "../../psbot/handlers/battle/agent"; +import {main} from "../../psbot/handlers/battle/parser/main"; import {experienceBattleParser, PlayerOptions, startPsBattle} from "./ps"; /** Arguments for general battle sims. */ @@ -29,32 +27,18 @@ export interface SimArgs { readonly seed?: PRNGSeed; } -/** Base interface for {@link SimArgsAgent}. */ -interface SimArgsAgentBase { +/** Config for a {@link BattleAgent}. */ +export interface SimArgsAgent { /** Name for logging. */ readonly name: string; /** BattleAgent function. */ - readonly agent: TAgent; - /** - * Whether the {@link agent} emits ExperienceAgentData that should be used - * to generate {@link Experience} objects. - */ - readonly emitExperience: TExp; + readonly agent: BattleAgent; + /** Whether to track experience creation. */ + readonly emitExperience?: true; /** Seed for generating the random team. */ readonly seed?: PRNGSeed; } -/** - * {@link SimArgsAgent} that doesn't emit ExperienceAgentData or wants that data - * to be ignored. - */ -export type SimArgsNoexpAgent = SimArgsAgentBase; -/** {@link SimArgsAgent} that emits ExperienceAgentData. */ -export type SimArgsExpAgent = SimArgsAgentBase; - -/** Config for a {@link BattleAgent}. */ -export type SimArgsAgent = SimArgsNoexpAgent | SimArgsExpAgent; - /** Base simulator result type. */ export interface SimResult { /** Names of the two agents that participated in the game. */ @@ -72,13 +56,17 @@ export interface SimResult { * Plays a single game and processes the results. * * @param args Arguments for the simulator. - * @param experienceCallback Callback for processing {@link Experience}s if any - * of the provided BattleAgents are configured to emit them. If omitted, the - * Experiences will instead be discarded. + * @param experienceCallback Callback for processing the final state transition + * for experience generation. Includes agent name as an identifier. */ export async function playGame( args: SimArgs, - experienceCallback?: (exp: Experience) => void, + experienceCallback?: ( + name: string, + state?: Float32Array[], + action?: number, + reward?: number, + ) => Promise, ): Promise { // Detect battle agents that want to generate Experience objects. const [p1, p2] = args.agents.map(function (agentArgs) { @@ -95,13 +83,18 @@ export async function playGame( return { name: agentArgs.name, agent: agentArgs.agent, - parser: agentArgs.emitExperience - ? (experienceBattleParser( - main, - experienceCallback, - agentArgs.name /*username*/, - ) as BattleParser) - : main, + parser: experienceBattleParser( + main, + async (state, action, reward) => + await experienceCallback( + agentArgs.name, + state, + action, + reward, + ), + agentArgs.name /*username*/, + args.maxTurns, + ), ...(agentArgs.seed && {seed: agentArgs.seed}), }; }); diff --git a/src/game/sim/ps/experienceBattleParser.ts b/src/game/sim/ps/experienceBattleParser.ts new file mode 100644 index 00000000..c46033e6 --- /dev/null +++ b/src/game/sim/ps/experienceBattleParser.ts @@ -0,0 +1,116 @@ +import { + BattleAgent, + Choice, + choiceIds, +} from "../../../psbot/handlers/battle/agent"; +import { + allocEncodedState, + encodeState, +} from "../../../psbot/handlers/battle/ai/encoder"; +import {BattleParser} from "../../../psbot/handlers/battle/parser/BattleParser"; +import {ExperienceBattleAgent} from "../../experience"; + +/** + * Wraps a BattleParser to track rewards/decisions and emit Experience objects. + * + * Returned wrapper requires an {@link ExperienceAgent}. + * + * @template TArgs Parser arguments. + * @template TResult Parser return type. + * @param parser Parser function to wrap. + * @param callback Callback for processing the final state transition. + * @param username Client's username to parse game-over reward. + * @param maxTurns Configured turn limit. + * @returns The wrapped BattleParser function. + */ +export function experienceBattleParser< + TArgs extends unknown[] = unknown[], + TResult = unknown, +>( + parser: BattleParser, + callback: ( + state?: Float32Array[], + action?: number, + reward?: number, + ) => Promise, + username: string, + maxTurns?: number, +): BattleParser { + return async function experienceBattleParserImpl(ctx, ...args: TArgs) { + let forcedGameOver = false; + let lastChoice: Choice | null = null; + let reward = 0; + const result = await parser( + { + ...ctx, + // Extract additional info from the ExperienceAgent. + async agent(state, choices, logger) { + const lastAction = lastChoice + ? choiceIds[lastChoice] + : undefined; + lastChoice = null; + const lastReward = reward; + reward = 0; + + ctx.logger.debug(`Reward = ${lastReward}`); + return await ctx.agent( + state, + choices, + logger, + lastAction, + lastReward, + ); + }, + iter: { + ...ctx.iter, + async next() { + // Observe events before the parser consumes them. + const r = await ctx.iter.next(); + if (r.done) { + return r; + } + switch (r.value.args[0]) { + case "turn": + if ( + maxTurns !== undefined && + Number(r.value.args[1]) > maxTurns + ) { + forcedGameOver = true; + } + break; + case "win": + // Add win/loss reward. + reward += r.value.args[1] === username ? 1 : -1; + break; + default: + } + return r; + }, + }, + // Extract the last choice that was accepted. + async sender(choice) { + const r = await ctx.sender(choice); + if (!r) { + lastChoice = choice; + } + return r; + }, + }, + ...args, + ); + // Emit final experience at the end of the game. + if (lastChoice && !forcedGameOver) { + const stateData = allocEncodedState(); + encodeState(stateData, ctx.state); + const lastAction = lastChoice ? choiceIds[lastChoice] : undefined; + ctx.logger.debug(`Finalizing experience: reward = ${reward}`); + await callback(stateData, lastAction, reward); + } else { + // Game result was forced, so the previous experience was actually + // the final one. + ctx.logger.debug("Finalizing experience: forced game over"); + await callback(); + } + return result; + }; +} diff --git a/src/train/game/sim/ps/index.ts b/src/game/sim/ps/index.ts similarity index 100% rename from src/train/game/sim/ps/index.ts rename to src/game/sim/ps/index.ts diff --git a/src/train/game/sim/ps/ps.ts b/src/game/sim/ps/ps.ts similarity index 92% rename from src/train/game/sim/ps/ps.ts rename to src/game/sim/ps/ps.ts index 9511677d..7ba49fa3 100644 --- a/src/train/game/sim/ps/ps.ts +++ b/src/game/sim/ps/ps.ts @@ -3,19 +3,19 @@ import {setTimeout} from "timers/promises"; import {TeamGenerators} from "@pkmn/randoms"; import {BattleStreams, PRNGSeed, Teams} from "@pkmn/sim"; import {SideID} from "@pkmn/types"; -import {Sender} from "../../../../psbot/PsBot"; -import {BattleHandler} from "../../../../psbot/handlers/battle"; -import {BattleAgent} from "../../../../psbot/handlers/battle/agent"; -import {BattleParser} from "../../../../psbot/handlers/battle/parser/BattleParser"; +import {Sender} from "../../../psbot/PsBot"; +import {BattleHandler} from "../../../psbot/handlers/battle"; +import {BattleAgent} from "../../../psbot/handlers/battle/agent"; +import {BattleParser} from "../../../psbot/handlers/battle/parser/BattleParser"; import { Event, HaltEvent, MessageParser, RoomEvent, -} from "../../../../psbot/parser"; -import {DeferredFile} from "../../../../util/DeferredFile"; -import {LogFunc, Logger} from "../../../../util/logging/Logger"; -import {Verbose} from "../../../../util/logging/Verbose"; +} from "../../../psbot/parser"; +import {DeferredFile} from "../../../util/DeferredFile"; +import {LogFunc, Logger} from "../../../util/logging/Logger"; +import {Verbose} from "../../../util/logging/Verbose"; import {SimArgs, SimResult} from "../playGame"; Teams.setGeneratorFactory(TeamGenerators); @@ -149,7 +149,7 @@ export async function startPsBattle( // Start event loop for this side of the battle. - // Note: keep this separate from the above pipeline streams since for + // Note: Keep this separate from the above pipeline streams since for // some reason it causes the whole worker process to crash when an // error is encountered due to the underlying handler.finish() promise // rejecting before the method itself can be called/caught. @@ -165,7 +165,7 @@ export async function startPsBattle( } await wrapTimeout( async () => await handler.handle(e as RoomEvent), - 30e3 /*30s*/, + 60e3 /*60s*/, ); } } catch (e) { @@ -182,7 +182,7 @@ export async function startPsBattle( } else { await handler.finish(); } - }, 30e3 /*30s*/); + }, 60e3 /*60s*/); } catch (e) { if (loopErr !== e) { logError(innerLog, battleStream, e as Error); @@ -271,7 +271,7 @@ function logError( function throwLog(logPath?: string): never { throw new Error( - "startPSBattle() encountered an error." + + "startPsBattle() encountered an error." + (logPath ? `Check ${logPath} for details.` : ""), ); } @@ -284,9 +284,7 @@ async function wrapTimeout( return ( await Promise.all([ f().finally(() => ac.abort()), - setTimeout(milliseconds, true, { - signal: ac.signal, - }) + setTimeout(milliseconds, true, {signal: ac.signal, ref: false}) .catch(err => { if (!(err instanceof Error) || err.name !== "AbortError") { throw err; diff --git a/src/model/model.ts b/src/model/model.ts index c6a9b5d4..49491916 100644 --- a/src/model/model.ts +++ b/src/model/model.ts @@ -1,10 +1,17 @@ +/** @file Defines the model and a few utilities for interacting with it. */ import * as tf from "@tensorflow/tfjs"; import {ModelAggregateConfig, ModelConfig} from "../config/types"; import {Moveset} from "../psbot/handlers/battle/state/Moveset"; import {Team} from "../psbot/handlers/battle/state/Team"; import {Rng, rng} from "../util/random"; import * as customLayers from "./custom_layers"; -import {modelInputShapesMap, verifyModel} from "./shapes"; +import { + modelInputNames, + modelInputShapes, + modelInputShapesMap, + modelOutputName, + modelOutputShape, +} from "./shapes"; /** * Creates a default model for training. @@ -884,3 +891,91 @@ function duelingQ( .apply([stateValue, centeredAdv]) as tf.SymbolicTensor; return actionQ; } + +/** + * Verifies that the model is compatible with the input/output shapes as + * specified by {@link modelInputShapes} and {@link modelOutputShape}. + * + * @throws Error if invalid input/output shapes. + */ +export function verifyModel(model: tf.LayersModel): void { + validateInput(model.input); + validateOutput(model.output); +} + +/** Ensures that the model input shape is valid. */ +function validateInput(input: tf.SymbolicTensor | tf.SymbolicTensor[]): void { + if (!Array.isArray(input)) { + throw new Error("Model input is not an array"); + } + if (input.length !== modelInputShapes.length) { + throw new Error( + `Expected ${modelInputShapes.length} inputs but found ` + + `${input.length}`, + ); + } + for (let i = 0; i < modelInputShapes.length; ++i) { + const {shape} = input[i]; + const expectedShape = [null, ...modelInputShapes[i]]; + let invalid: boolean | undefined; + if (shape.length !== expectedShape.length) { + invalid = true; + } else { + for (let j = 0; j < expectedShape.length; ++j) { + if (shape[j] !== expectedShape[j]) { + invalid = true; + break; + } + } + } + if (invalid) { + throw new Error( + `Expected input ${i} (${modelInputNames[i]}) to have shape ` + + `${JSON.stringify(expectedShape)} but found ` + + `${JSON.stringify(shape)}`, + ); + } + } +} + +/** Ensures that the model output shape is valid. */ +function validateOutput(output: tf.SymbolicTensor | tf.SymbolicTensor[]): void { + if (Array.isArray(output)) { + throw new Error("Model output must not be an array"); + } + const expectedShape = [null, ...modelOutputShape]; + for (let i = 0; i < expectedShape.length; ++i) { + if (output.shape[i] !== expectedShape[i]) { + throw new Error( + `Expected output (${modelOutputName}) to have shape ` + + `${JSON.stringify(expectedShape)} but found ` + + `${JSON.stringify(output.shape)}`, + ); + } + } +} + +/** + * Converts the data lists into tensors + * + * @param includeBatchDim Whether to include an extra 1 dimension in the first + * axis for the batch. Default false. + */ +export function encodedStateToTensors( + arr: Float32Array[], + includeBatchDim?: boolean, +): tf.Tensor[] { + if (arr.length !== modelInputShapes.length) { + throw new Error( + `Expected ${modelInputShapes.length} inputs but found ` + + `${arr.length}`, + ); + } + return modelInputShapes.map((shape, i) => + tf.tensor( + arr[i], + includeBatchDim ? [1, ...shape] : [...shape], + "float32", + ), + ); +} diff --git a/src/model/port/ModelPort.ts b/src/model/port/ModelPort.ts new file mode 100644 index 00000000..0a510909 --- /dev/null +++ b/src/model/port/ModelPort.ts @@ -0,0 +1,152 @@ +import {MessagePort} from "worker_threads"; +import {randomAgent} from "../../game/agent/random"; +import {ExperienceBattleAgent} from "../../game/experience"; +import {AgentExploreConfig} from "../../game/pool/worker"; +import {verifyInputData, verifyOutputData} from "../../model/verify"; +import { + allocEncodedState, + encodeState, +} from "../../psbot/handlers/battle/ai/encoder"; +import {maxAgent} from "../../psbot/handlers/battle/ai/maxAgent"; +import {WrappedError} from "../../util/errors/WrappedError"; +import {AsyncPort, ProtocolResultRaw} from "../../util/port/AsyncPort"; +import {rng} from "../../util/random"; +import {ModelPortProtocol, PredictResult} from "./ModelPortProtocol"; + +/** + * Abstracts the interface between a game worker and a model owned by the main + * ModelWorker. + * + * Intended to be used by only one BattleAgent within a game worker that + * received a port to connect to a model. + */ +export class ModelPort { + /** Port wrapper. */ + private readonly asyncPort: AsyncPort< + MessagePort, + ModelPortProtocol, + keyof ModelPortProtocol + >; + + /** + * Creates a ModelPort. + * + * @param port Message port. + */ + public constructor(port: MessagePort) { + this.asyncPort = new AsyncPort(port); + port.on( + "message", + ( + res: ProtocolResultRaw< + ModelPortProtocol, + keyof ModelPortProtocol, + keyof ModelPortProtocol + >, + ) => this.asyncPort.receiveMessage(res), + ); + port.on("error", (err: Error) => + this.asyncPort.receiveError( + new WrappedError( + err, + msg => + "ModelPort encountered an unhandled exception: " + msg, + ), + ), + ); + } + + /** Closes the connection. */ + public close(): void { + this.asyncPort.port.close(); + } + + /** + * Creates a BattleAgent from this port. + * + * @param explore Exploration policy config. + */ + public getAgent(explore?: AgentExploreConfig): ExperienceBattleAgent { + const random = explore?.seed ? rng(explore.seed) : Math.random; + + const greedyAgent = maxAgent( + async (state, lastAction?: number, reward?: number) => { + const stateData = allocEncodedState(); + encodeState(stateData, state); + verifyInputData(stateData); + + const result = await this.predict( + stateData, + lastAction, + reward, + ); + verifyOutputData(result.output); + + return result.output; + }, + ); + + return async function modelPortAgent( + state, + choices, + logger, + lastAction, + reward, + ) { + await greedyAgent(state, choices, logger, lastAction, reward); + + if (explore && random() < explore.factor) { + logger?.debug("Exploring"); + await randomAgent(state, choices, random); + } + }; + } + + /** + * Finalizes game experience generation. + * @param state Final state. + * @param lastAction Last action taken before arriving at state. + * @param reward Final reward. + */ + public async finalize( + state?: Float32Array[], + lastAction?: number, + reward?: number, + ): Promise { + return await new Promise((res, rej) => + this.asyncPort.postMessage<"finalize">( + { + type: "finalize", + rid: this.asyncPort.nextRid(), + ...(state && {state}), + ...(lastAction !== undefined && {lastAction}), + ...(reward !== undefined && {reward}), + }, + state?.map(a => a.buffer) ?? [], + result => (result.type === "error" ? rej(result.err) : res()), + ), + ); + } + + /** Requests a prediction from the neural network. */ + private async predict( + state: Float32Array[], + lastAction?: number, + reward?: number, + ): Promise { + return await new Promise((res, rej) => + this.asyncPort.postMessage<"predict">( + { + type: "predict", + rid: this.asyncPort.nextRid(), + state, + ...(lastAction !== undefined && {lastAction}), + ...(reward !== undefined && {reward}), + }, + state.map(a => a.buffer), + result => + result.type === "error" ? rej(result.err) : res(result), + ), + ); + } +} diff --git a/src/model/port/ModelPortProtocol.ts b/src/model/port/ModelPortProtocol.ts new file mode 100644 index 00000000..ec5b5ac9 --- /dev/null +++ b/src/model/port/ModelPortProtocol.ts @@ -0,0 +1,66 @@ +/** @file Defines the protocol typings for ModelPorts. */ +import { + PortProtocol, + PortRequestBase, + PortResultBase, +} from "../../util/port/PortProtocol"; + +/** ModelPort request protocol typings. */ +export interface ModelPortProtocol + extends PortProtocol<"predict" | "finalize"> { + predict: {message: PredictMessage; result: PredictWorkerResult}; + finalize: {message: FinalizeMessage; result: FinalizeResult}; +} + +/** The types of requests that can be made to the model port. */ +export type ModelPortRequestType = keyof ModelPortProtocol; + +/** Types of messages that the ModelPort can send. */ +export type ModelPortMessage = + ModelPortProtocol[ModelPortRequestType]["message"]; + +/** Base interface for the predict request protocol. */ +type PredictRequestBase = PortRequestBase; + +/** Prediction request message format. */ +export interface PredictMessage extends PredictRequestBase<"predict"> { + /** State data. */ + state: Float32Array[]; + /** Id of the previous action. Used for experience generation. */ + lastAction?: number; + /** Reward from the state transition. Used for experience generation. */ + reward?: number; +} + +/** Finalizes experience generation for the current game. */ +export interface FinalizeMessage extends PredictRequestBase<"finalize"> { + /** Data representing the final state. Omit to use previous state. */ + state?: Float32Array[]; + /** Id of the previous action. Omit to use the previous action. */ + lastAction?: number; + /** Reward from game end. Omit to use the previous reward */ + reward?: number; +} + +/** Types of results that can be given to the ModelPort. */ +export type ModelPortResult = ModelPortProtocol[ModelPortRequestType]["result"]; + +/** Prediction returned from the model. */ +export interface PredictWorkerResult + extends PortResultBase<"predict">, + PredictResult { + /** @override */ + done: true; +} + +/** Result from a prediction. */ +export interface PredictResult { + /** Action output. */ + output: Float32Array; +} + +/** Result from finalizing a game. */ +export interface FinalizeResult extends PortResultBase<"finalize"> { + /** @override */ + done: true; +} diff --git a/src/train/model/port/index.ts b/src/model/port/index.ts similarity index 100% rename from src/train/model/port/index.ts rename to src/model/port/index.ts diff --git a/src/model/shapes.ts b/src/model/shapes.ts index afa1c05f..21488505 100644 --- a/src/model/shapes.ts +++ b/src/model/shapes.ts @@ -1,5 +1,7 @@ -/** @file Specifies the input/output shapes of the model. */ -import * as tf from "@tensorflow/tfjs"; +/** + * @file Specifies the input/output shapes of the model. Safe to import in + * non-tf threads. + */ import {intToChoice} from "../psbot/handlers/battle/agent"; import * as encoders from "../psbot/handlers/battle/ai/encoder/encoders"; import {Moveset} from "../psbot/handlers/battle/state/Moveset"; @@ -80,71 +82,9 @@ export const modelInputShapesMap: { ); /** - * Output shape for the {@link createModel model}, with the batch dimension. + * Output shape for the {@link createModel model}, without the batch dimension. */ -export const modelOutputShape: Readonly = [null, intToChoice.length]; +export const modelOutputShape = [intToChoice.length]; /** Output name for the {@link createModel model}. */ export const modelOutputName = "action"; - -/** - * Verifies that the model is compatible with the input/output shapes as - * specified by {@link modelInputShapes} and {@link modelOutputShape}. - * - * @throws Error if invalid input/output shapes. - */ -export function verifyModel(model: tf.LayersModel): void { - validateInput(model.input); - validateOutput(model.output); -} - -/** Ensures that the model input shape is valid. */ -function validateInput(input: tf.SymbolicTensor | tf.SymbolicTensor[]): void { - if (!Array.isArray(input)) { - throw new Error("Model input is not an array"); - } - if (input.length !== modelInputShapes.length) { - throw new Error( - `Expected ${modelInputShapes.length} inputs but found ` + - `${input.length}`, - ); - } - for (let i = 0; i < modelInputShapes.length; ++i) { - const {shape} = input[i]; - const expectedShape = [null, ...modelInputShapes[i]]; - let invalid: boolean | undefined; - if (shape.length !== expectedShape.length) { - invalid = true; - } else { - for (let j = 0; j < expectedShape.length; ++j) { - if (shape[j] !== expectedShape[j]) { - invalid = true; - break; - } - } - } - if (invalid) { - throw new Error( - `Expected input ${i} (${modelInputNames[i]}) to have shape ` + - `${JSON.stringify(expectedShape)} but found ` + - `${JSON.stringify(shape)}`, - ); - } - } -} - -/** Ensures that the model output shape is valid. */ -function validateOutput(output: tf.SymbolicTensor | tf.SymbolicTensor[]): void { - if (Array.isArray(output)) { - throw new Error("Model output must not be an array"); - } - for (let i = 0; i < modelOutputShape.length; ++i) { - if (output.shape[i] !== modelOutputShape[i]) { - throw new Error( - `Expected output (${modelOutputName}) to have shape ` + - `${JSON.stringify(modelOutputShape)} but found ` + - `${JSON.stringify(output.shape)}`, - ); - } - } -} diff --git a/src/model/verify.ts b/src/model/verify.ts new file mode 100644 index 00000000..d2ad67ee --- /dev/null +++ b/src/model/verify.ts @@ -0,0 +1,51 @@ +/** + * @file Utility functions for model input/output data verification. Safe to + * import in non-tf threads. + */ +import {intToChoice} from "../psbot/handlers/battle/agent"; +import {flattenedInputShapes, modelInputNames} from "./shapes"; + +/** Assertions for model output data. */ +export function verifyInputData(data: Float32Array[]): void { + for (let i = 0; i < data.length; ++i) { + const arr = data[i]; + if (arr.length !== flattenedInputShapes[i]) { + throw new Error( + `Model input ${i} (${modelInputNames[i]}) requires ` + + `${flattenedInputShapes[i]} elements but got ${arr.length}`, + ); + } + for (let j = 0; j < arr.length; ++j) { + const value = arr[j]; + if (isNaN(value)) { + throw new Error( + `Model input ${i} (${modelInputNames[i]}) contains ` + + `NaN at index ${j}`, + ); + } + if (value < -1 || value > 1) { + throw new Error( + `Model input ${i} (${modelInputNames[i]}) contains ` + + `an out-of-range value ${value} at index ${j}`, + ); + } + } + } +} + +/** Assertions for model output data. */ +export function verifyOutputData(output: Float32Array): void { + if (output.length !== intToChoice.length) { + throw new Error( + `Expected ${intToChoice.length} output values but got ` + + `${output.length}`, + ); + } + for (let i = 0; i < output.length; ++i) { + if (isNaN(output[i])) { + throw new Error( + `Model output contains NaN for action ${i} (${intToChoice[i]})`, + ); + } + } +} diff --git a/src/train/model/worker/Metrics.ts b/src/model/worker/Metrics.ts similarity index 94% rename from src/train/model/worker/Metrics.ts rename to src/model/worker/Metrics.ts index b59fa069..a7d74935 100644 --- a/src/train/model/worker/Metrics.ts +++ b/src/model/worker/Metrics.ts @@ -1,6 +1,6 @@ import {workerData} from "worker_threads"; import * as tf from "@tensorflow/tfjs"; -import {importTfn} from "../../../util/tfn"; +import {importTfn} from "../../util/tfn"; import {ModelWorkerData} from "./ModelProtocol"; const {gpu, metricsPath} = workerData as ModelWorkerData; @@ -9,7 +9,7 @@ const tfn = importTfn(gpu); /** Used for writing model summary metrics to Tensorboard. */ export class Metrics { private static readonly writer = metricsPath - ? tfn.node.summaryFileWriter(metricsPath) + ? tfn.node.summaryFileWriter(metricsPath, 100 /*maxQueue*/) : null; private static readonly instances = new Map(); diff --git a/src/train/model/worker/ModelProtocol.ts b/src/model/worker/ModelProtocol.ts similarity index 85% rename from src/train/model/worker/ModelProtocol.ts rename to src/model/worker/ModelProtocol.ts index 005036ba..2d462b1c 100644 --- a/src/train/model/worker/ModelProtocol.ts +++ b/src/model/worker/ModelProtocol.ts @@ -5,10 +5,10 @@ import { ModelConfig, PathsConfig, TrainConfig, -} from "../../../config/types"; +} from "../../config/types"; import {SimResult} from "../../game/sim/playGame"; -import {PortMessageBase, PortResultBase} from "../../port/PortProtocol"; -import {WorkerProtocol} from "../../port/WorkerProtocol"; +import {PortMessageBase, PortResultBase} from "../../util/port/PortProtocol"; +import {WorkerProtocol} from "../../util/worker/WorkerProtocol"; /** Typings for the `workerData` object given to the model worker. */ export interface ModelWorkerData { @@ -30,7 +30,7 @@ export interface ModelProtocol /** The types of requests that can be made to the model worker. */ export type ModelRequestType = keyof ModelProtocol; -/** Types of messages that the Model can send. */ +/** Types of messages that the ModelWorker can send. */ export type ModelMessage = ModelProtocol[ModelRequestType]["message"]; /** Base interface for ModelThread messages. */ @@ -99,12 +99,6 @@ interface ModelTrainDataBase { type: T; } -/** Reports that the training episode is just starting. */ -export interface ModelTrainEpisode extends ModelTrainDataBase<"episode"> { - /** Current episode number. */ - step: number; -} - /** Data that gets reported after each rollout game. */ export interface ModelTrainRollout extends ModelTrainDataBase<"rollout"> { @@ -118,17 +112,11 @@ export interface ModelTrainRollout err?: TSerialized extends true ? Buffer : Error; } -/** Data that gets reported after each batch in the learning step. */ -export interface ModelTrainBatch extends ModelTrainDataBase<"batch"> { - /** Current batch number. */ - step: number; - /** Training loss for the batch. */ - loss: number; -} - -/** Data that gets reported after completing a learning episode. */ +/** Data that gets reported after completing a learning step. */ export interface ModelTrainLearn extends ModelTrainDataBase<"learn"> { - /** Average training loss for the episode. */ + /** Step number. */ + step: number; + /** Training loss for the step. */ loss: number; } @@ -136,7 +124,7 @@ export interface ModelTrainLearn extends ModelTrainDataBase<"learn"> { export interface ModelTrainEval extends ModelTrainDataBase<"eval">, Omit { - /** Current episode number. */ + /** Step number when eval started. */ readonly step: number; /** Unique identifier for logging. */ readonly id: number; @@ -148,9 +136,9 @@ export interface ModelTrainEval err?: TSerialized extends true ? Buffer : Error; } -/** Notification that all eval games are completed for the current episode. */ +/** Notification that the current evaluation run has been completed. */ export interface ModelTrainEvalDone extends ModelTrainDataBase<"evalDone"> { - /** Current episode number. */ + /** Step number when eval started. */ readonly step: number; /** Result counts for each opponent. */ readonly wlt: { @@ -160,9 +148,7 @@ export interface ModelTrainEvalDone extends ModelTrainDataBase<"evalDone"> { /** Data for events that get reported during training. */ export type ModelTrainData = - | ModelTrainEpisode | ModelTrainRollout - | ModelTrainBatch | ModelTrainLearn | ModelTrainEval | ModelTrainEvalDone; diff --git a/src/train/model/worker/ModelRegistry.ts b/src/model/worker/ModelRegistry.ts similarity index 81% rename from src/train/model/worker/ModelRegistry.ts rename to src/model/worker/ModelRegistry.ts index 777db3f3..90d996ad 100644 --- a/src/train/model/worker/ModelRegistry.ts +++ b/src/model/worker/ModelRegistry.ts @@ -2,16 +2,18 @@ import {serialize} from "v8"; import {MessageChannel, MessagePort} from "worker_threads"; import * as tf from "@tensorflow/tfjs"; import {ListenerSignature, TypedEmitter} from "tiny-typed-emitter"; -import {BatchPredictConfig} from "../../../config/types"; -import {modelInputShapes, verifyModel} from "../../../model/shapes"; -import {intToChoice} from "../../../psbot/handlers/battle/agent"; -import {setTimeoutNs} from "../../../util/nanosecond"; -import {RawPortResultError} from "../../port/PortProtocol"; +import {BatchPredictConfig} from "../../config/types"; +import {intToChoice} from "../../psbot/handlers/battle/agent"; +import {setTimeoutNs} from "../../util/nanosecond"; +import {RawPortResultError} from "../../util/port/PortProtocol"; +import {encodedStateToTensors, verifyModel} from "../model"; import { + ModelPortMessage, PredictMessage, PredictResult, PredictWorkerResult, } from "../port/ModelPortProtocol"; +import {modelInputShapes} from "../shapes"; import {Metrics} from "./Metrics"; import {PredictBatch} from "./PredictBatch"; @@ -26,8 +28,6 @@ interface Events extends ListenerSignature<{[predictReady]: true}> { /** Manages a neural network registry. */ export class ModelRegistry { - /** Neural network object. */ - public readonly model: tf.LayersModel; /** Currently held game worker ports. */ private readonly ports = new Set(); /** Event manager for throttling batch predict requests. */ @@ -63,48 +63,18 @@ export class ModelRegistry { * Creates a ModelRegistry. * * @param name Name of the model. - * @param model Neural network object. This registry object will own the - * model as soon as the constructor is called. + * @param model Neural network object. * @param config Configuration for batching predict requests. */ public constructor( public readonly name: string, - model: tf.LayersModel, + public readonly model: tf.LayersModel, private readonly config: BatchPredictConfig, ) { - try { - verifyModel(model); - } catch (e) { - // Cleanup model so it doesn't cause a memory leak. - model.dispose(); - throw e; - } - this.model = model; + verifyModel(model); this.events.setMaxListeners(config.maxSize); } - /** Clones the current model into a new registry with the same config. */ - public async clone(name: string): Promise { - const modelArtifact = new Promise( - res => - void this.model.save({ - save: async _modelArtifact => { - res(_modelArtifact); - return await Promise.resolve({} as tf.io.SaveResult); - }, - }), - ); - const clonedModel = await tf.loadLayersModel({ - load: async () => await Promise.resolve(modelArtifact), - }); - return new ModelRegistry(name, clonedModel, this.config); - } - - /** Saves the neural network to the given url. */ - public async save(url: string): Promise { - await this.model.save(url); - } - /** Safely closes ports and disposes the model. */ public unload(): void { for (const port of this.ports) { @@ -212,7 +182,8 @@ export class ModelRegistry { this.ports.add(port1); port1.on( "message", - (msg: PredictMessage) => + (msg: ModelPortMessage) => + msg.type === "predict" && void this.predict(msg) .then(prediction => { const result: PredictWorkerResult = { @@ -221,11 +192,7 @@ export class ModelRegistry { done: true, ...prediction, }; - port1.postMessage(result, [ - // Give the state buffer back to the calling thread. - ...result.input.map(s => s.buffer), - result.output.buffer, - ]); + port1.postMessage(result, [result.output.buffer]); }) .catch(err => { const result: RawPortResultError = { @@ -237,16 +204,10 @@ export class ModelRegistry { port1.postMessage(result, [result.err.buffer]); }), ); - // Remove this port from the recorded references after close. port1.on("close", () => this.ports.delete(port1)); return port2; } - /** Copies the weights of the current model to the given model. */ - public copyTo(other: ModelRegistry): void { - other.model.setWeights(this.model.getWeights()); - } - /** * Queues a prediction for the neural network. Can be called multiple times * while other predict requests are still queued. @@ -257,7 +218,7 @@ export class ModelRegistry { } const result = new Promise(res => - this.predictBatch.add(msg.state, res), + this.predictBatch.add(encodedStateToTensors(msg.state), res), ); await this.checkPredictBatch(); return await result; diff --git a/src/train/model/worker/ModelWorker.ts b/src/model/worker/ModelWorker.ts similarity index 98% rename from src/train/model/worker/ModelWorker.ts rename to src/model/worker/ModelWorker.ts index 670cae28..545fe878 100644 --- a/src/train/model/worker/ModelWorker.ts +++ b/src/model/worker/ModelWorker.ts @@ -6,8 +6,8 @@ import { ModelConfig, PathsConfig, TrainConfig, -} from "../../../config/types"; -import {WorkerPort} from "../../port/WorkerPort"; +} from "../../config/types"; +import {WorkerPort} from "../../util/worker/WorkerPort"; import {ModelProtocol, ModelTrainData, ModelWorkerData} from "./ModelProtocol"; /** Path to the worker script. */ diff --git a/src/train/model/worker/PredictBatch.ts b/src/model/worker/PredictBatch.ts similarity index 72% rename from src/train/model/worker/PredictBatch.ts rename to src/model/worker/PredictBatch.ts index b68c6082..bd6b613a 100644 --- a/src/train/model/worker/PredictBatch.ts +++ b/src/model/worker/PredictBatch.ts @@ -1,5 +1,5 @@ import * as tf from "@tensorflow/tfjs"; -import {PredictResult} from "../port/ModelPortProtocol"; +import {PredictResult} from "../port"; /** State+callback entries for managing batched model predict requests. */ export class PredictBatch { @@ -13,7 +13,7 @@ export class PredictBatch { */ private readonly transposedInput: tf.Tensor[][] = []; /** Resolver callbacks for each request within the batch. */ - private readonly callbacks: ((output: Float32Array) => void)[] = []; + private readonly callbacks: ((result: PredictResult) => void)[] = []; /** Corresponding times that the requests were {@link add added}. */ public get times(): readonly bigint[] { return this._times; @@ -31,9 +31,12 @@ export class PredictBatch { * @param inputShapes Input shape that the batch must conform to. Should be * an array of shapes (excluding the batch dimension) for each input that * the model receives. + * @param autoDisposeInput Whether to automatically dispose input tensors + * after the batch request is resolved. Default true. */ public constructor( private readonly inputShapes: readonly (readonly number[])[], + private readonly autoDisposeInput = true, ) { for (let i = 0; i < inputShapes.length; ++i) { this.transposedInput[i] = []; @@ -43,20 +46,24 @@ export class PredictBatch { /** * Adds a predict request to the batch. * - * @param inputs Flattened tensor data to use as inputs. + * @param inputs Tensors to use as inputs. * @param callback Called when the batch is executed and the corresponding * result is extracted. */ public add( - input: Float32Array[], + input: tf.Tensor[], callback: (result: PredictResult) => void, ): void { - for (let i = 0; i < input.length; ++i) { - this.transposedInput[i].push( - tf.tensor(input[i], this.inputShapes[i] as number[], "float32"), + if (input.length !== this.transposedInput.length) { + throw new Error( + `Expected ${this.transposedInput.length} inputs but found ` + + `${input.length}`, ); } - this.callbacks.push(output => callback({input, output})); + for (let i = 0; i < input.length; ++i) { + this.transposedInput[i].push(input[i]); + } + this.callbacks.push(callback); this._times.push(process.hrtime.bigint()); } @@ -76,12 +83,16 @@ export class PredictBatch { `${results.length}`, ); } - tf.dispose(this.transposedInput); await Promise.all( results.map(async (t, i) => { try { - this.callbacks[i](await t.data<"float32">()); + this.callbacks[i]({output: await t.data<"float32">()}); } finally { + if (this.autoDisposeInput) { + for (const input of this.transposedInput) { + input[i].dispose(); + } + } t.dispose(); } }), diff --git a/src/train/model/worker/index.ts b/src/model/worker/index.ts similarity index 100% rename from src/train/model/worker/index.ts rename to src/model/worker/index.ts diff --git a/src/train/model/worker/worker.js b/src/model/worker/worker.js similarity index 100% rename from src/train/model/worker/worker.js rename to src/model/worker/worker.js diff --git a/src/train/model/worker/worker.ts b/src/model/worker/worker.ts similarity index 88% rename from src/train/model/worker/worker.ts rename to src/model/worker/worker.ts index ce7e51b6..4dc1b86c 100644 --- a/src/train/model/worker/worker.ts +++ b/src/model/worker/worker.ts @@ -2,9 +2,10 @@ import {serialize} from "v8"; import {parentPort, TransferListItem, workerData} from "worker_threads"; import * as tf from "@tensorflow/tfjs"; -import {createModel} from "../../../model/model"; -import {importTfn} from "../../../util/tfn"; -import {RawPortResultError} from "../../port/PortProtocol"; +import {train} from "../../train/train"; +import {RawPortResultError} from "../../util/port/PortProtocol"; +import {importTfn} from "../../util/tfn"; +import {createModel} from "../model"; import { ModelMessage, ModelResult, @@ -12,7 +13,6 @@ import { ModelWorkerData, } from "./ModelProtocol"; import {ModelRegistry} from "./ModelRegistry"; -import {train} from "./train"; if (!parentPort) { throw new Error("No parent port!"); @@ -56,10 +56,15 @@ async function handle(msg: ModelMessage): Promise { const model = msg.url ? await tf.loadLayersModel(msg.url) : createModel(msg.name, msg.config, msg.seed); - models.set( - msg.name, - new ModelRegistry(msg.name, model, msg.predict), - ); + try { + models.set( + msg.name, + new ModelRegistry(msg.name, model, msg.predict), + ); + } catch (e) { + model.dispose(); + throw e; + } result = {type: "load", rid, done: true, name: msg.name}; break; } @@ -71,7 +76,7 @@ async function handle(msg: ModelMessage): Promise { } case "train": { await train( - getRegistry(msg.model), + getRegistry(msg.model).model, msg.config, msg.paths, data => { diff --git a/src/psbot/handlers/battle/BattleHandler.ts b/src/psbot/handlers/battle/BattleHandler.ts index 7c6d6286..38b080bc 100644 --- a/src/psbot/handlers/battle/BattleHandler.ts +++ b/src/psbot/handlers/battle/BattleHandler.ts @@ -87,7 +87,7 @@ export class BattleHandler /** Creates a BattleHandler. */ public constructor({ username, - parser = main as unknown as BattleParser, + parser = main, stateCtor = BattleState, agent, sender, diff --git a/src/psbot/handlers/battle/agent/BattleAgent.ts b/src/psbot/handlers/battle/agent/BattleAgent.ts index 4167a95b..f9c99e7b 100644 --- a/src/psbot/handlers/battle/agent/BattleAgent.ts +++ b/src/psbot/handlers/battle/agent/BattleAgent.ts @@ -6,14 +6,17 @@ import {Choice} from "./Choice"; * Generic function type alias that makes decisions during a battle. * * @template TInfo Optional decision info type to return. + * @template TArgs Optional additional arguments. * @param state State data for decision making. * @param choices Available choices to choose from. This method will sort the * choices array in-place from most to least preferable. * @param logger Optional logger object. + * @param args Optional additional arguments. * @returns Optional data returned after making a decision. */ -export type BattleAgent = ( +export type BattleAgent = ( state: ReadonlyBattleState, choices: Choice[], logger?: Logger, + ...args: TArgs ) => Promise; diff --git a/src/psbot/handlers/battle/ai/maxAgent.ts b/src/psbot/handlers/battle/ai/maxAgent.ts index 7e463d3b..0b4a0e4d 100644 --- a/src/psbot/handlers/battle/ai/maxAgent.ts +++ b/src/psbot/handlers/battle/ai/maxAgent.ts @@ -7,22 +7,24 @@ import {ReadonlyBattleState} from "../state"; * Creates a {@link BattleAgent} function that selects the best choice based on * the given evaluator. * - * @param getOutputValues Function for getting the output values of each choice. + * @param evaluator Function for evaluating the ranking of each choice. * @param debugRankings If true, the returned BattleAgent will also return a * debug string displaying the rankings for each choice. */ -export function maxAgent( - getOutputValues: ( +export function maxAgent( + evaluator: ( state: ReadonlyBattleState, + ...args: TArgs ) => Float32Array | Promise, debugRankings?: boolean, -): BattleAgent { +): BattleAgent { return async function ( state: ReadonlyBattleState, choices: Choice[], logger?: Logger, + ...args: TArgs ): Promise { - const output = await getOutputValues(state); + const output = await evaluator(state, ...args); logger?.debug( "Ranked choices: {" + intToChoice diff --git a/src/psbot/handlers/battle/ai/networkAgent.ts b/src/psbot/handlers/battle/ai/networkAgent.ts index 936da0d9..0b2fbeca 100644 --- a/src/psbot/handlers/battle/ai/networkAgent.ts +++ b/src/psbot/handlers/battle/ai/networkAgent.ts @@ -1,5 +1,5 @@ import * as tf from "@tensorflow/tfjs"; -import {modelInputShapes, verifyModel} from "../../../../model/shapes"; +import {encodedStateToTensors, verifyModel} from "../../../../model/model"; import {BattleAgent} from "../agent"; import {allocEncodedState, encodeState} from "./encoder"; import {maxAgent} from "./maxAgent"; @@ -47,28 +47,3 @@ export function networkAgent( return outputData; }, debugRankings); } - -/** - * Converts the data lists into tensors - * - * @param includeBatchDim Whether to include an extra 1 dimension in the first - * axis for the batch. Default false. - */ -export function encodedStateToTensors( - arr: Float32Array[], - includeBatchDim?: boolean, -): tf.Tensor[] { - if (arr.length !== modelInputShapes.length) { - throw new Error( - `Expected ${modelInputShapes.length} inputs but found ` + - `${arr.length}`, - ); - } - return modelInputShapes.map((shape, i) => - tf.tensor( - arr[i], - includeBatchDim ? [1, ...shape] : [...shape], - "float32", - ), - ); -} diff --git a/src/train/model/worker/Evaluate.ts b/src/train/Evaluate.ts similarity index 86% rename from src/train/model/worker/Evaluate.ts rename to src/train/Evaluate.ts index 29b6cb38..9f30b189 100644 --- a/src/train/model/worker/Evaluate.ts +++ b/src/train/Evaluate.ts @@ -1,14 +1,14 @@ import {join} from "path"; -import {EvalConfig} from "../../../config/types"; +import {EvalConfig} from "../config/types"; import { GameArgsGenOptions, GameArgsGenSeeders, GamePipeline, GamePoolArgs, GamePoolResult, -} from "../../game/pool"; -import {Metrics} from "./Metrics"; -import {ModelRegistry} from "./ModelRegistry"; +} from "../game/pool"; +import {Metrics} from "../model/worker/Metrics"; +import {ModelRegistry} from "../model/worker/ModelRegistry"; /** * Encapsulates the evaluation step of training, where the model plays games @@ -43,24 +43,29 @@ export class Evaluate { } /** Closes game threads. */ - public async cleanup(): Promise { - return await this.games.cleanup(); + public async close(): Promise { + return await this.games.close(); + } + + /** Force-closes game threads. */ + public async terminate(): Promise { + return await this.games.terminate(); } /** * Runs the evaluation step on the current model versions. * - * @param step Current episode step. + * @param step Current learning step. * @param callback Called for each game result. */ public async run( step: number, - callback?: (result: GamePoolResult) => void | Promise, + callback?: (result: GamePoolResult) => void, ): Promise<{[vs: string]: {win: number; loss: number; tie: number}}> { const wlts: { [vs: string]: {win: number; loss: number; tie: number}; } = {}; - await this.games.run(this.genArgs(step), async result => { + await this.games.run(this.genArgs(step), result => { const wlt = (wlts[result.agents[1]] ??= { win: 0, loss: 0, @@ -73,7 +78,7 @@ export class Evaluate { } else { ++wlt.tie; } - await callback?.(result); + callback?.(result); }); if (this.metrics) { @@ -121,7 +126,7 @@ export class Evaluate { }, numGames: this.config.numGames, ...(this.logPath !== undefined && { - logPath: join(this.logPath, `episode-${step}`), + logPath: join(this.logPath, `step-${step}`), }), ...(this.config.pool.reduceLogs && {reduceLogs: true}), ...(this.seeders && {seeders: this.seeders}), diff --git a/src/train/Learn.ts b/src/train/Learn.ts new file mode 100644 index 00000000..4fc4bff8 --- /dev/null +++ b/src/train/Learn.ts @@ -0,0 +1,232 @@ +import * as tf from "@tensorflow/tfjs"; +import {ExperienceConfig, LearnConfig} from "../config/types"; +import {BatchTensorExperience} from "../game/experience/tensor"; +import {Metrics} from "../model/worker/Metrics"; +import {intToChoice} from "../psbot/handlers/battle/agent"; + +/** + * Encapsulates the learning step of training, where the model is updated based + * on experience generated by rollout games. + */ +export class Learn { + /** Metrics logger. */ + private readonly metrics = Metrics.get(`${this.name}/learn`); + /** Used for calculating gradients. */ + private readonly optimizer = tf.train.sgd(this.config.learningRate); + /** Collection of trainable variables in the model. */ + private readonly variables = this.model.trainableWeights.map( + w => w.read() as tf.Variable, + ); + /** Used for logging inputs during loss calcs. */ + private readonly hookLayers: readonly tf.layers.Layer[] = + this.model.layers.filter(l => + ["Dense", "SetAttention", "SetMultiHeadAttention"].includes( + l.getClassName(), + ), + ); + + /** + * Creates a Learn object. + * + * @param name Name of the training run for logging. + * @param model Model to train. + * @param targetModel Model for computing TD targets. Can be set to the same + * model to disable target model mechanism. + * @param config Learning config. + * @param expConfig Experience config for computing TD targets. + */ + public constructor( + public readonly name: string, + private readonly model: tf.LayersModel, + private readonly targetModel: tf.LayersModel, + private readonly config: LearnConfig, + private readonly expConfig: ExperienceConfig, + ) { + // Log initial weights. + for (const weights of this.variables) { + if (weights.size === 1) { + const weightScalar = weights.asScalar(); + this.metrics?.scalar( + `${weights.name}/weights`, + weightScalar, + 0, + ); + tf.dispose(weightScalar); + } else { + this.metrics?.histogram(`${weights.name}/weights`, weights, 0); + } + } + } + + /** + * Performs a single batch update step. + * + * @param step Step number for logging. + * @param batch Batch to train on. + * @returns The loss for this batch. + */ + public step(step: number, batch: BatchTensorExperience): tf.Scalar { + return tf.tidy(() => { + const preStep = process.hrtime.bigint(); + const storeBatchMetrics = step % this.config.metricsInterval === 0; + + const target = this.calculateTarget( + batch.reward, + batch.nextState, + batch.done, + ); + + const hookedInputs: {[name: string]: tf.Tensor1D[]} = {}; + if (storeBatchMetrics) { + for (const layer of this.hookLayers) { + // Note: Call hook is wrapped in tf.tidy() so tf.keep() is + // used to extract training inputs. + layer.setCallHook(function logInputs(inputs) { + if (!Array.isArray(inputs)) { + inputs = [inputs]; + } + for (let i = 0; i < inputs.length; ++i) { + const input = inputs[i].flatten(); + let name = `${layer.name}/input`; + if (inputs.length > 1) { + name += `/${i}`; + } + (hookedInputs[name] ??= []).push(tf.keep(input)); + } + }); + } + } + + const {value: loss, grads} = this.optimizer.computeGradients( + () => this.loss(batch.state, batch.action, target), + this.variables, + ); + this.optimizer.applyGradients(grads); + + const postStep = process.hrtime.bigint(); + const updateMs = Number(postStep - preStep) / 1e6; + this.metrics?.scalar("update_ms", updateMs, step); + this.metrics?.scalar( + "update_throughput_s", + this.config.batchSize / + (updateMs / 1e3) /*experiences per sec*/, + step, + ); + + this.metrics?.scalar("loss", loss, step); + if (storeBatchMetrics) { + for (const name in grads) { + if (Object.prototype.hasOwnProperty.call(grads, name)) { + const grad = grads[name]; + if (grad.size === 1) { + this.metrics?.scalar( + `${name}/grads`, + grad.asScalar(), + step, + ); + } else { + this.metrics?.histogram( + `${name}/grads`, + grad, + step, + ); + } + } + } + + for (const name in hookedInputs) { + if ( + Object.prototype.hasOwnProperty.call(hookedInputs, name) + ) { + this.metrics?.histogram( + name, + tf.concat1d(hookedInputs[name]), + step, + ); + } + } + tf.dispose(hookedInputs); + + for (const weights of this.variables) { + if (weights.size === 1) { + this.metrics?.scalar( + `${weights.name}/weights`, + weights.asScalar(), + step, + ); + } else { + this.metrics?.histogram( + `${weights.name}/weights`, + weights, + step, + ); + } + } + + for (const layer of this.hookLayers) { + layer.clearCallHook(); + } + } + + return loss; + }); + } + + /** Calculates TD target for an experience batch. */ + private calculateTarget( + reward: tf.Tensor, + nextState: tf.Tensor[], + done: tf.Tensor, + ): tf.Tensor { + return tf.tidy(() => { + let targetQ: tf.Tensor; + const q = this.model.predictOnBatch(nextState) as tf.Tensor; + if (!this.config.target) { + // Vanilla DQN TD target: r + gamma * max_a(Q(s', a)) + targetQ = tf.max(q, -1); + } else { + targetQ = this.targetModel.predictOnBatch( + nextState, + ) as tf.Tensor; + if (this.config.target !== "double") { + // TD target with target net: r + gamma * max_a(Qt(s', a)) + targetQ = tf.max(targetQ, -1); + } else { + // Double Q target: r + gamma * Qt(s', argmax_a(Q(s', a))) + const action = tf.argMax(q, -1); + const actionMask = tf.oneHot(action, intToChoice.length); + targetQ = tf.sum(tf.mul(targetQ, actionMask), -1); + } + } + + // Also mask out q values of terminal states. + targetQ = tf.where(done, 0, targetQ); + + const target = tf.add( + reward, + tf.mul(this.expConfig.rewardDecay, targetQ), + ); + return target; + }); + } + + /** Calculates training loss on an experience batch. */ + private loss( + state: tf.Tensor[], + action: tf.Tensor, + target: tf.Tensor, + ): tf.Scalar { + return tf.tidy("loss", () => { + let q = this.model.predictOnBatch(state) as tf.Tensor; + const mask = tf.oneHot(action, intToChoice.length); + q = tf.sum(tf.mul(q, mask), -1); + return tf.losses.meanSquaredError(target, q); + }); + } + + /** Cleans up dangling variables. */ + public cleanup(): void { + this.optimizer.dispose(); + Metrics.flush(); + } +} diff --git a/src/train/README.md b/src/train/README.md index 502cde88..f0f6689b 100644 --- a/src/train/README.md +++ b/src/train/README.md @@ -13,17 +13,18 @@ that can be taken. ## Algorithm - Load the neural network to be trained, or create one if it doesn't exist. -- Run the [training loop](model/worker/train.ts): - - Use a [thread pool](game/pool/GamePool.ts) to queue up multiple +- Run the [training loop](train.ts): + - Use a [thread pool](../game/pool/GamePool.ts) to queue up multiple self-play games and collect reward/state data, called - [experience](game/experience/Experience.ts). This is called the - [rollout](model/worker/Rollout.ts) stage. - - Asynchronously feed experience into the [learner](model/worker/Learn.ts) - to compute the model update. - - Play [evaluation](model/worker/Evaluate.ts) games against previous + [experience](../game/experience/Experience.ts). This is called the + [rollout](Rollout.ts) stage. + - Feed experience into the [learner](Learn.ts) to compute model updates. + - Periodically play [evaluation](Evaluate.ts) games against previous versions to track progress. -The entire algorithm takes place on a separate [thread](model/worker/worker.ts) -where all the Tensorflow operations are kept and -[managed](model/worker/ModelWorker.ts) by the main thread. It also manages batch -predictions for parallel games. +The entire algorithm takes place on a separate +[thread](../model/worker/worker.ts) where all the Tensorflow operations +(including batch predictions for parallel games) are kept and +[managed](../model/worker/ModelWorker.ts) asynchronously by the main thread. The +main thread is also used to [track](TrainingProgress.ts) training progress with +a progress bar if configured for it. diff --git a/src/train/ReplayBuffer.ts b/src/train/ReplayBuffer.ts new file mode 100644 index 00000000..49c812dc --- /dev/null +++ b/src/train/ReplayBuffer.ts @@ -0,0 +1,124 @@ +import * as tf from "@tensorflow/tfjs"; +import { + BatchTensorExperience, + TensorExperience, +} from "../game/experience/tensor"; +import {modelInputShapes} from "../model/shapes"; +import {Rng} from "../util/random"; + +/** Experience replay buffer, implemented as a circular buffer. */ +export class ReplayBuffer { + // Transpose and stack buffered experiences to make batching easier. + private readonly states = Array.from( + modelInputShapes, + () => new Array(this.maxSize), + ); + private readonly actions = new Array(this.maxSize); + private readonly rewards = new Array(this.maxSize); + private readonly nextStates = Array.from( + modelInputShapes, + () => new Array(this.maxSize), + ); + private readonly dones = new Array(this.maxSize); + + private start = 0; + + /** Current length of the buffer. */ + public get length(): number { + return this._length; + } + private _length = 0; + + /** + * Creates a ReplayBuffer. + * + * @param maxSize Size of the buffer. + */ + public constructor(public readonly maxSize: number) {} + + /** + * Adds a new Experience to the buffer. If full, the oldest one is + * discarded. + */ + public add(exp: TensorExperience): void { + let i: number; + if (this._length < this.maxSize) { + i = this._length++; + } else { + i = this.start++; + if (this.start >= this.maxSize) { + this.start = 0; + } + } + + for (let s = 0; s < modelInputShapes.length; ++s) { + this.states[s][i]?.dispose(); + this.states[s][i] = exp.state[s]; + this.nextStates[s][i]?.dispose(); + this.nextStates[s][i] = exp.nextState[s]; + } + this.actions[i]?.dispose(); + this.actions[i] = exp.action; + this.rewards[i]?.dispose(); + this.rewards[i] = exp.reward; + this.dones[i]?.dispose(); + this.dones[i] = exp.done; + } + + /** + * Samples a batch of experiences from the buffer with replacement. + * + * @param size Size of sample. + * @param random Controlled random. + */ + public sample( + size: number, + random: Rng = Math.random, + ): BatchTensorExperience { + if (size > this._length) { + throw new Error( + `Requested batch size ${size} is too big for current ` + + `ReplayBuffer size ${this._length}`, + ); + } + // Knuth's algorithm for sampling without replacement. + const states = Array.from( + modelInputShapes, + () => new Array(size), + ); + const actions = new Array(size); + const rewards = new Array(size); + const nextStates = Array.from( + modelInputShapes, + () => new Array(size), + ); + const dones = new Array(size); + for (let t = 0, m = 0; m < size; ++t) { + if ((this._length - t) * random() < size - m) { + for (let s = 0; s < modelInputShapes.length; ++s) { + states[s][m] = this.states[s][t]; + nextStates[s][m] = this.nextStates[s][t]; + } + actions[m] = this.actions[t]; + rewards[m] = this.rewards[t]; + dones[m] = this.dones[t]; + ++m; + } + } + return tf.tidy(() => ({ + state: modelInputShapes.map((_, s) => tf.stack(states[s])), + action: tf.stack(actions).as1D(), + reward: tf.stack(rewards).as1D(), + nextState: modelInputShapes.map((_, s) => tf.stack(nextStates[s])), + done: tf.stack(dones).as1D(), + })); + } + + /** Disposes tensors left in the buffer. */ + public dispose(): void { + tf.dispose(this.states); + tf.dispose(this.actions); + tf.dispose(this.rewards); + tf.dispose(this.nextStates); + } +} diff --git a/src/train/model/worker/Rollout.ts b/src/train/Rollout.ts similarity index 63% rename from src/train/model/worker/Rollout.ts rename to src/train/Rollout.ts index a00e9390..2a0e2656 100644 --- a/src/train/model/worker/Rollout.ts +++ b/src/train/Rollout.ts @@ -1,17 +1,17 @@ -import {PassThrough} from "stream"; -import {RolloutConfig} from "../../../config/types"; -import {rng, Seeder} from "../../../util/random"; -import {Experience} from "../../game/experience"; +import {RolloutConfig} from "../config/types"; import { GameArgsGenOptions, GameArgsGenSeeders, GamePipeline, GamePoolArgs, GamePoolResult, -} from "../../game/pool"; -import {Metrics} from "./Metrics"; -import {ModelRegistry} from "./ModelRegistry"; +} from "../game/pool"; +import {Metrics} from "../model/worker/Metrics"; +import {ModelRegistry} from "../model/worker/ModelRegistry"; +import {rng, Seeder} from "../util/random"; +import {RolloutModel} from "./RolloutModel"; +/** Seeders for {@link Rollout}. */ export interface RolloutSeeders extends GameArgsGenSeeders { /** Random seed generator for opponent selection. */ readonly rollout?: Seeder; @@ -29,9 +29,9 @@ export class Rollout { /** Current exploration factor for the agent. */ private readonly exploration: {factor: number}; - /** Counter for number of games played for the episode. */ + /** Counter for number of games played for the training run. */ private numGames = 0; - /** Number of ties during the episode. */ + /** Number of ties during the training run. */ private numTies = 0; /** @@ -48,7 +48,7 @@ export class Rollout { */ public constructor( public readonly name: string, - private readonly model: ModelRegistry, + private readonly rolloutModel: RolloutModel, private readonly prevModel: ModelRegistry, private readonly config: RolloutConfig, private readonly logPath?: string, @@ -58,46 +58,31 @@ export class Rollout { this.exploration = {factor: config.policy.exploration}; } - /** Closes game threads. */ - public async cleanup(): Promise { - return await this.games.cleanup(); + /** Force-closes game threads. */ + public async terminate(): Promise { + return await this.games.terminate(); } - /** Generator for getting experience data from the training games. */ - public async *gen( + /** + * Runs the rollout stage. + * + * @param callback Called for each game result. + */ + public async run( callback?: (result: GamePoolResult) => void, - ): AsyncGenerator { - const stream = new PassThrough({objectMode: true, highWaterMark: 1}); - - const run = this.games - .run( - this.genArgs(exp => stream.write(exp)), - async result => { - ++this.numGames; - if (result.winner === undefined) { - ++this.numTies; - } - callback?.(result); - // Need to check for backpressure. - if (!stream.write(undefined)) { - await new Promise(res => stream.once("drain", res)); - } - }, - ) - .catch(e => void stream.emit("error", e)) - .finally(() => stream.end()); - - for await (const exp of stream) { - if (exp) { - yield exp as Experience; + ): Promise { + await this.games.run(this.genArgs(), result => { + ++this.numGames; + if (result.winner === undefined) { + ++this.numTies; } - } - await run; + callback?.(result); + }); } /** * Updates the exploration rate for future games and logs metrics to prepare - * for the next episode. + * for the next learning step. */ public step(step: number): void { this.metrics?.scalar("exploration", this.exploration.factor, step); @@ -106,20 +91,22 @@ export class Rollout { this.exploration.factor = this.config.policy.minExploration; } - this.metrics?.scalar("num_games", this.numGames, step); - this.metrics?.scalar("tie_rate", this.numTies / this.numGames, step); - this.numGames = 0; - this.numTies = 0; + if (this.numGames > 0) { + this.metrics?.scalar("total_games", this.numGames, step); + this.metrics?.scalar( + "tie_ratio", + this.numTies / this.numGames, + step, + ); + } } /** Generates game configs for the thread pool. */ - private *genArgs( - experienceCallback?: (exp: Experience) => void, - ): Generator { + private *genArgs(): Generator { const opts: GameArgsGenOptions = { agentConfig: { name: "rollout", - exploit: {type: "model", model: this.model.name}, + exploit: {type: "model", model: this.rolloutModel.name}, // Use object reference so that step() updates with the new rate // for newly-created games explore: this.exploration, @@ -127,14 +114,14 @@ export class Rollout { }, opponent: { name: "self", - exploit: {type: "model", model: this.model.name}, + exploit: {type: "model", model: this.rolloutModel.name}, explore: this.exploration, emitExperience: true, }, requestModelPort: (model: string) => { switch (model) { - case this.model.name: - return this.model.subscribe(); + case this.rolloutModel.name: + return this.rolloutModel.subscribe(); case this.prevModel.name: return this.prevModel.subscribe(); default: @@ -144,7 +131,6 @@ export class Rollout { ...(this.logPath !== undefined && {logPath: this.logPath}), ...(this.config.pool.reduceLogs && {reduceLogs: true}), ...(this.seeders && {seeders: this.seeders}), - ...(experienceCallback && {experienceCallback}), }; const gen = GamePipeline.genArgs(opts); diff --git a/src/train/RolloutModel.ts b/src/train/RolloutModel.ts new file mode 100644 index 00000000..eb6cca14 --- /dev/null +++ b/src/train/RolloutModel.ts @@ -0,0 +1,307 @@ +// eslint-disable-next-line max-classes-per-file +import {serialize} from "v8"; +import {MessageChannel, MessagePort} from "worker_threads"; +import * as tf from "@tensorflow/tfjs"; +import {ListenerSignature, TypedEmitter} from "tiny-typed-emitter"; +import {TensorExperience} from "../game/experience/tensor"; +import {encodedStateToTensors, verifyModel} from "../model/model"; +import { + PredictMessage, + PredictWorkerResult, + PredictResult, + ModelPortMessage, + ModelPortResult, + FinalizeMessage, + FinalizeResult, +} from "../model/port"; +import {modelInputShapes} from "../model/shapes"; +import {PredictBatch} from "../model/worker/PredictBatch"; +import {intToChoice} from "../psbot/handlers/battle/agent"; +import {RawPortResultError} from "../util/port/PortProtocol"; + +/** Event for when a model prediction is requested. */ +const predictRequest = Symbol("predictRequest"); + +/** Defines events that the RolloutModel implements. */ +interface Events extends ListenerSignature<{[predictRequest]: true}> { + /** When the model is requested to make a prediction. */ + readonly [predictRequest]: () => void; +} + +/** + * Wraps a model for use in the rollout stage of training. Manages batched + * synchronized time steps with both the learn stage and the rollout game + * threads. + */ +export class RolloutModel { + /** Currently held game worker ports. */ + private readonly ports = new Map(); + /** Event manager for predict requests. */ + private readonly events = new TypedEmitter(); + + /** Current pending predict request batch. */ + private predictBatch = new PredictBatch( + modelInputShapes, + false /*autoDisposeInput*/, + ); + + /** + * Stores experiences that haven't yet been emitted into the replay buffer. + */ + private experienceBuffer: TensorExperience[] = []; + + /** + * Creates a RolloutModel object. + * + * @param name Name of the model. + * @param model Model to wrap. Does not assume ownership since it's assumed + * that a ModelRegistry already owns it. + */ + public constructor( + public readonly name: string, + public readonly model: tf.LayersModel, + ) { + verifyModel(model); + } + + /** Safely closes ports. */ + public unload(): void { + for (const [port] of this.ports) { + port.close(); + } + } + + /** + * Indicates that a game worker is subscribing to a model. + * + * @returns A port for queueing predictions that the game worker will use. + */ + public subscribe(): MessagePort { + const ec = new ExperienceContext(exp => + this.experienceBuffer.push(exp), + ); + const {port1, port2} = new MessageChannel(); + this.ports.set(port1, ec); + port1.on( + "message", + (msg: ModelPortMessage) => + void this.handle(msg, ec) + .then(result => { + port1.postMessage( + result, + result.type === "predict" + ? [result.output.buffer] + : [], + ); + }) + .catch(err => { + const result: RawPortResultError = { + type: "error", + rid: msg.rid, + done: true, + err: serialize(err), + }; + port1.postMessage(result, [result.err.buffer]); + }), + ); + port1.on("close", () => this.ports.delete(port1)); + return port2; + } + + /** Handles a ModelPort message. */ + private async handle( + msg: ModelPortMessage, + ec: ExperienceContext, + ): Promise { + switch (msg.type) { + case "predict": + return await this.predict(msg, ec); + case "finalize": + return this.finalize(msg, ec); + } + } + + /** + * Queues a prediction for the neural network. Can be called multiple times + * while other predict requests are still queued. + */ + private async predict( + msg: PredictMessage, + ec: ExperienceContext, + ): Promise { + const state = encodedStateToTensors(msg.state); + ec.add(state, msg.lastAction, msg.reward); + + const result = await new Promise(res => { + this.predictBatch.add(state, res); + this.events.emit(predictRequest); + }); + return { + type: "predict", + rid: msg.rid, + done: true, + ...result, + }; + } + + /** Finalizes experience generation. */ + private finalize( + msg: FinalizeMessage, + ec: ExperienceContext, + ): FinalizeResult { + const state = msg.state && encodedStateToTensors(msg.state); + const lastAction = + msg.lastAction !== undefined + ? tf.scalar(msg.lastAction, "int32") + : undefined; + const reward = + msg.reward !== undefined + ? tf.scalar(msg.reward, "float32") + : undefined; + ec.finalize(state, lastAction, reward); + return { + type: "finalize", + rid: msg.rid, + done: true, + }; + } + + /** + * Flushes the predict buffer and executes the queued batch predict requests + * from the game thread pool, returning the generated experiences from each + * request. + */ + public async step(): Promise { + while (this.predictBatch.length <= 0) { + await new Promise(res => + this.events.once(predictRequest, res), + ); + } + // Give time for rollout game threads to make predict requests. + // Without this, all batch predicts will just have size=1. + await tf.nextFrame(); + + const batch = this.predictBatch; + this.predictBatch = new PredictBatch( + modelInputShapes, + false /*autoDisposeInput*/, + ); + + await batch.resolve( + tf.tidy(() => + (this.model.predictOnBatch(batch.toTensors()) as tf.Tensor) + .as2D(batch.length, intToChoice.length) + .unstack(), + ), + ); + + const exps = this.experienceBuffer; + this.experienceBuffer = []; + return exps; + } +} + +/** Tracks Experience generation for one side of a game. */ +class ExperienceContext { + private lastState?: tf.Tensor[]; + private lastAction?: tf.Scalar; + private lastReward?: tf.Scalar; + + /** + * Creates an ExperienceContext. + * + * @param callback Callback to emit experiences. + */ + public constructor( + private readonly callback: (exp: TensorExperience) => void, + ) {} + + /** + * Adds data for generating experience. + * + * @param state Resultant state. + * @param action Action used to get to state. + * @param reward Net reward from state transition. + */ + public add(state: tf.Tensor[], action?: number, reward?: number): void { + if (!this.lastState) { + this.lastState = state; + return; + } + if (action === undefined) { + throw new Error( + "Predict requests after first must include previous action", + ); + } + if (reward === undefined) { + throw new Error( + "Predict requests after first must include previous reward", + ); + } + const {lastState} = this; + this.lastState = state; + this.lastAction?.dispose(); + this.lastAction = tf.scalar(action, "int32"); + this.lastReward?.dispose(); + this.lastReward = tf.scalar(reward, "float32"); + this.callback({ + state: lastState, + action: this.lastAction.clone(), + reward: this.lastReward.clone(), + nextState: state.map(t => t.clone()), + done: tf.scalar(false), + }); + } + + /** + * Generates the final experience for the game. + * + * @param state Final state. If omitted, no experience should be generated. + * @param action Action before final state. + * @param reward Final reward. + */ + public finalize( + state?: tf.Tensor[], + action?: tf.Scalar, + reward?: tf.Scalar, + ): void { + if (!state) { + // Game was forced to end abruptly. + tf.dispose(this.lastState); + this.lastAction?.dispose(); + this.lastReward?.dispose(); + action?.dispose(); + reward?.dispose(); + return; + } + if (!this.lastState) { + throw new Error("No last state"); + } + if (!action) { + action = this.lastAction; + if (!action) { + throw new Error("No last action provided"); + } + } else { + this.lastAction?.dispose(); + this.lastAction = undefined; + } + if (!reward) { + reward = this.lastReward; + if (!reward) { + throw new Error("No last reward provided"); + } + } else { + this.lastReward?.dispose(); + this.lastReward = undefined; + } + this.callback({ + state: this.lastState, + action, + reward, + nextState: state, + done: tf.scalar(true), + }); + this.lastState = undefined; + } +} diff --git a/src/train/TrainingProgress.ts b/src/train/TrainingProgress.ts index 18a777c4..f7b6433b 100644 --- a/src/train/TrainingProgress.ts +++ b/src/train/TrainingProgress.ts @@ -1,36 +1,36 @@ import ProgressBar from "progress"; -import {Config} from "../config/types"; -import {Logger} from "../util/logging/Logger"; +import {TrainConfig} from "../config/types"; import { - ModelTrainBatch, ModelTrainData, - ModelTrainEpisode, ModelTrainEval, ModelTrainEvalDone, ModelTrainLearn, ModelTrainRollout, -} from "./model/worker"; +} from "../model/worker"; +import {formatUptime, numDigits} from "../util/format"; +import {Logger} from "../util/logging/Logger"; /** Handles logging and progress bars during the training loop. */ export class TrainingProgress { - private learnProgress: ProgressBar | undefined; - private readonly stepPadding = numDigits(this.config.train.episodes); - private readonly batchPadding = numDigits(this.config.train.learn.updates); - private readonly lossDigits = 8; - private readonly learnPadding = - "Batch / ".length + - 2 * this.batchPadding + - " loss=-0.".length + - this.lossDigits + - 2; + private progress: ProgressBar | undefined; + private readonly stepPadding = this.config.steps + ? numDigits(this.config.steps) + : 0; + private static readonly lossDigits = 8; + private readonly progressPadding = + "Step / ".length + + 2 * this.stepPadding + + " loss=0.".length + + TrainingProgress.lossDigits + + " eta=00d00h00m00s".length + + 1; + private startTime: number | undefined; /** Total rollout games played. */ private numRolloutGames = 0; /** Logger object. */ private readonly logger: Logger; - /** Logger with episode counter. */ - private episodeLogger: Logger; /** * Creates a Train object. @@ -38,19 +38,39 @@ export class TrainingProgress { * @param config Training config. * @param logger Logger object. Note that logging function is ignored. */ - public constructor(private readonly config: Config, logger: Logger) { + public constructor(private readonly config: TrainConfig, logger: Logger) { this.logger = logger.withFunc(msg => this.log(msg)); - this.episodeLogger = this.logger.addPrefix(this.episodePrefix(0)); + + if (this.config.progress && this.config.steps) { + this.progress = new ProgressBar( + this.logger.prefix + + "Step :step/:total :bar loss=:loss eta=:est", + { + total: this.config.steps, + head: ">", + clear: true, + width: + (process.stderr.columns || 80) - + this.logger.prefix.length - + this.progressPadding, + }, + ); + this.progress.render({ + step: "0".padStart(this.stepPadding), + loss: (0).toFixed(TrainingProgress.lossDigits), + est: "0s", + }); + } } /** Logging function that guards for progress bars. */ private log(msg: string): void { - if (this.learnProgress && !this.learnProgress.complete) { + if (this.progress && !this.progress.complete) { if (msg.endsWith("\n")) { // Account for extra newline inserted by interrupt. msg = msg.slice(0, -1); } - this.learnProgress.interrupt(msg); + this.progress.interrupt(msg); } else { Logger.stderr(msg); } @@ -59,10 +79,6 @@ export class TrainingProgress { /** Callback for events during the training loop. */ public callback(data: ModelTrainData): void { switch (data.type) { - case "episode": - return this.episode(data); - case "batch": - return this.batch(data); case "learn": return this.learn(data); case "rollout": @@ -81,51 +97,38 @@ export class TrainingProgress { } } - /** Called at the beginning of each training episode. */ - private episode(data: ModelTrainEpisode): void { - this.episodeLogger = this.logger.addPrefix( - this.episodePrefix(data.step), - ); - this.learnProgress = new ProgressBar( - this.episodeLogger.prefix + "Batch :batch/:total :bar loss=:loss", - { - stream: undefined, - total: this.config.train.learn.updates, - head: ">", - clear: true, - width: - (process.stderr.columns || 80) - - this.episodeLogger.prefix.length - - this.learnPadding, - }, - ); - this.learnProgress.render({ - batch: "0".padStart(this.batchPadding), - loss: "n/a", - }); - } - /** Called after processing a learning batch. */ - private batch(data: ModelTrainBatch): void { - this.learnProgress?.tick({ - batch: String(data.step).padStart(this.batchPadding), - loss: data.loss.toFixed(this.lossDigits), - }); - } - - /** Called after processing all learning batches for the current episode. */ private learn(data: ModelTrainLearn): void { - this.episodeLogger.addPrefix("Learn: ").info(`Loss = ${data.loss}`); - this.learnProgress?.terminate(); - this.learnProgress = undefined; + if (this.progress) { + let est: string; + if (!this.startTime) { + this.startTime = process.uptime(); + est = "n/a"; + } else { + this.startTime ??= process.uptime(); + const elapsed = process.uptime() - this.startTime; + const eta = + this.progress.curr >= this.progress.total + ? 0 + : elapsed * + (this.progress.total / this.progress.curr - 1); + est = formatUptime(eta); + } + this.progress.tick({ + step: String(data.step).padStart(this.stepPadding), + loss: data.loss.toFixed(TrainingProgress.lossDigits), + est, + }); + } else { + this.logger + .addPrefix(this.stepPrefix(data.step)) + .info(`Loss = ${data.loss}`); + } } /** Called after each rollout game. */ private rollout(data: ModelTrainRollout): void { - // Note that this indicates a game was completed but not necessarily - // that the collected data has made it into the learning step just yet. ++this.numRolloutGames; - if (data.err) { this.logger .addPrefix("Rollout: ") @@ -140,18 +143,18 @@ export class TrainingProgress { private eval(data: ModelTrainEval): void { if (data.err) { this.logger - .addPrefix(this.episodePrefix(data.step)) + .addPrefix(this.stepPrefix(data.step)) .error( `Evaluation game ${data.id} threw an error: ` + `${data.err.stack ?? data.err.toString()}`, ); } - // TODO: Stacked progress for eval opponents along with next learn bar? + // TODO: Stacked progress for eval opponents? } - /** Called after all evaluation games for an episode. */ + /** Called after finishing an evaluation run. */ private evalDone(data: ModelTrainEvalDone): void { - const logger = this.logger.addPrefix(this.episodePrefix(data.step)); + const logger = this.logger.addPrefix(this.stepPrefix(data.step)); for (const vs in data.wlt) { if (Object.prototype.hasOwnProperty.call(data.wlt, vs)) { const wlt = data.wlt[vs]; @@ -162,22 +165,23 @@ export class TrainingProgress { } } - private episodePrefix(step: number): string { - return `Episode(${String(step).padStart(this.stepPadding)}/${ - this.config.train.episodes - }): `; + /** Log prefix with step number. */ + private stepPrefix(step: number): string { + if (!this.config.steps) { + return `Step(${step}): `; + } + return ( + `Step(${String(step).padStart(this.stepPadding)}/` + + `${this.config.steps}): ` + ); } /** Prints short summary of completed training session. */ public done(): void { - this.learnProgress?.terminate(); - this.learnProgress = undefined; + this.progress?.terminate(); + this.progress = undefined; this.logger .addPrefix("Rollout: ") .info(`Total games = ${this.numRolloutGames}`); } } - -function numDigits(n: number): number { - return Math.max(1, Math.ceil(Math.log10(n))); -} diff --git a/src/train/game/agent/random.ts b/src/train/game/agent/random.ts deleted file mode 100644 index 12cc56ce..00000000 --- a/src/train/game/agent/random.ts +++ /dev/null @@ -1,33 +0,0 @@ -import {Choice} from "../../../psbot/handlers/battle/agent"; -import { - allocEncodedState, - encodeState, -} from "../../../psbot/handlers/battle/ai/encoder"; -import {ReadonlyBattleState} from "../../../psbot/handlers/battle/state"; -import {Rng, shuffle} from "../../../util/random"; -import {ModelPort} from "../../model/port"; -import {ExperienceAgentData} from "../experience"; - -/** BattleAgent that chooses actions randomly. */ -export async function randomAgent( - state: ReadonlyBattleState, - choices: Choice[], - random?: Rng, -): Promise { - shuffle(choices, random); - return await Promise.resolve(); -} - -/** ExperienceAgent that chooses actions randomly. */ -export async function randomExpAgent( - state: ReadonlyBattleState, - choices: Choice[], - random?: Rng, -): Promise { - await randomAgent(state, choices, random); - - const data = allocEncodedState(); - encodeState(data, state); - ModelPort.verifyInput(data); - return await Promise.resolve({state: data}); -} diff --git a/src/train/game/sim/ps/experienceBattleParser.ts b/src/train/game/sim/ps/experienceBattleParser.ts deleted file mode 100644 index ddec07ad..00000000 --- a/src/train/game/sim/ps/experienceBattleParser.ts +++ /dev/null @@ -1,112 +0,0 @@ -import { - BattleAgent, - Choice, - choiceIds, -} from "../../../../psbot/handlers/battle/agent"; -import { - allocEncodedState, - encodeState, -} from "../../../../psbot/handlers/battle/ai/encoder"; -import { - BattleParser, - BattleParserContext, -} from "../../../../psbot/handlers/battle/parser/BattleParser"; -import {ReadonlyBattleState} from "../../../../psbot/handlers/battle/state"; -import { - Experience, - ExperienceAgent, - ExperienceAgentData, -} from "../../experience/Experience"; - -/** - * Wraps a BattleParser to track rewards/decisions and emit Experience objects. - * - * Returned wrapper requires an {@link ExperienceAgent}. - * - * @template TArgs Parser arguments. - * @template TResult Parser return type. - * @param parser Parser function to wrap. - * @param callback Callback for emitting Experience objs. - * @param username Client's username to parse game-over reward. - * @returns The wrapped BattleParser function. - */ -export function experienceBattleParser< - TArgs extends unknown[] = unknown[], - TResult = unknown, ->( - parser: BattleParser, - callback: (exp: Experience) => void, - username: string, -): BattleParser { - return async function experienceBattleParserImpl( - ctx: BattleParserContext, - ...args: TArgs - ): Promise { - let expAgentData: ExperienceAgentData | null = null; - let lastChoice: Choice | null = null; - let reward = 0; - function emitExperience(state: ReadonlyBattleState, done = false) { - if (!expAgentData || !lastChoice) { - return; - } - // Collect data to emit an experience. - const action = choiceIds[lastChoice]; - const nextState = allocEncodedState(); - encodeState(nextState, state); - - ctx.logger.info(`Emitting experience, reward=${reward}`); - callback({ - ...expAgentData, - action, - reward, - nextState, - done, - }); - // Reset collected data for the next decision. - expAgentData = null; - lastChoice = null; - reward = 0; - } - - // Start tracking the game. - const result = await parser( - { - ...ctx, - // Extract additional info from the ExperienceAgent. - async agent(state, choices, logger) { - // Emit experience between last and current decision. - emitExperience(state); - expAgentData = await ctx.agent(state, choices, logger); - }, - iter: { - ...ctx.iter, - async next() { - // Observe events before the parser consumes them. - const r = await ctx.iter.next(); - if ( - !r.done && - ["win", "tie"].includes(r.value.args[0]) - ) { - // Add win/loss reward. - reward += r.value.args[1] === username ? 1 : -1; - } - return r; - }, - }, - // Extract the last choice that was accepted. - async sender(choice) { - const r = await ctx.sender(choice); - if (!r) { - lastChoice = choice; - } - return r; - }, - }, - ...args, - ); - - // Emit final experience at the end of the game. - emitExperience(ctx.state, true /*done*/); - return result; - }; -} diff --git a/src/train/model/port/ModelPort.ts b/src/train/model/port/ModelPort.ts deleted file mode 100644 index 9cd1914f..00000000 --- a/src/train/model/port/ModelPort.ts +++ /dev/null @@ -1,167 +0,0 @@ -import {MessagePort} from "worker_threads"; -import {modelInputNames} from "../../../model/shapes"; -import {intToChoice} from "../../../psbot/handlers/battle/agent"; -import { - allocEncodedState, - encodeState, -} from "../../../psbot/handlers/battle/ai/encoder"; -import {maxAgent} from "../../../psbot/handlers/battle/ai/maxAgent"; -import {WrappedError} from "../../../util/errors/WrappedError"; -import {rng} from "../../../util/random"; -import {randomExpAgent} from "../../game/agent/random"; -import { - ExperienceAgent, - ExperienceAgentData, -} from "../../game/experience/Experience"; -import {AgentExploreConfig} from "../../game/pool/worker"; -import {AsyncPort, ProtocolResultRaw} from "../../port/AsyncPort"; -import { - ModelPortProtocol, - PredictMessage, - PredictResult, -} from "./ModelPortProtocol"; - -/** - * Abstracts the interface between a game worker and a model owned by the main - * ModelWorker. - * - * Intended to be used by only one BattleAgent within a game worker that - * received a port to connect to a model. - */ -export class ModelPort { - /** Port wrapper. */ - private readonly asyncPort: AsyncPort< - MessagePort, - ModelPortProtocol, - keyof ModelPortProtocol - >; - - /** - * Creates a ModelPort. - * - * @param port Message port. - */ - public constructor(port: MessagePort) { - this.asyncPort = new AsyncPort(port); - port.on( - "message", - ( - res: ProtocolResultRaw< - ModelPortProtocol, - keyof ModelPortProtocol, - keyof ModelPortProtocol - >, - ) => this.asyncPort.receiveMessage(res), - ); - port.on("error", (err: Error) => - this.asyncPort.receiveError( - new WrappedError( - err, - msg => - "ModelPort encountered an unhandled exception: " + msg, - ), - ), - ); - } - - /** Closes the connection. */ - public close(): void { - this.asyncPort.port.close(); - } - - /** - * Creates a BattleAgent from this port that - * {@link ExperienceAgent returns} data used for building Experience objs. - * - * @param explore Exploration policy config. - */ - public getAgent(explore?: AgentExploreConfig): ExperienceAgent { - const random = explore?.seed ? rng(explore.seed) : Math.random; - - let data: ExperienceAgentData | null = null; - - const innerAgent = maxAgent(async state => { - const stateData = allocEncodedState(); - encodeState(stateData, state); - ModelPort.verifyInput(stateData); - - const result = await this.predict(stateData); - ModelPort.verifyOutput(result.output); - - data = {state: result.input}; - return result.output; - }); - - return async function portAgent(state, choices, logger) { - if (explore && random() < explore.factor) { - logger?.debug("Exploring"); - return await randomExpAgent(state, choices, random); - } - - logger?.debug("Exploiting"); - await innerAgent(state, choices, logger); - if (!data) { - throw new Error( - "ModelPort agent didn't collect experience data", - ); - } - const result = data; - data = null; - return result; - }; - } - - /** - * Makes sure that the input doesn't contain invalid values, i.e. `NaN`s or - * values outside the range `[-1, 1]`. - */ - public static verifyInput(data: Float32Array[]): void { - for (let i = 0; i < data.length; ++i) { - const arr = data[i]; - for (let j = 0; j < arr.length; ++j) { - const value = arr[j]; - if (isNaN(value)) { - throw new Error( - `Model input ${i} (${modelInputNames[i]}) contains ` + - `NaN at index ${j}`, - ); - } - if (value < -1 || value > 1) { - throw new Error( - `Model input ${i} (${modelInputNames[i]}) contains ` + - `an out-of-range value ${value} at index ${j}`, - ); - } - } - } - } - - /** Makes sure that the output doesn't contain invalid values. */ - public static verifyOutput(output: Float32Array): void { - for (let i = 0; i < output.length; ++i) { - if (isNaN(output[i])) { - throw new Error( - `Model output contains NaN for action ${i} ` + - `(${intToChoice[i]})`, - ); - } - } - } - - /** Requests a prediction from the neural network. */ - private async predict(state: Float32Array[]): Promise { - const msg: PredictMessage = { - type: "predict", - rid: this.asyncPort.nextRid(), - state, - }; - return await new Promise((res, rej) => - this.asyncPort.postMessage( - msg, - msg.state.map(a => a.buffer), - result => - result.type === "error" ? rej(result.err) : res(result), - ), - ); - } -} diff --git a/src/train/model/port/ModelPortProtocol.ts b/src/train/model/port/ModelPortProtocol.ts deleted file mode 100644 index b16358fd..00000000 --- a/src/train/model/port/ModelPortProtocol.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** @file Defines the protocol typings for ModelPorts. */ -import { - PortProtocol, - PortRequestBase, - PortResultBase, -} from "../../port/PortProtocol"; - -/** ModelPort request protocol typings. */ -export interface ModelPortProtocol extends PortProtocol<"predict"> { - predict: {message: PredictMessage; result: PredictWorkerResult}; -} - -/** Base interface for the predict request protocol. */ -type PredictRequestBase = PortRequestBase<"predict">; - -/** Prediction request message format. */ -export interface PredictMessage extends PredictRequestBase { - /** State data. */ - state: Float32Array[]; -} - -/** Prediction returned from the model. */ -export interface PredictWorkerResult - extends PortResultBase<"predict">, - PredictResult { - /** @override */ - done: true; -} - -/** Result from a prediction. */ -export interface PredictResult { - /** Given state input. */ - input: Float32Array[]; - /** Action output. */ - output: Float32Array; -} diff --git a/src/train/model/worker/Learn.ts b/src/train/model/worker/Learn.ts deleted file mode 100644 index 238c40fa..00000000 --- a/src/train/model/worker/Learn.ts +++ /dev/null @@ -1,340 +0,0 @@ -import * as tf from "@tensorflow/tfjs"; -import {LearnConfig} from "../../../config/types"; -import {intToChoice} from "../../../psbot/handlers/battle/agent"; -import {Metrics} from "./Metrics"; -import {BatchedExp} from "./dataset"; - -/** - * Encapsulates the learning step of training, where the model is updated based - * on experience generated by rollout games. - */ -export class Learn { - /** Metrics logger. */ - private readonly metrics = Metrics.get(`${this.name}/learn`); - /** Used for calculating gradients. */ - private readonly optimizer = tf.train.sgd(this.config.learningRate); - /** Collection of trainable variables in the model. */ - private readonly variables = this.model.trainableWeights.map( - w => w.read() as tf.Variable, - ); - /** Used for logging inputs during loss calcs. */ - private readonly hookLayers: readonly tf.layers.Layer[] = - this.model.layers.filter(l => - ["Dense", "SetAttention", "SetMultiHeadAttention"].includes( - l.getClassName(), - ), - ); - - /** - * Creates a Learn object. - * - * @param name Name of the training run for logging. - * @param model Model to train. - * @param targetModel Model for computing TD targets. Can be set to the same - * model to disable target model mechanism. - * @param iterator Iterator to pull from to obtain batched experiences for - * learning. - * @param config Learning config. - */ - public constructor( - public readonly name: string, - private readonly model: tf.LayersModel, - private readonly targetModel: tf.LayersModel, - private readonly iterator: AsyncIterator, - private readonly config: LearnConfig, - ) { - // Log initial weights. - for (const weights of this.variables) { - if (weights.size === 1) { - const weightScalar = weights.asScalar(); - this.metrics?.scalar( - `${weights.name}/weights`, - weightScalar, - 0, - ); - tf.dispose(weightScalar); - } else { - this.metrics?.histogram(`${weights.name}/weights`, weights, 0); - } - } - } - - /** - * Performs the configured amount of batch update steps on the model, - * completing one learning episode. - * - * @param step Episode step number for logging. - * @param callback Called for each batch. - * @returns The average loss of each batch update. - */ - public async episode( - step: number, - callback?: (step: number, loss: number) => void, - ): Promise { - const avgInputs: tf.NamedTensorMap = {}; - for (const layer of this.hookLayers) { - // Note: Call hook is wrapped in tf.tidy() so tf.keep() is used. - layer.setCallHook(function logInputs(inputs) { - if (!Array.isArray(inputs)) { - inputs = [inputs]; - } - for (let i = 0; i < inputs.length; ++i) { - // Average along all axes except last one in order to - // account for vectorized inputs. - const input = inputs[i] - .mean(inputs[i].shape.map((_, j) => j).slice(0, -1)) - .flatten(); - const name = - inputs.length > 1 ? `${layer.name}/${i}` : layer.name; - if (Object.prototype.hasOwnProperty.call(avgInputs, name)) { - const oldInput = avgInputs[name]; - avgInputs[name] = tf.keep(tf.add(oldInput, input)); - tf.dispose(oldInput); - } else { - avgInputs[name] = tf.keep(input); - } - } - }); - } - - // Discard very first batch in order to warmup the gpu and fill rollout - // threads and prefetch buffers without polluting performance stats. - if (step === 1) { - const result = await this.iterator.next(); - if (result.done) { - throw new Error("No more data in dataset"); - } - tf.dispose(result.value); - } - - const beforeUpdate = process.hrtime.bigint(); - const batchFetchTimes: number[] = []; - const batchUpdateTimes: number[] = []; - const batchTotalTimes: number[] = []; - let avgLoss = tf.scalar(0, "float32"); - const totalGrads: tf.NamedTensorMap = {}; - for (let i = 0; i < this.config.updates; ++i) { - const beforeFetch = process.hrtime.bigint(); - const result = await this.iterator.next(); - if (result.done) { - throw new Error("No more data in dataset"); - } - const batch = result.value; - const afterFetch = process.hrtime.bigint(); - - const {batchLoss, batchGrads} = this.update(batch); - tf.dispose(batch); - const afterUpdate = process.hrtime.bigint(); - - const oldAvgLoss = avgLoss; - avgLoss = tf.add(oldAvgLoss, batchLoss); - callback?.(i + 1, (await batchLoss.data<"float32">())[0]); - tf.dispose([oldAvgLoss, batchLoss]); - - for (const name in batchGrads) { - if (!Object.prototype.hasOwnProperty.call(batchGrads, name)) { - continue; - } - if (Object.prototype.hasOwnProperty.call(totalGrads, name)) { - const oldGrads = totalGrads[name]; - totalGrads[name] = tf.add(oldGrads, batchGrads[name]); - tf.dispose([oldGrads, batchGrads[name]]); - } else { - totalGrads[name] = batchGrads[name]; - } - } - const afterAll = process.hrtime.bigint(); - - batchFetchTimes.push(Number(afterFetch - beforeFetch) / 1e6 /*ms*/); - batchUpdateTimes.push(Number(afterUpdate - afterFetch) / 1e6); - batchTotalTimes.push(Number(afterAll - beforeFetch) / 1e6); - } - const afterUpdate = process.hrtime.bigint(); - - const updateTime = Number(afterUpdate - beforeUpdate) / 1e9; - this.metrics?.scalar("update_s", updateTime, step); - this.metrics?.scalar( - "update_throughput_s", - (this.config.updates * this.config.buffer.batch) / updateTime, - step, - ); - - tf.tidy(() => { - const batchFetchTensor = tf.tensor1d(batchFetchTimes, "float32"); - const batchUpdateTensor = tf.tensor1d(batchUpdateTimes, "float32"); - const batchTotalTensor = tf.tensor1d(batchTotalTimes, "float32"); - this.metrics?.histogram("batch_fetch_ms", batchFetchTensor, step); - this.metrics?.scalar( - "batch_fetch_ms/avg", - tf.mean(batchFetchTensor).asScalar(), - step, - ); - this.metrics?.histogram("batch_update_ms", batchUpdateTensor, step); - this.metrics?.scalar( - "batch_update_ms/avg", - tf.mean(batchUpdateTensor).asScalar(), - step, - ); - this.metrics?.histogram("batch_total_ms", batchTotalTensor, step); - this.metrics?.scalar( - "batch_total_ms/avg", - tf.mean(batchTotalTensor).asScalar(), - step, - ); - }); - - const oldAvgLoss = avgLoss; - avgLoss = tf.div(oldAvgLoss, this.config.updates); - const avgLossData = await avgLoss.data<"float32">(); - this.metrics?.scalar("loss", avgLoss, step); - tf.dispose([oldAvgLoss, avgLoss]); - - for (const name in totalGrads) { - if (!Object.prototype.hasOwnProperty.call(totalGrads, name)) { - continue; - } - const grad = totalGrads[name]; - if (grad.size === 1) { - const gradScalar = grad.asScalar(); - this.metrics?.scalar(`${name}/grads`, gradScalar, step); - tf.dispose(gradScalar); - } else { - this.metrics?.histogram(`${name}/grads`, grad, step); - } - tf.dispose(grad); - } - - for (const name in avgInputs) { - if (!Object.prototype.hasOwnProperty.call(avgInputs, name)) { - continue; - } - const oldInput = avgInputs[name]; - avgInputs[name] = tf.div(oldInput, this.config.updates); - this.metrics?.histogram(`${name}/input`, avgInputs[name], step); - tf.dispose([oldInput, avgInputs[name]]); - } - - for (const weights of this.variables) { - if (weights.size === 1) { - const weightScalar = weights.asScalar(); - this.metrics?.scalar( - `${weights.name}/weights`, - weightScalar, - step, - ); - tf.dispose(weightScalar); - } else { - this.metrics?.histogram( - `${weights.name}/weights`, - weights, - step, - ); - } - } - - for (const layer of this.hookLayers) { - layer.clearCallHook(); - } - - return avgLossData[0]; - } - - /** Performs one batch update step, returning the loss. */ - private update(batch: BatchedExp): { - batchLoss: tf.Scalar; - batchGrads: tf.NamedTensorMap; - } { - return tf.tidy(() => { - // Dataset batching process turns arrays into objs so need to undo - // this. - const batchState: tf.Tensor[] = []; - for (const key of Object.keys(batch.state)) { - const index = Number(key); - batchState[index] = batch.state[index]; - } - - const batchNextState: tf.Tensor[] = []; - for (const key of Object.keys(batch.nextState)) { - const index = Number(key); - batchNextState[index] = batch.nextState[index]; - } - const target = this.calculateTarget( - batch.reward, - batchNextState, - batch.done, - ); - - // Compute batch gradients manually to be able to do logging - // in-between. - const {value: batchLoss, grads: batchGrads} = - this.optimizer.computeGradients( - () => this.loss(batchState, batch.action, target), - this.variables, - ); - this.optimizer.applyGradients(batchGrads); - - return {batchLoss, batchGrads}; - }); - } - - /** Calculates TD target. */ - private calculateTarget( - reward: tf.Tensor, - nextState: tf.Tensor[], - done: tf.Tensor, - ): tf.Tensor { - return tf.tidy(() => { - let targetQ: tf.Tensor; - const q = this.model.predictOnBatch(nextState) as tf.Tensor; - if (!this.config.target) { - // Vanilla DQN TD target: r + gamma * max_a(Q(s', a)) - targetQ = tf.max(q, -1); - } else { - targetQ = this.targetModel.predictOnBatch( - nextState, - ) as tf.Tensor; - if (this.config.target !== "double") { - // TD target with target net: r + gamma * max_a(Qt(s', a)) - targetQ = tf.max(targetQ, -1); - } else { - // Double Q target: r + gamma * Qt(s', argmax_a(Q(s', a))) - const action = tf.argMax(q, -1); - const actionMask = tf.oneHot(action, intToChoice.length); - targetQ = tf.sum(tf.mul(targetQ, actionMask), -1); - } - } - - // Also mask out q values of terminal states. - targetQ = tf.where(done, 0, targetQ); - - const target = tf.add( - reward, - tf.mul(this.config.experience.rewardDecay, targetQ), - ); - return target; - }); - } - - private loss( - state: tf.Tensor[], - action: tf.Tensor, - target: tf.Tensor, - ): tf.Scalar { - return tf.tidy("loss", () => { - // Isolate the Q-value of the action that was taken. - let q = this.model.predictOnBatch(state) as tf.Tensor; - const mask = tf.oneHot(action, intToChoice.length); - q = tf.sum(tf.mul(q, mask), -1); - - // Compute the loss based on the discount reward actually obtained - // from that action. - return tf.losses.meanSquaredError(target, q); - }); - } - - /** Cleans up dangling variables. */ - public cleanup(): void { - this.optimizer.dispose(); - Metrics.flush(); - } -} diff --git a/src/train/model/worker/dataset.ts b/src/train/model/worker/dataset.ts deleted file mode 100644 index 2afb467c..00000000 --- a/src/train/model/worker/dataset.ts +++ /dev/null @@ -1,73 +0,0 @@ -import * as tf from "@tensorflow/tfjs"; -import {BufferConfig} from "../../../config/types"; -import {encodedStateToTensors} from "../../../psbot/handlers/battle/ai/networkAgent"; -import {Experience} from "../../game/experience"; - -/** - * {@link Experience} with values converted to {@link tf.Tensor tensors}. - */ -export type TensorExp = { - [T in keyof Experience]: Experience[T] extends number | boolean - ? tf.Scalar - : Experience[T] extends Float32Array - ? tf.Tensor1D - : Experience[T] extends Float32Array[] - ? {[index: number]: tf.Tensor} - : never; -}; - -/** - * Batched {@link Experience} stacked {@link tf.Tensor tensors}. - * - * Essentially a list of {@link TensorExp}s but with values converted to - * stacked tensors. - */ -export type BatchedExp = { - [T in keyof Experience]: Experience[T] extends number | boolean - ? tf.Tensor1D - : Experience[T] extends Float32Array - ? tf.Tensor2D - : Experience[T] extends Float32Array[] - ? {[index: number]: tf.Tensor2D} - : never; -}; - -/** - * Creates a TensorFlow Dataset from an experience stream for use in training. - * - * @param gen Generator for experience objects. - * @param config Config for buffering and batching. - * @param seed Optional seed for shuffling. - */ -export function datasetFromRollout( - gen: AsyncGenerator, - config: BufferConfig, - seed?: string, -): tf.data.Dataset { - return tf.data - .generator( - // Note: This still works with async generators even though the - // typings don't explicitly support it. - async function* () { - for await (const exp of gen) { - yield experienceToTensor(exp); - } - } as unknown as () => Generator, - ) - .shuffle(config.shuffle, seed) - .batch(config.batch) - .prefetch(config.prefetch) as tf.data.Dataset; -} - -/** Converts Experience fields to tensors suitable for batching. */ -function experienceToTensor(exp: Experience): TensorExp { - return { - // Convert array into an object with integer keys in order to prevent - // the array itself from being batched, just the contained tensors. - state: {...encodedStateToTensors(exp.state)}, - action: tf.scalar(exp.action, "int32"), - reward: tf.scalar(exp.reward, "float32"), - nextState: {...encodedStateToTensors(exp.nextState)}, - done: tf.scalar(exp.done, "bool"), - }; -} diff --git a/src/train/model/worker/train.ts b/src/train/model/worker/train.ts deleted file mode 100644 index 9f7a1063..00000000 --- a/src/train/model/worker/train.ts +++ /dev/null @@ -1,198 +0,0 @@ -import {join} from "path"; -import {serialize} from "v8"; -import * as tf from "@tensorflow/tfjs"; -import {PathsConfig, TrainConfig} from "../../../config/types"; -import {ensureDir} from "../../../util/paths/ensureDir"; -import {pathToFileUrl} from "../../../util/paths/pathToFileUrl"; -import {seeder} from "../../../util/random"; -import {GameArgsGenSeeders} from "../../game/pool"; -import {Evaluate} from "./Evaluate"; -import {Learn} from "./Learn"; -import {Metrics} from "./Metrics"; -import {ModelTrainData} from "./ModelProtocol"; -import {ModelRegistry} from "./ModelRegistry"; -import {Rollout} from "./Rollout"; -import {datasetFromRollout} from "./dataset"; - -/** - * Main training loop. - * - * @param model Model to train. - * @param config Training config. - * @param paths Optional paths to store model checkpoints, game logs, and - * metrics. - * @param callback Callback for notifying the main thread of various events - * during each episode, including errors. - */ -export async function train( - model: ModelRegistry, - config: TrainConfig, - paths?: Partial, - callback?: (data: ModelTrainData) => void, -): Promise { - const metrics = Metrics.get("train"); - - let checkpointsPath: string | undefined; - if (paths?.models) { - checkpointsPath = join(paths.models, "checkpoints"); - await Promise.all([ - (async function () { - await ensureDir(checkpointsPath), - await model.save( - pathToFileUrl(join(checkpointsPath, "original")), - ); - })(), - model.save(pathToFileUrl(paths.models)), - ]); - } - - const [rolloutModel, prevModel, targetModel] = await Promise.all( - ["rollout", "prev", "target"].map( - async name => await model.clone(name), - ), - ); - - const seeders: GameArgsGenSeeders | undefined = config.seeds && { - ...(config.seeds.battle && {battle: seeder(config.seeds.battle)}), - ...(config.seeds.team && {team: seeder(config.seeds.team)}), - ...(config.seeds.explore && {explore: seeder(config.seeds.explore)}), - }; - - rolloutModel.lock("train", 0 /*step*/); - prevModel.lock("train", 0); - - const rollout = new Rollout( - "train", - rolloutModel, - prevModel, - config.rollout, - paths?.logs ? join(paths.logs, "rollout") : undefined, - { - ...seeders, - ...(config.seeds?.rollout && { - rollout: seeder(config.seeds.rollout), - }), - }, - ); - - const dataset = datasetFromRollout( - rollout.gen( - callback && - (result => - callback({ - type: "rollout", - id: result.id, - ...(result.err && {err: serialize(result.err)}), - })), - ), - config.learn.buffer, - config.seeds?.learn, - ); - const learn = new Learn( - "train", - model.model, - targetModel.model, - await dataset.iterator(), - config.learn, - ); - - const evaluate = new Evaluate( - "train", - rolloutModel, - prevModel, - config.eval, - paths?.logs ? join(paths.logs, "eval") : undefined, - seeders && { - ...(seeders.battle && {battle: seeder(seeders.battle())}), - ...(seeders.team && {team: seeder(seeders.team())}), - ...(seeders.explore && {explore: seeder(seeders.explore())}), - }, - ); - - const logMemoryMetrics = (step: number) => { - if (metrics) { - const memory = tf.memory(); - metrics.scalar("memory/num_bytes", memory.numBytes, step); - metrics.scalar("memory/num_tensors", memory.numTensors, step); - } - }; - - let lastEval: Promise | undefined; - try { - rolloutModel.unlock(); - prevModel.unlock(); - rolloutModel.lock("train", 1); - prevModel.lock("train", 1); - logMemoryMetrics(0); - - for (let i = 0; i < config.episodes; ++i) { - const step = i + 1; - callback?.({type: "episode", step}); - const loss = await learn.episode(step, (batchStep, batchLoss) => - callback?.({type: "batch", step: batchStep, loss: batchLoss}), - ); - await lastEval; - callback?.({type: "learn", loss}); - - logMemoryMetrics(step); - - rollout.step(step); - rolloutModel.unlock(); - prevModel.unlock(); - if (i < config.episodes) { - rolloutModel.lock("train", step + 1); - prevModel.lock("train", step + 1); - rolloutModel.copyTo(prevModel); - } - model.copyTo(rolloutModel); - model.copyTo(targetModel); - - lastEval = Promise.all([ - evaluate - .run( - step, - callback && - (result => - callback({ - type: "eval", - step, - id: result.id, - agents: result.agents, - ...(result.winner !== undefined && { - winner: result.winner, - }), - ...(result.err && { - err: serialize(result.err), - }), - })), - ) - .then( - callback && - (wlt => callback({type: "evalDone", step, wlt})), - ), - ...(checkpointsPath - ? [ - rolloutModel.save( - pathToFileUrl( - join(checkpointsPath, `episode-${step}`), - ), - ), - ] - : []), - ...(paths?.models - ? [rolloutModel.save(pathToFileUrl(paths.models))] - : []), - ]); - // Suppress unhandled exception warnings since we'll await this - // promise later. - lastEval.catch(() => {}); - } - } finally { - await lastEval; - await Promise.all([rollout.cleanup(), evaluate.cleanup()]); - learn.cleanup(); - for (const m of [rolloutModel, prevModel, targetModel]) { - m.unload(); - } - } -} diff --git a/src/train/pool/index.ts b/src/train/pool/index.ts deleted file mode 100644 index bcf4295d..00000000 --- a/src/train/pool/index.ts +++ /dev/null @@ -1 +0,0 @@ -export {ThreadPool} from "./ThreadPool"; diff --git a/src/train/train.ts b/src/train/train.ts new file mode 100644 index 00000000..814adbfd --- /dev/null +++ b/src/train/train.ts @@ -0,0 +1,259 @@ +import {join} from "path"; +import {serialize} from "v8"; +import * as tf from "@tensorflow/tfjs"; +import {PathsConfig, TrainConfig} from "../config/types"; +import {GameArgsGenSeeders} from "../game/pool"; +import {ModelTrainData} from "../model/worker"; +import {Metrics} from "../model/worker/Metrics"; +import {ModelRegistry} from "../model/worker/ModelRegistry"; +import {cloneModel} from "../util/model"; +import {ensureDir} from "../util/paths/ensureDir"; +import {pathToFileUrl} from "../util/paths/pathToFileUrl"; +import {rng, seeder} from "../util/random"; +import {Evaluate} from "./Evaluate"; +import {Learn} from "./Learn"; +import {ReplayBuffer} from "./ReplayBuffer"; +import {Rollout} from "./Rollout"; +import {RolloutModel} from "./RolloutModel"; + +/** + * Main training loop. + * + * @param model Model to train. + * @param config Training config. + * @param paths Optional paths to store model checkpoints, game logs, and + * metrics. + * @param callback Callback for notifying the main thread of various events + * during each step, including errors. + */ +export async function train( + model: tf.LayersModel, + config: TrainConfig, + paths?: Partial, + callback?: (data: ModelTrainData) => void, +): Promise { + const metrics = Metrics.get("train"); + + let checkpointsPath: string | undefined; + if (paths?.models) { + checkpointsPath = join(paths.models, "checkpoints"); + await Promise.all([ + (async function () { + await ensureDir(checkpointsPath), + await model.save( + pathToFileUrl(join(checkpointsPath, "original")), + ); + })(), + model.save(pathToFileUrl(paths.models)), + ]); + } + + const rolloutModel = new RolloutModel("rollout", model); + const [evalModel, prevModel] = await Promise.all( + ["eval", "prev"].map( + async name => + new ModelRegistry( + name, + await cloneModel(model), + config.batchPredict, + ), + ), + ); + const targetModel = await cloneModel(model); + + evalModel.lock("train", 0 /*step*/); + prevModel.lock("train", 0); + + const seeders: GameArgsGenSeeders | undefined = config.seeds && { + ...(config.seeds.battle && {battle: seeder(config.seeds.battle)}), + ...(config.seeds.team && {team: seeder(config.seeds.team)}), + ...(config.seeds.explore && {explore: seeder(config.seeds.explore)}), + }; + + const rollout = new Rollout( + "train", + rolloutModel, + prevModel, + config.rollout, + paths?.logs ? join(paths.logs, "rollout") : undefined, + { + ...seeders, + ...(config.seeds?.rollout && { + rollout: seeder(config.seeds.rollout), + }), + }, + ); + + const learn = new Learn( + "train", + model, + targetModel, + config.learn, + config.experience, + ); + + const evaluate = new Evaluate( + "train", + evalModel, + prevModel, + config.eval, + paths?.logs ? join(paths.logs, "eval") : undefined, + seeders && { + ...(seeders.battle && {battle: seeder(seeders.battle())}), + ...(seeders.team && {team: seeder(seeders.team())}), + ...(seeders.explore && {explore: seeder(seeders.explore())}), + }, + ); + + const logMemoryMetrics = (step: number) => { + if (metrics) { + const memory = tf.memory(); + metrics.scalar("memory/num_bytes", memory.numBytes, step); + metrics.scalar("memory/num_tensors", memory.numTensors, step); + } + }; + + const replayBuffer = new ReplayBuffer(config.experience.bufferSize); + const bufferRandom = config.seeds?.learn + ? rng(config.seeds.learn) + : undefined; + + void rollout.run( + callback && + (result => + callback({ + type: "rollout", + id: result.id, + ...(result.err && {err: serialize(result.err)}), + })), + ); + + let lastEval: Promise | undefined; + try { + let step = 0; + evalModel.unlock(); + prevModel.unlock(); + + let i = 0; + while (i < config.experience.prefill) { + const exps = await rolloutModel.step(); + for (const exp of exps) { + replayBuffer.add(exp); + ++i; + } + } + + evalModel.lock("train", step); + prevModel.lock("train", step); + + logMemoryMetrics(step); + + ++step; + while (!config.steps || step < config.steps) { + const exps = await rolloutModel.step(); + for (const exp of exps) { + if (config.steps && step >= config.steps) { + break; + } + replayBuffer.add(exp); + const loss = tf.tidy(() => + learn.step( + step, + replayBuffer.sample( + config.learn.batchSize, + bufferRandom, + ), + ), + ); + callback?.({ + type: "learn", + step, + loss: (await loss.data<"float32">())[0], + }); + loss.dispose(); + + rollout.step(step); + logMemoryMetrics(step); + + if (step % config.learn.targetInterval === 0) { + targetModel.setWeights(model.getWeights()); + } + + if (step % config.eval.interval === 0) { + await lastEval; + + evalModel.unlock(); + prevModel.unlock(); + prevModel.model.setWeights(evalModel.model.getWeights()); + evalModel.model.setWeights(model.getWeights()); + evalModel.lock("train", step); + prevModel.lock("train", step); + + const evalStep = step; + lastEval = evaluate + .run( + evalStep, + callback && + (result => + callback({ + type: "eval", + step: evalStep, + id: result.id, + agents: result.agents, + ...(result.winner !== undefined && { + winner: result.winner, + }), + ...(result.err && { + err: serialize(result.err), + }), + })), + ) + .then(wlt => + callback?.({type: "evalDone", step: evalStep, wlt}), + ); + // Suppress unhandled exception warnings since we'll await + // this promise later. + lastEval.catch(() => {}); + } + if ( + config.checkpointInterval && + step % config.checkpointInterval === 0 + ) { + // TODO: Use a separate model copy so this isn't blocking. + await Promise.all( + [ + ...(paths?.models + ? [pathToFileUrl(paths.models)] + : []), + ...(config.savePreviousVersions && checkpointsPath + ? [ + pathToFileUrl( + join(checkpointsPath, `step-${step}`), + ), + ] + : []), + ].map(async url => await model.save(url)), + ); + } + + // Async yield to allow for evaluate step to run in parallel. + await tf.nextFrame(); + + ++step; + } + } + await lastEval; + await evaluate.close(); + } finally { + await Promise.all([rollout.terminate(), evaluate.terminate()]); + replayBuffer.dispose(); + learn.cleanup(); + for (const m of [rolloutModel, evalModel, prevModel]) { + m.unload(); + } + targetModel.dispose(); + } + if (paths?.models) { + await model.save(pathToFileUrl(paths.models)); + } +} diff --git a/src/util/format.ts b/src/util/format.ts new file mode 100644 index 00000000..8325f758 --- /dev/null +++ b/src/util/format.ts @@ -0,0 +1,35 @@ +/** Gets the number of whole decimal digits in a number. */ +export function numDigits(n: number): number { + return Math.max(1, Math.ceil(Math.log10(Math.abs(n)))); +} + +/** Formats a {@link process.uptime} value. */ +export function formatUptime(seconds: number): string { + seconds = Math.floor(seconds); + let s = ""; + const days = Math.floor(seconds / (24 * 60 * 60)); + seconds %= 24 * 60 * 60; + if (days > 0) { + s += `${days}d`; + } + const hours = Math.floor(seconds / (60 * 60)); + seconds %= 60 * 60; + if (s.length > 0) { + s += String(hours).padStart(2, "0") + "h"; + } else if (hours > 0) { + s += String(hours) + "h"; + } + const minutes = Math.floor(seconds / 60); + seconds %= 60; + if (s.length > 0) { + s += String(minutes).padStart(2, "0") + "m"; + } else if (minutes > 0) { + s += String(minutes) + "m"; + } + if (s.length > 0) { + s += String(seconds).padStart(2, "0") + "s"; + } else { + s += String(seconds) + "s"; + } + return s; +} diff --git a/src/util/model.ts b/src/util/model.ts new file mode 100644 index 00000000..111b1b70 --- /dev/null +++ b/src/util/model.ts @@ -0,0 +1,20 @@ +import * as tf from "@tensorflow/tfjs"; + +/** Clones a TensorFlow model. */ +export async function cloneModel( + model: tf.LayersModel, +): Promise { + const modelArtifact = new Promise( + res => + void model.save({ + // eslint-disable-next-line @typescript-eslint/require-await + save: async _modelArtifact => { + res(_modelArtifact); + return {} as tf.io.SaveResult; + }, + }), + ); + return await tf.loadLayersModel({ + load: async () => await Promise.resolve(modelArtifact), + }); +} diff --git a/src/train/pool/ThreadPool.ts b/src/util/pool/ThreadPool.ts similarity index 83% rename from src/train/pool/ThreadPool.ts rename to src/util/pool/ThreadPool.ts index 0deaee5e..e78a517c 100644 --- a/src/train/pool/ThreadPool.ts +++ b/src/util/pool/ThreadPool.ts @@ -1,7 +1,7 @@ import {Worker} from "worker_threads"; import {ListenerSignature, TypedEmitter} from "tiny-typed-emitter"; -import {WorkerPort} from "../port/WorkerPort"; -import {WorkerProtocol} from "../port/WorkerProtocol"; +import {WorkerPort} from "../worker/WorkerPort"; +import {WorkerProtocol} from "../worker/WorkerProtocol"; /** * Required methods for {@link ThreadPool}'s `TWorker` type param. @@ -16,7 +16,7 @@ import {WorkerProtocol} from "../port/WorkerProtocol"; export type WorkerPortLike< TProtocol extends WorkerProtocol, TTypes extends string, -> = Pick, "close">; +> = Pick, "close" | "terminate">; /** Event for when a WorkerPort is free. */ const workerFree = Symbol("workerFree"); @@ -81,8 +81,6 @@ export class ThreadPool< `Expected positive numThreads but got ${numThreads}`, ); } - // Note: Heavy takePort() usage can cause listeners to build up, but - // they should always stay under the number of threads. this.workerEvents.setMaxListeners(this.numThreads); for (let i = 0; i < this.numThreads; ++i) { @@ -102,6 +100,10 @@ export class ThreadPool< * of outstanding calls to this method by the {@link numThreads}. */ public async takePort(): Promise { + if (this.ports.size <= 0) { + throw new Error("ThreadPool is closed"); + } + // Wait until a port is open. while (this.freePorts.length <= 0) { await new Promise(res => @@ -117,6 +119,10 @@ export class ThreadPool< * pool. */ public givePort(port: TWorker): void { + if (this.ports.size <= 0) { + throw new Error("ThreadPool is closed"); + } + if (!this.ports.has(port)) { // Errored port has been returned. if (this.erroredPorts.has(port)) { @@ -131,7 +137,7 @@ export class ThreadPool< } /** - * Safely closes each port. + * Safely closes each port by calling {@link takePort}. * * Note that future calls to {@link takePort} will never resolve after this * resolves. @@ -151,6 +157,26 @@ export class ThreadPool< await Promise.all(closePromises); } + /** + * Closes each of the threads that are currently running. + * + * Future calls to {@link takePort} and {@link givePort} will throw after + * this method is called. + */ + public async terminate(): Promise { + this.freePorts.length = 0; + const closePromises: Promise[] = []; + for (const port of this.ports) { + closePromises.push(port.terminate()); + this.ports.delete(port); + } + for (const port of this.erroredPorts) { + closePromises.push(port.terminate()); + this.erroredPorts.delete(port); + } + await Promise.all(closePromises); + } + /** * Adds a new worker to the pool. * @@ -165,9 +191,9 @@ export class ThreadPool< // Remove the errored worker and create a new one to replace it. if (!this.freePorts.includes(port)) { - // Port hasn't been returned yet. + // Port hasn't been given back to us yet. // Cache it in a special place so that #givePort() doesn't throw - // later. + // later if the original caller decides to give it back. this.ports.delete(port); this.erroredPorts.add(port); } diff --git a/src/train/port/AsyncPort.ts b/src/util/port/AsyncPort.ts similarity index 100% rename from src/train/port/AsyncPort.ts rename to src/util/port/AsyncPort.ts diff --git a/src/train/port/PortProtocol.ts b/src/util/port/PortProtocol.ts similarity index 100% rename from src/train/port/PortProtocol.ts rename to src/util/port/PortProtocol.ts diff --git a/src/train/port/WorkerPort.ts b/src/util/worker/WorkerPort.ts similarity index 92% rename from src/train/port/WorkerPort.ts rename to src/util/worker/WorkerPort.ts index 371079b4..de3145cd 100644 --- a/src/train/port/WorkerPort.ts +++ b/src/util/worker/WorkerPort.ts @@ -1,12 +1,12 @@ import {TransferListItem, Worker} from "worker_threads"; -import {WrappedError} from "../../util/errors/WrappedError"; +import {WrappedError} from "../errors/WrappedError"; import { AsyncPort, ProtocolMessage, ProtocolResultRaw, ProtocolResult, -} from "./AsyncPort"; -import {PortResultError} from "./PortProtocol"; +} from "../port/AsyncPort"; +import {PortResultError} from "../port/PortProtocol"; import {WorkerProtocol} from "./WorkerProtocol"; /** @@ -69,6 +69,11 @@ export class WorkerPort< } } + /** Force-closes the worker. */ + public async terminate(): Promise { + await this.worker.terminate(); + } + /** * Sends and tracks a message through the port. * diff --git a/src/train/port/WorkerProtocol.ts b/src/util/worker/WorkerProtocol.ts similarity index 89% rename from src/train/port/WorkerProtocol.ts rename to src/util/worker/WorkerProtocol.ts index 71dced21..770699c6 100644 --- a/src/train/port/WorkerProtocol.ts +++ b/src/util/worker/WorkerProtocol.ts @@ -1,5 +1,9 @@ /** @file Defines the base protocol typings for WorkerPorts. */ -import {PortMessageBase, PortProtocol, PortResultBase} from "./PortProtocol"; +import { + PortMessageBase, + PortProtocol, + PortResultBase, +} from "../port/PortProtocol"; /** * Base type for WorkerPort request protocol typings.