Tools for NN creation with Pytorch. The project is greatly inspired by fast.ai library. The project brings the convenience functions from fast.ai to a usual pytorch user.
from pytorch_nn_tools.devices import to_device
from pytorch_nn_tools.train.checkpoint import CheckpointSaver
checkpoint_saver = CheckpointSaver(path_checkpoints, logger=DummyLogger())
# create your model, optimizer, scheduler
# train for several epochs
for epoch in range(100):
.....
# at the end of each epoch we save the results
checkpoint_saver.save(model, optimizer, scheduler, epoch)
# later you can load the latest checkpoint and continue training
last = checkpoint_saver.find_last(start_epoch, end_epoch)
if last is not None:
print(f"found pretrained results for epoch {last}. Loading...")
self.checkpoint_saver.load(model, optimizer, scheduler, last)
else:
print(f"pretrained weights are not found")
- Free software: MIT license
- Documentation: https://pytorch-nn-tools.readthedocs.io.