Skip to content

Commit

Permalink
Synchronize rollout with proper replay buffer, reorganize
Browse files Browse the repository at this point in the history
Improvement on #353 and setup for #354.

Rewrite training algorithm (again) to remove the concept of episodes and
instead focus on pure learning steps according to the DQN algorithm.
Also add a proper replay buffer implementation.

Add/rewrite some configs/metrics code to mesh with above.

Also reorganize source tree, general housekeeping.
  • Loading branch information
taylorhansen committed Jan 29, 2023
1 parent 4c252d4 commit a979442
Show file tree
Hide file tree
Showing 62 changed files with 1,989 additions and 1,556 deletions.
40 changes: 22 additions & 18 deletions src/config/config.example.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {Config} from "./types";
// the training script should make the whole process fully deterministic.
const numThreads = os.cpus().length;

const maxTurns = 100;

/**
* Top-level config. Should only be accessed by the top-level.
*
Expand All @@ -32,7 +34,7 @@ export const config: Config = {
tf: {gpu: false},
train: {
name: "train",
episodes: 16,
steps: maxTurns * 2 * 50 /*enough for at least 50 games*/,
batchPredict: {
maxSize: numThreads,
timeoutNs: 10_000_000n /*10ms*/,
Expand All @@ -47,36 +49,36 @@ export const config: Config = {
rollout: {
pool: {
numThreads,
maxTurns: 100,
maxTurns,
reduceLogs: true,
},
policy: {
exploration: 1.0,
explorationDecay: 0.9,
explorationDecay: 0.999,
minExploration: 0.1,
},
prev: 0.1,
},
experience: {
rewardDecay: 0.99,
bufferSize: maxTurns * 2 * 25 /*enough for at least 25 games*/,
prefill: maxTurns * 2 * numThreads /*at least one complete game*/,
},
learn: {
learningRate: 0.0001,
batchSize: 32,
target: "double",
targetInterval: 500,
metricsInterval: 100,
},
eval: {
numGames: 128,
numGames: 32,
pool: {
numThreads,
maxTurns: 100,
maxTurns,
reduceLogs: true,
},
},
learn: {
updates: 1024,
learningRate: 0.0001,
buffer: {
shuffle: 100 * 2 * 8 /*at least 8 game's worth*/,
batch: 32,
prefetch: 4,
},
experience: {
rewardDecay: 0.99,
},
target: "double",
interval: 1000,
},
seeds: {
model: "abc",
Expand All @@ -87,6 +89,8 @@ export const config: Config = {
learn: "pqr",
},
savePreviousVersions: true,
checkpointInterval: 1000,
progress: true,
verbose: Verbose.Info,
},
compare: {
Expand Down
85 changes: 46 additions & 39 deletions src/config/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,31 @@ export interface TensorflowConfig {
export interface TrainConfig {
/** Name of the training run under which to store logs. */
readonly name: string;
/** Number of training episodes to complete. */
readonly episodes: number;
/** Batch predict config. */
/** Number of learning steps. Omit or set to zero to train indefinitely. */
readonly steps?: number;
/** Batch predict config for models outside the learning step. */
readonly batchPredict: BatchPredictConfig;
/** Model config. */
readonly model: ModelConfig;
/** Rollout config. */
readonly rollout: RolloutConfig;
/** Evaluation config. */
readonly eval: EvalConfig;
/** Experience config. */
readonly experience: ExperienceConfig;
/** Learning config. */
readonly learn: LearnConfig;
/** Evaluation config. */
readonly eval: EvalConfig;
/** RNG config. */
readonly seeds?: TrainSeedConfig;
/** Whether to save model checkpoints as separate versions. */
readonly savePreviousVersions?: boolean;
/**
* Whether to save each previous version of the model separately after each
* training step.
* Step interval for saving model checkpoints. Omit to not store
* checkpoints.
*/
readonly savePreviousVersions: boolean;
readonly checkpointInterval?: number;
/** Whether to display a progress bar if {@link steps} is also defined. */
readonly progress?: boolean;
/** Verbosity level for logging. Default highest. */
readonly verbose?: Verbose;
}
Expand Down Expand Up @@ -133,6 +139,8 @@ export interface RolloutConfig {
/**
* Fraction of self-play games that should by played against the model's
* previous version rather than itself.
*
* The previous version is defined by the last {@link EvalConfig eval} step.
*/
readonly prev: number;
}
Expand Down Expand Up @@ -160,54 +168,53 @@ export interface PolicyConfig {
* proportion of actions to take randomly rather than consulting the model.
*/
readonly exploration: number;
/**
* Exploration (epsilon) decay factor. Applied after each full episode of
* training.
*/
/** Exploration (epsilon) decay factor. Applied after each learning step. */
readonly explorationDecay: number;
/** Minumum exploration (epsilon) value. */
readonly minExploration: number;
}

/** Configuration for the evaluation process. */
export interface EvalConfig {
/** Number of games to play against each eval opponent. */
readonly numGames: number;
/** Game pool config. */
readonly pool: GamePoolConfig;
/** Configuration for learning on experience generated from rollout games. */
export interface ExperienceConfig {
/** Discount factor for future rewards. */
readonly rewardDecay: number;
/** Size of the experience replay buffer. */
readonly bufferSize: number;
/**
* Minimum number of experiences to generate before starting training. Must
* be at least as big as the batch size.
*/
readonly prefill: number;
}

/** Configuration for the learning process. */
export interface LearnConfig {
/** Number of batch updates before starting the next episode. */
readonly updates: number;
/** Optimizer learning rate. */
readonly learningRate: number;
/** Replay buffer config. */
readonly buffer: BufferConfig;
/** Experience config. */
readonly experience: ExperienceConfig;
/** Batch size. */
readonly batchSize: number;
/**
* Whether to use a target network to increase training stability, or
* `"double"` to implement double Q learning approach using the target net.
*/
readonly target?: boolean | "double";
}

/** Configuration for the experience replay buffer. */
export interface BufferConfig {
/** Number of experiences to buffer for shuffling. */
readonly shuffle: number;
/** Batch size for learning updates. */
readonly batch: number;
/** Number of batches to prefetch for learning. */
readonly prefetch: number;
readonly target: boolean | "double";
/** Step interval for updating the target network. */
readonly targetInterval: number;
/**
* Step interval for tracking expensive batch update model metrics such as
* histograms which can significanly slow down training.
*/
readonly metricsInterval: number;
}

/** Configuration for learning on experience generated from rollout games. */
export interface ExperienceConfig {
/** Discount factor for future rewards. */
readonly rewardDecay: number;
/** Configuration for the evaluation process. */
export interface EvalConfig {
/** Number of games to play against each eval opponent. */
readonly numGames: number;
/** Game pool config. */
readonly pool: GamePoolConfig;
/** Step interval for performing model evaluations. */
readonly interval: number;
}

/** Configuration for random number generators in the training script. */
Expand Down
7 changes: 4 additions & 3 deletions src/demo/compare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import {
GameArgsGenOptions,
GameArgsGenSeeders,
GamePipeline,
} from "../train/game/pool";
import {ModelWorker} from "../train/model/worker";
} from "../game/pool";
import {ModelWorker} from "../model/worker";
import {Logger} from "../util/logging/Logger";
import {Verbose} from "../util/logging/Verbose";
import {ensureDir} from "../util/paths/ensureDir";
Expand Down Expand Up @@ -198,8 +198,9 @@ void (async function () {
);
}
});
await games.close();
} finally {
await games.cleanup();
await games.terminate();
progressBar.terminate();
}

Expand Down
6 changes: 4 additions & 2 deletions src/demo/train.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import {join} from "path";
import {setGracefulCleanup} from "tmp-promise";
import {config} from "../config";
import {ModelWorker} from "../model/worker";
import {TrainingProgress} from "../train/TrainingProgress";
import {ModelWorker} from "../train/model/worker";
import {formatUptime} from "../util/format";
import {Logger} from "../util/logging/Logger";
import {Verbose} from "../util/logging/Verbose";
import {ensureDir} from "../util/paths/ensureDir";
Expand Down Expand Up @@ -70,7 +71,7 @@ void (async function () {
);
}

const trainProgress = new TrainingProgress(config, logger);
const trainProgress = new TrainingProgress(config.train, logger);
try {
await models.train(
model,
Expand All @@ -84,6 +85,7 @@ void (async function () {
await models.unload(model);
await models.close();
trainProgress.done();
logger.info("Uptime: " + formatUptime(process.uptime()));
logger.info("Done");
}
})();
13 changes: 13 additions & 0 deletions src/game/agent/random.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import {Choice} from "../../psbot/handlers/battle/agent";
import {ReadonlyBattleState} from "../../psbot/handlers/battle/state";
import {Rng, shuffle} from "../../util/random";

/** BattleAgent that chooses actions randomly. */
export async function randomAgent(
state: ReadonlyBattleState,
choices: Choice[],
random?: Rng,
): Promise<void> {
shuffle(choices, random);
return await Promise.resolve();
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import {BattleAgent} from "../../../psbot/handlers/battle/agent";

/** BattleAgent decision data. */
export interface ExperienceAgentData {
/** State in which the action was taken. Flattened array data. */
state: Float32Array[];
}

/** BattleAgent type that emits partial Experience objects. */
export type ExperienceAgent = BattleAgent<ExperienceAgentData>;
import {BattleAgent} from "../../psbot/handlers/battle/agent";

/**
* BattleAgent decision evaluation data. Can be processed in batches to
* effectively train a neural network.
*/
export interface Experience extends ExperienceAgentData {
/** ID of the Choice that was taken. */
export interface Experience {
/** State in which the action was taken. Flattened array data. */
state: Float32Array[];
/** Id of the action that was taken. */
action: number;
/** Reward gained from the action and state transition. */
reward: number;
Expand All @@ -23,3 +16,9 @@ export interface Experience extends ExperienceAgentData {
/** Marks {@link nextState} as a terminal state so it won't be processed. */
done: boolean;
}

/** BattleAgent with additional args for experience generation. */
export type ExperienceBattleAgent<TInfo = unknown> = BattleAgent<
TInfo,
[lastAction?: number, reward?: number]
>;
File renamed without changes.
25 changes: 25 additions & 0 deletions src/game/experience/tensor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import * as tf from "@tensorflow/tfjs";
import {Experience} from "./Experience";

/** {@link Experience} with values converted to {@link tf.Tensor tensors}. */
export type TensorExperience = {
[T in keyof Experience]: Experience[T] extends number | boolean
? tf.Scalar
: Experience[T] extends Float32Array[]
? tf.Tensor[]
: never;
};

/**
* Batched {@link Experience} stacked {@link tf.Tensor tensors}.
*
* Essentially a list of {@link TensorExperience}s but with values converted to
* stacked tensors.
*/
export type BatchTensorExperience = {
[T in keyof Experience]: Experience[T] extends number | boolean
? tf.Tensor1D
: Experience[T] extends Float32Array[]
? tf.Tensor[]
: never;
};
Loading

0 comments on commit a979442

Please sign in to comment.