From b0cea9cff14892ae19d602ab3d2e793aa18dbfc1 Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Thu, 14 Mar 2024 08:02:09 -0400 Subject: [PATCH] training config changes --- src/config.rs | 4 ++-- src/main.rs | 2 +- src/training.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index 8967586..a50f5cd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,11 +14,11 @@ impl Config { // Default configuration for initial evaluation, will add larger configs later after confirming valid output pub fn default() -> Self { Self { - dim: 256, + dim: 512, depth: 8, vocab_size: 32000, heads: 8, - ff_mult: 12, + ff_mult: 10, eps: 1e-6, ff_dropout: 0.1, seq_len: 100, diff --git a/src/main.rs b/src/main.rs index d4692dd..508f07a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,7 +61,7 @@ pub struct TrainingCmd { dataset: String, /// The batch size to use - #[arg(long, default_value = "2")] + #[arg(long, default_value = "1")] batch_size: usize, /// The learning rate to use diff --git a/src/training.rs b/src/training.rs index e747c32..f48d4d5 100644 --- a/src/training.rs +++ b/src/training.rs @@ -93,7 +93,7 @@ pub fn run(args: &TrainingCmd, common_args: &Args) -> Result<()> { let _enter = span.enter(); opt.backward_step(&loss)?; } - if batch_index > 0 && batch_index % 10 == 0 { + if batch_index > 0 && batch_index % 100 == 0 { let training_loss = f64::from(loss.to_vec0::()?); let validation_loss = valid_loss(args.seq_len, args.batch_size, &dataset, &mut model, &device)?;