# Training validation

Full pytorch-ref based training, for loss curve comparision / validation between the various kernel implements

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')
sys.path.append('../../test')

# Import the model classes
from rwkv_block.v7_goose.model.rwkv7_goose_model import RWKV7GooseModel
from trainer.SimpleTestTrainer import SimpleTestTrainer

# Device to run on
RUN_DEVICE="cuda:0"

# If multiple cuda devices are available
# we use the respective device, so that I can run multiple notebooks in parallel
#
# Comment out this logic if you intend to manually set the device
if torch.cuda.device_count() >= 8:
    RUN_DEVICE="cuda:1"

# Model shape and size
LAYER_COUNT = 12
DIM_SIZE = 512
TMIX_BACKEND="pytorch"

# Create and initalize the model
model = RWKV7GooseModel({
    "n_layer": LAYER_COUNT,
    "n_dim": DIM_SIZE,
    "tmix_backend": TMIX_BACKEND,
    "device": RUN_DEVICE,
    "dtype": "bfloat16",
    "n_vocab": 50432
})
model.init_parameters()

# Setup the trainer
trainer = SimpleTestTrainer(model, device=RUN_DEVICE)

# Trigger the train process
trainer.train()

  from .autonotebook import tqdm as notebook_tqdm


---------------------------------------------
[SimpleTestTrainer] Initializing the trainer for:  RWKV-Block.SimpleTestTrainer
- hf_dataset:          teven/enwiki_100k
- dataset_ctx_length:  4096
- dataset_min_length:  4096
- tokenizer_name:      EleutherAI/gpt-neox-20b
- batch_size:          1
- learning_rate:       0.001
- num_epochs:          1
---------------------------------------------
[SimpleTestTrainer] Loading the tokenizer:  EleutherAI/gpt-neox-20b ...
[SimpleTestTrainer] Loading the dataset:  teven/enwiki_100k ...
[SimpleTestTrainer] Preparing the training dataset...
[SimpleTestTrainer] Training dataset size:    999000
[SimpleTestTrainer] Validation dataset size:  1000
[SimpleTestTrainer] Preparing the data loaders...
[SimpleTestTrainer] Training batch count:    999000
[SimpleTestTrainer] Validation batch count:  1000
[SimpleTestTrainer] Setting up the optimizer, loss function...
[SimpleTestTrainer] Initializing wandb...
[SimpleTestTrainer] skipping wandb - not initialized.


Training:   0%|          | 138/999000 [57:20<6917:30:46, 24.93s/it, loss=0.0615]


KeyboardInterrupt: 