The original spiral paper (https://arxiv.org/abs/1804.01118) training pipeline implementation in pytorch and ray
2 GPUs. One for policy learning and one for discriminator learning.
Note that this training pipeline is for a single machine. Population-based exploration of hyperparameters (PBT) is not implemented.
- Install https://github.com/deepmind/spiral following the instructions (need the libmypaint environment)
- Copy all python scripts here to spiral/
- Download some data. Look at real_image_loader.py for dataset location/format
- Adjust hyperparameters in config.py
- Run python spiral_torch.py
15000 policy training steps on digit 4 in mnist (each training step is n_batches * n_timesteps, or 64*10):
- In the original paper, discriminator trains faster than policy because of network structure.
However in this implementation, discriminator trains faster because policy spends most of the time waiting for batches from painter agents. - Policy trains on each trajectory for only once (which causes 1.) so that training is on-policy. But in the paper, they describe the training as off-policy.
- https://github.com/werner-duvaud/muzero-general. I learned about ray here.
- https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py. I used the wgan-gp implementation here.