See our example constructing language models from scratch in a GPU-backed colab notebook at Trax Demo
python -m trax.trainer \
--dataset=mnist \
--model=MLP \
--config="train.train_steps=1000"
python -m trax.trainer \
--config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin
python -m trax.trainer \
--config_file=transformer_lm1b_8gb.gin