learn2learn is a PyTorch library for meta-learning implementations. It was developed during the first PyTorch Hackathon. Edit: L2L was lucky to win the hackathon!
Note learn2learn is under active development and many things are breaking.
pip install learn2learn
import learn2learn as l2l mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True) task_generator = l2l.data.TaskGenerator(mnist, ways=3) model = Net() maml = l2l.MAML(model, lr=1e-3, first_order=False) opt = optim.Adam(maml.parameters(), lr=4e-3) for iteration in range(num_iterations): learner = maml.clone() # Creates a clone of model task = task_generator.sample(shots=1) # Fast adapt for step in range(adaptation_steps): error = compute_loss(task) learner.adapt(error) # Compute validation loss valid_task = task_generator.sample(shots=1, classes_to_sample=task.sampled_classes) valid_error = compute_loss(valid_task) # Take the meta-learning step opt.zero_grad() valid_error.backward() opt.step()
The following changelog is mostly for the hackathon period.
August 15, 2019
- Algorithm cleanup.
- Added BaseLearner with unified API.
- Support for vectorized environments.
August 12, 2019
- Basic implementation of MAML, FOMAML, Meta-SGD.
- TaskGenerator code for classification tasks.
- Environments for RL.
- Small scale examples of MAML-A2C and MAML-PPO.
- The RL environments are copied from: https://github.com/tristandeleu/pytorch-maml-rl