Skip to content

Commit

Permalink
Feature/python2.7 (#496)
Browse files Browse the repository at this point in the history
* Add support for Python 2.7

* Update travis.yml

* Remove unused dependency

* Add missing callback list test

* Add missing Tqdm test

* Add missing Metric test

* Fix t-board tests

* Update setup.py
  • Loading branch information
ethanwharris committed Jan 25, 2019
1 parent f8e4353 commit 78b939c
Show file tree
Hide file tree
Showing 62 changed files with 564 additions and 2,257 deletions.
6 changes: 6 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
language: python
matrix:
include:
- python: "2.7"
env: TORCH_VERSION=0.4.0 TORCH_URL=http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl
- python: "2.7"
env: TORCH_VERSION=0.4.1 TORCH_URL=http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
- python: "2.7"
env: TORCH_VERSION=1.0.0 TORCH_URL=http://download.pytorch.org/whl/cpu/torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
- python: "3.5"
env: TORCH_VERSION=0.4.0 TORCH_URL=http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl
- python: "3.5"
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
- Added support for rounding 1D lists to the Tqdm callback
- Added SimpleWeibull distribution
- Added support for Python 2.7
### Changed
- Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
- Tqdm precision argument now rounds to decimal places rather than significant figures
### Deprecated
### Removed
- Removed the old Model API (deprecated since version 0.2.0)
### Fixed
- Fixed a bug in the weight decay callback which would result in potentially negative decay (now just uses torch.norm)
- Fixed a bug in the cite decorator causing the citation to not show up correctly
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<img src="https://raw.githubusercontent.com/ecs-vlc/torchbearer/master/docs/_static/img/logo_dark_text.svg?sanitize=true" width="100%"/>

[![PyPI version](https://badge.fury.io/py/torchbearer.svg)](https://badge.fury.io/py/torchbearer) [![Python 3.5 | 3.6 | 3.7](https://img.shields.io/badge/python-3.5%20%7C%203.6%20%7C%203.7-brightgreen.svg)](https://www.python.org/) [![PyTorch 0.4.0 | 0.4.1 | 1.0.0](https://img.shields.io/badge/pytorch-0.4.0%20%7C%200.4.1%20%7C%201.0.0-brightgreen.svg)](https://pytorch.org/) [![Build Status](https://travis-ci.com/ecs-vlc/torchbearer.svg?branch=master)](https://travis-ci.com/ecs-vlc/torchbearer) [![codecov](https://codecov.io/gh/ecs-vlc/torchbearer/branch/master/graph/badge.svg)](https://codecov.io/gh/ecs-vlc/torchbearer) [![Documentation Status](https://readthedocs.org/projects/torchbearer/badge/?version=latest)](https://torchbearer.readthedocs.io/en/latest/?badge=latest) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/8c9b136fbcd443fa9135d92321be480d)](https://www.codacy.com/app/ewah1g13/torchbearer?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=ecs-vlc/torchbearer&amp;utm_campaign=Badge_Grade)
[![PyPI version](https://badge.fury.io/py/torchbearer.svg)](https://badge.fury.io/py/torchbearer) [![Python 2.7 | 3.5 | 3.6 | 3.7](https://img.shields.io/badge/python-2.7%20%7C%203.5%20%7C%203.6%20%7C%203.7-brightgreen.svg)](https://www.python.org/) [![PyTorch 0.4.0 | 0.4.1 | 1.0.0](https://img.shields.io/badge/pytorch-0.4.0%20%7C%200.4.1%20%7C%201.0.0-brightgreen.svg)](https://pytorch.org/) [![Build Status](https://travis-ci.com/ecs-vlc/torchbearer.svg?branch=master)](https://travis-ci.com/ecs-vlc/torchbearer) [![codecov](https://codecov.io/gh/ecs-vlc/torchbearer/branch/master/graph/badge.svg)](https://codecov.io/gh/ecs-vlc/torchbearer) [![Documentation Status](https://readthedocs.org/projects/torchbearer/badge/?version=latest)](https://torchbearer.readthedocs.io/en/latest/?badge=latest) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/8c9b136fbcd443fa9135d92321be480d)](https://www.codacy.com/app/ewah1g13/torchbearer?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=ecs-vlc/torchbearer&amp;utm_campaign=Badge_Grade)

A model fitting library for PyTorch
## Contents
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ tqdm
tensorboardX>=1.4
visdom
livelossplot
mock
16 changes: 10 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
version_dict = {}
exec(open("./torchbearer/version.py").read(), version_dict)

from os import path
this_directory = path.abspath(path.dirname(__file__))
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
import sys
if sys.version_info[0] >= 3:
from os import path
this_directory = path.abspath(path.dirname(__file__))
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
else:
long_description = 'A model training and variational auto-encoder library for pytorch'

setup(
name='torchbearer',
Expand All @@ -17,9 +21,9 @@
license='GPL-3.0',
author='Matt Painter',
author_email='mp2u16@ecs.soton.ac.uk',
description='A model training library for pytorch',
description='A model training and variational auto-encoder library for pytorch',
long_description=long_description,
long_description_content_type='text/markdown',
install_requires=['torch>=0.4', 'torchvision', 'tqdm'],
python_requires='>=3.5',
python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
)
22 changes: 19 additions & 3 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from unittest import TestCase
from unittest.mock import MagicMock, Mock
from mock import MagicMock, Mock

import torchbearer
from torchbearer.callbacks import CallbackList, Tqdm, TensorBoard


class TestCallbackList(TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName)
super(TestCallbackList, self).__init__(methodName)
self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm())
self.callback_2 = MagicMock(spec=torchbearer.callbacks.tensor_board.TensorBoard())
callbacks = [self.callback_1, self.callback_2]
Expand Down Expand Up @@ -42,8 +42,12 @@ def test_load_state_dict(self):
state = self.list.state_dict()
state[CallbackList.CALLBACK_TYPES] = list(reversed(state[CallbackList.CALLBACK_TYPES]))

with self.assertWarns(UserWarning, msg='Callback classes did not match, expected: {\'TensorBoard\', \'Tqdm\'}'):
import warnings
with warnings.catch_warnings(record=True) as w:
self.list.load_state_dict(state)
self.assertTrue(len(w) == 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertTrue('Callback classes did not match, expected: [\'TensorBoard\', \'Tqdm\']' in str(w[-1].message))

def test_for_list(self):
self.list.on_start({})
Expand All @@ -55,3 +59,15 @@ def test_list_in_list(self):
clist = CallbackList([callback])
clist2 = CallbackList([clist])
self.assertTrue(clist2.callback_list[0] == 'test')

def test_iter_copy(self):
callback = 'test'
clist = CallbackList([callback])
cpy = clist.__copy__()
self.assertTrue(cpy.callback_list[0] == 'test')
self.assertTrue(cpy is not clist)
cpy = clist.copy()
self.assertTrue(cpy.callback_list[0] == 'test')
self.assertTrue(cpy is not clist)
for cback in clist:
self.assertTrue(cback == 'test')
16 changes: 8 additions & 8 deletions tests/callbacks/test_checkpointers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unittest import TestCase
from unittest.mock import patch, Mock
from mock import patch, Mock

import torchbearer
from torchbearer import Model
from torchbearer import Trial
from torchbearer.callbacks.checkpointers import _Checkpointer, ModelCheckpoint, MostRecent, Interval, Best


Expand All @@ -17,7 +17,7 @@ def test_save_checkpoint_save_filename(self, mock_save):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.METRICS: {}
}

Expand All @@ -33,7 +33,7 @@ def test_save_checkpoint_formatting(self, mock_save):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.METRICS: {},
torchbearer.EPOCH: 2
}
Expand All @@ -50,7 +50,7 @@ def test_save_checkpoint_formatting_metric(self, mock_save):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.METRICS: {'test_metric': 0.001},
torchbearer.EPOCH: 2
}
Expand All @@ -67,7 +67,7 @@ def test_save_checkpoint_subformatting(self, mock_save):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.METRICS: {'test_metric': 0.001},
torchbearer.EPOCH: 2
}
Expand All @@ -84,7 +84,7 @@ def test_save_checkpoint_wrong_format(self, _):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.METRICS: {'test_metric': 0.001},
torchbearer.EPOCH: 2
}
Expand All @@ -104,7 +104,7 @@ def test_save_checkpoint_overwrite_recent(self, _, __):
torchmodel = Mock()
optim = Mock()
state = {
torchbearer.SELF: Model(torchmodel, optim, None, []),
torchbearer.SELF: Trial(torchmodel, optim, None, []),
torchbearer.EPOCH: 0,
torchbearer.METRICS: {}
}
Expand Down
22 changes: 13 additions & 9 deletions tests/callbacks/test_csv_logger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from unittest import TestCase
from unittest.mock import patch, mock_open
from mock import patch, mock_open

import torchbearer
from torchbearer.callbacks import CSVLogger


class TestCSVLogger(TestCase):

@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_write_header(self, mock_open):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -26,7 +26,7 @@ def test_write_header(self, mock_open):
self.assertTrue('test_metric_1' in header)
self.assertTrue('test_metric_2' in header)

@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_write_no_header(self, mock_open):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -45,7 +45,7 @@ def test_write_no_header(self, mock_open):
self.assertTrue('test_metric_1' not in header)
self.assertTrue('test_metric_2' not in header)

@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_csv_closed(self, mock_open):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -61,7 +61,7 @@ def test_csv_closed(self, mock_open):

self.assertTrue(mock_open.return_value.close.called)

@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_append(self, mock_open):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -75,9 +75,13 @@ def test_append(self, mock_open):
logger.on_end_epoch(state)
logger.on_end(state)

self.assertTrue(mock_open.call_args[0][1] == 'a+')
import sys
if sys.version_info[0] < 3:
self.assertTrue(mock_open.call_args[0][1] == 'ab')
else:
self.assertTrue(mock_open.call_args[0][1] == 'a')

@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_get_field_dict(self, mock_open):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -98,7 +102,7 @@ def test_get_field_dict(self, mock_open):
self.assertDictEqual(logger_fields_dict, correct_fields_dict)

@patch('torchbearer.callbacks.CSVLogger._write_to_dict')
@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_write_on_epoch(self, mock_open, mock_write):
state = {
torchbearer.EPOCH: 0,
Expand All @@ -115,7 +119,7 @@ def test_write_on_epoch(self, mock_open, mock_write):
self.assertEqual(mock_write.call_count, 1)

@patch('torchbearer.callbacks.CSVLogger._write_to_dict')
@patch("builtins.open", new_callable=mock_open)
@patch("torchbearer.callbacks.csv_logger.open", new_callable=mock_open)
def test_batch_granularity(self, mock_open, mock_write):
state = {
torchbearer.EPOCH: 0,
Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
from unittest.mock import patch, Mock
from mock import patch, Mock

import torchbearer
from torchbearer.callbacks import GradientNormClipping, GradientClipping
Expand All @@ -20,7 +20,7 @@ def test_not_given_params(self, mock_clip):
clipper.on_start(state)
clipper.on_backward(state)

self.assertTrue(next(mock_clip.mock_calls[0][1][0])() == -1)
self.assertTrue(next(iter(mock_clip.mock_calls[0][1][0]))() == -1)

@patch('torch.nn.utils.clip_grad_norm_')
def test_given_params(self, mock_clip):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_not_given_params(self, mock_clip):
clipper.on_start(state)
clipper.on_backward(state)

self.assertTrue(next(mock_clip.mock_calls[0][1][0])() == -1)
self.assertTrue(next(iter(mock_clip.mock_calls[0][1][0]))() == -1)

@patch('torch.nn.utils.clip_grad_value_')
def test_given_params(self, mock_clip):
Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_live_loss_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase
from unittest.mock import patch, Mock, MagicMock

from mock import patch, MagicMock

import torchbearer as tb
from torchbearer.callbacks import LiveLossPlot
Expand Down
38 changes: 31 additions & 7 deletions tests/callbacks/test_printer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase
from unittest.mock import patch, Mock, MagicMock

from mock import patch, MagicMock

import torchbearer
from torchbearer.callbacks import Tqdm, ConsolePrinter
Expand All @@ -23,7 +24,7 @@ def test_not_string(self):


class TestConsolePrinter(TestCase):
@patch('builtins.print')
@patch('torchbearer.callbacks.printer.print')
def test_console_printer(self, mock_print):
state = {torchbearer.BATCH: 5, torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: {'test': 0.99456}}
printer = ConsolePrinter(validation_label_letter='e')
Expand All @@ -49,7 +50,7 @@ def test_console_printer(self, mock_print):

class TestTqdm(TestCase):
def test_tqdm(self):
state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: {'test': 10}}
state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: {'test': 0.99456}}
tqdm = Tqdm(validation_label_letter='e')
tqdm.tqdm_module = MagicMock()
mock_tqdm = tqdm.tqdm_module
Expand All @@ -59,12 +60,12 @@ def test_tqdm(self):
mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

tqdm.on_step_training(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=10')
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
mock_tqdm.return_value.update.assert_called_once_with(1)
mock_tqdm.return_value.set_postfix_str.reset_mock()

tqdm.on_end_training(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=10')
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
self.assertEqual(mock_tqdm.return_value.close.call_count, 1)

mock_tqdm.reset_mock()
Expand All @@ -77,12 +78,12 @@ def test_tqdm(self):
mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

tqdm.on_step_validation(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=10')
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
mock_tqdm.return_value.update.assert_called_once_with(1)
mock_tqdm.return_value.set_postfix_str.reset_mock()

tqdm.on_end_validation(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=10')
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
self.assertEqual(mock_tqdm.return_value.close.call_count, 1)

def test_tqdm_custom_args(self):
Expand All @@ -105,3 +106,26 @@ def test_tqdm_custom_args(self):

tqdm.on_start(state)
mock_tqdm.assert_called_once_with(initial=1, total=10, ascii=True)

def test_tqdm_on_epoch(self):
state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.HISTORY: [0, (1, {'test': 0.99456})],
torchbearer.METRICS: {'test': 0.99456}}
tqdm = Tqdm(validation_label_letter='e', on_epoch=True)
tqdm.tqdm_module = MagicMock()
mock_tqdm = tqdm.tqdm_module

tqdm.on_start(state)
mock_tqdm.assert_called_once_with(initial=2, total=10)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
mock_tqdm.return_value.update.assert_called_once_with(1)
mock_tqdm.return_value.set_postfix_str.reset_mock()
mock_tqdm.return_value.update.reset_mock()

tqdm.on_end_epoch(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
mock_tqdm.return_value.update.assert_called_once_with(1)
mock_tqdm.return_value.set_postfix_str.reset_mock()

tqdm.on_end(state)
mock_tqdm.return_value.set_postfix_str.assert_called_once_with('test=0.9946')
self.assertEqual(mock_tqdm.return_value.close.call_count, 1)

0 comments on commit 78b939c

Please sign in to comment.