Skip to content

Commit

Permalink
Fix/torchbearer imports (#558)
Browse files Browse the repository at this point in the history
* Remove all torchbearer as tb imports

* Update notebooks
  • Loading branch information
MattPainter01 committed May 30, 2019
1 parent 2c065d8 commit 4ddbc6b
Show file tree
Hide file tree
Showing 14 changed files with 497 additions and 493 deletions.
28 changes: 14 additions & 14 deletions docs/_static/notebooks/amsgrad.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"import torch\n",
"from torch.nn import Module\n",
"\n",
"import torchbearer as tb\n",
"import torchbearer\n",
"\n",
"if 'tensorboardX' in sys.modules:\n",
" from torchbearer.callbacks import TensorBoard\n",
Expand Down Expand Up @@ -88,7 +88,7 @@
" function to be minimised:\n",
" f(x) = 1010x if t mod 101 = 1, else -10x\n",
" \"\"\"\n",
" if state[tb.BATCH] % 101 == 1:\n",
" if state[torchbearer.BATCH] % 101 == 1:\n",
" res = 1010 * self.x\n",
" else:\n",
" res = -10 * self.x\n",
Expand Down Expand Up @@ -120,13 +120,13 @@
" return y_pred\n",
"\n",
"\n",
"@tb.metrics.to_dict\n",
"class est(tb.metrics.Metric):\n",
"@torchbearer.metrics.to_dict\n",
"class est(torchbearer.metrics.Metric):\n",
" def __init__(self):\n",
" super().__init__('est')\n",
"\n",
" def process(self, state):\n",
" return state[tb.MODEL].x.data"
" return state[torchbearer.MODEL].x.data"
]
},
{
Expand All @@ -150,12 +150,12 @@
},
"outputs": [],
"source": [
"@tb.callbacks.on_step_training\n",
"@torchbearer.callbacks.on_step_training\n",
"def greedy_update(state):\n",
" if state[tb.MODEL].x > 1:\n",
" state[tb.MODEL].x.data.fill_(1)\n",
" elif state[tb.MODEL].x < -1:\n",
" state[tb.MODEL].x.data.fill_(-1)"
" if state[torchbearer.MODEL].x > 1:\n",
" state[torchbearer.MODEL].x.data.fill_(1)\n",
" elif state[torchbearer.MODEL].x < -1:\n",
" state[torchbearer.MODEL].x.data.fill_(-1)"
]
},
{
Expand Down Expand Up @@ -201,12 +201,12 @@
"\n",
"model = Online()\n",
"optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99])\n",
"tbtrial = tb.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='adam', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial = torchbearer.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='adam', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial.for_train_steps(training_steps).run()\n",
"\n",
"model = Online()\n",
"optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99], amsgrad=True)\n",
"tbtrial = tb.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='amsgrad', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial = torchbearer.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='amsgrad', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial.for_train_steps(training_steps).run()"
]
},
Expand Down Expand Up @@ -305,12 +305,12 @@
"source": [
"model = Stochastic()\n",
"optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99])\n",
"tbtrial = tb.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='adam', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial = torchbearer.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='adam', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial.for_train_steps(training_steps).run()\n",
"\n",
"model = Stochastic()\n",
"optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.99], amsgrad=True)\n",
"tbtrial = tb.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='amsgrad', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial = torchbearer.Trial(model, optim, loss, [est()], callbacks=[greedy_update, TensorBoard(comment='amsgrad', write_graph=False, write_batch_metrics=True, write_epoch_metrics=False)])\n",
"tbtrial.for_train_steps(training_steps).run()"
]
},
Expand Down
10 changes: 7 additions & 3 deletions docs/_static/notebooks/basic_opt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
"import torch\n",
"from torch.nn import Module\n",
"\n",
"import torchbearer as tb"
"try:\n",
" import torchbearer\n",
"except:\n",
" !pip install torchbearer\n",
" import torchbearer"
]
},
{
Expand All @@ -58,7 +62,7 @@
},
"outputs": [],
"source": [
"ESTIMATE = tb.state_key('est')\n",
"ESTIMATE = torchbearer.state_key('est')\n",
"\n",
"\n",
"class Net(Module):\n",
Expand Down Expand Up @@ -187,7 +191,7 @@
},
"outputs": [],
"source": [
"tbtrial = tb.Trial(model, optim, loss, [tb.metrics.running_mean(ESTIMATE, dim=1), 'loss'])\n",
"tbtrial = torchbearer.Trial(model, optim, loss, [torchbearer.metrics.running_mean(ESTIMATE, dim=1), 'loss'])\n",
"tbtrial.for_train_steps(training_steps).to('cuda')\n",
"tbtrial.run()\n",
"print(list(model.parameters())[0].data)"
Expand Down
24 changes: 12 additions & 12 deletions docs/_static/notebooks/gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
"from torchvision.utils import save_image\n",
"\n",
"try:\n",
" import torchbearer as tb\n",
" import torchbearer\n",
"except:\n",
" !pip install torchbearer\n",
" import torchbearer as tb\n",
" import torchbearer\n",
" \n",
"import torchbearer.callbacks as callbacks\n",
"from torchbearer import state_key\n",
Expand Down Expand Up @@ -171,7 +171,7 @@
" )\n",
"\n",
" def forward(self, real_imgs, state):\n",
" z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[tb.DEVICE])\n",
" z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[torchbearer.DEVICE])\n",
" img = self.model(z)\n",
" img = img.view(img.size(0), *img_shape)\n",
" return img\n",
Expand Down Expand Up @@ -249,14 +249,14 @@
"outputs": [],
"source": [
"def gen_crit(state):\n",
" loss = adversarial_loss(state[DISC_MODEL](state[tb.Y_PRED], state), valid)\n",
" loss = adversarial_loss(state[DISC_MODEL](state[torchbearer.Y_PRED], state), valid)\n",
" state[G_LOSS] = loss\n",
" return loss\n",
"\n",
"\n",
"def disc_crit(state):\n",
" real_loss = adversarial_loss(state[DISC_MODEL](state[tb.X], state), valid)\n",
" fake_loss = adversarial_loss(state[DISC_MODEL](state[tb.Y_PRED].detach(), state), fake)\n",
" real_loss = adversarial_loss(state[DISC_MODEL](state[torchbearer.X], state), valid)\n",
" fake_loss = adversarial_loss(state[DISC_MODEL](state[torchbearer.Y_PRED].detach(), state), fake)\n",
" loss = (real_loss + fake_loss) / 2\n",
" state[D_LOSS] = loss\n",
" return loss"
Expand Down Expand Up @@ -326,8 +326,8 @@
"outputs": [],
"source": [
"from torchbearer.bases import base_closure\n",
"closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION, tb.LOSS, GEN_OPT)\n",
"closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, tb.LOSS, DISC_OPT)"
"closure_gen = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, GEN_OPT)\n",
"closure_disc = base_closure(torchbearer.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, torchbearer.LOSS, DISC_OPT)"
]
},
{
Expand Down Expand Up @@ -384,10 +384,10 @@
"os.makedirs('images', exist_ok=True)\n",
"\n",
"@callbacks.on_step_training\n",
"@callbacks.only_if(lambda state: state[tb.BATCH] % sample_interval == 0)\n",
"@callbacks.only_if(lambda state: state[torchbearer.BATCH] % sample_interval == 0)\n",
"def saver_callback(state):\n",
" samples = state[tb.MODEL](batch, state)\n",
" save_image(samples, 'images/%d.png' % state[tb.BATCH], nrow=5, normalize=True)"
" samples = state[torchbearer.MODEL](batch, state)\n",
" save_image(samples, 'images/%d.png' % state[torchbearer.BATCH], nrow=5, normalize=True)"
]
},
{
Expand Down Expand Up @@ -415,7 +415,7 @@
},
"outputs": [],
"source": [
"trial = tb.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback])\n",
"trial = torchbearer.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback])\n",
"trial.with_train_generator(dataloader, steps=200000)\n",
"_ = trial.to(device)"
]
Expand Down
4 changes: 2 additions & 2 deletions docs/_static/notebooks/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
"from torchvision import transforms\n",
"\n",
"try:\n",
" from torchbearer.cv_utils import DatasetValidationSplitter\n",
" from torchbearer.cv_utils import DatasetValidationSplitter\n",
"except:\n",
" !pip install torchbearer==0.3.0\n",
" !pip install torchbearer\n",
" from torchbearer.cv_utils import DatasetValidationSplitter"
]
},
Expand Down
6 changes: 3 additions & 3 deletions docs/_static/notebooks/vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
"from torchvision.utils import save_image\n",
"\n",
"try:\n",
" import torchbearer\n",
" import torchbearer\n",
"except:\n",
" !pip install torchbearer\n",
" import torchbearer\n",
" !pip install torchbearer\n",
" import torchbearer\n",
"from torchbearer.cv_utils import DatasetValidationSplitter"
]
},
Expand Down
12 changes: 6 additions & 6 deletions tests/callbacks/test_aggregate_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import TestCase
import torch
from torchbearer.callbacks import AggregatePredictions
import torchbearer as tb
import torchbearer


class TestAggregatePredictions(TestCase):
Expand All @@ -12,8 +12,8 @@ def test_aggreate_predictions(self):
y_pred_1 = torch.Tensor([1,2,3])
y_pred_2 = torch.Tensor([3,4,5])

state_1 = {tb.Y_PRED: y_pred_1}
state_2 = {tb.Y_PRED: y_pred_2}
state_1 = {torchbearer.Y_PRED: y_pred_1}
state_2 = {torchbearer.Y_PRED: y_pred_2}
final_state = {}

aggregator.on_step_validation(state_1)
Expand All @@ -24,16 +24,16 @@ def test_aggreate_predictions(self):

aggregate = torch.cat([y_pred_1, y_pred_2])
aggregator.on_end_validation(final_state)
self.assertTrue(list(final_state[tb.FINAL_PREDICTIONS].numpy()) == list(aggregate.numpy()))
self.assertTrue(list(final_state[torchbearer.FINAL_PREDICTIONS].numpy()) == list(aggregate.numpy()))

def test_aggreate_predictions_multiple_calls(self):
aggregator = AggregatePredictions()

y_pred_1 = torch.Tensor([1,2,3])
y_pred_2 = torch.Tensor([3,4,5])

state_1 = {tb.Y_PRED: y_pred_1}
state_2 = {tb.Y_PRED: y_pred_2}
state_1 = {torchbearer.Y_PRED: y_pred_1}
state_2 = {torchbearer.Y_PRED: y_pred_2}

aggregator.on_step_validation(state_1)
self.assertTrue(list(aggregator.predictions_list[0].numpy()) == list(y_pred_1.numpy()))
Expand Down
18 changes: 9 additions & 9 deletions tests/callbacks/test_live_loss_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mock import patch, MagicMock

import torchbearer as tb
import torchbearer
from torchbearer.callbacks import LiveLossPlot


Expand All @@ -17,7 +17,7 @@ def test_on_batch(self):
llp = LiveLossPlot(True, 1, False, False)
llp.batch_plt = MagicMock()
llp.plt = MagicMock()
state = {tb.BATCH: 1, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 1, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)
llp.on_step_training(state)

Expand All @@ -28,13 +28,13 @@ def test_on_batch_steps(self):
llp = LiveLossPlot(True, 2, False, False)
llp.batch_plt = MagicMock()
llp.plt = MagicMock()
state = {tb.BATCH: 1, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 1, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)
state = {tb.BATCH: 2, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 2, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)
state = {tb.BATCH: 3, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 3, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)
state = {tb.BATCH: 4, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 4, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)

self.assertTrue(llp.batch_plt.draw.call_count == 2)
Expand All @@ -44,7 +44,7 @@ def test_not_on_batch(self):
llp = LiveLossPlot(False, 10, False, False)
llp.batch_plt = MagicMock()
llp.plt = MagicMock()
state = {tb.BATCH: 1, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 1, torchbearer.METRICS: {'test': 1}}
llp.on_step_training(state)
llp.on_step_training(state)

Expand All @@ -54,7 +54,7 @@ def test_on_epoch(self):
llp = LiveLossPlot(False, 10, True, False)
llp.batch_plt = MagicMock()
llp.plt = MagicMock()
state = {tb.BATCH: 1, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 1, torchbearer.METRICS: {'test': 1}}
llp.on_end_epoch(state)
llp.on_end_epoch(state)

Expand All @@ -65,7 +65,7 @@ def test_draw_once(self):
llp = LiveLossPlot(True, 1, True, True)
llp.batch_plt = MagicMock()
llp.plt = MagicMock()
state = {tb.BATCH: 1, tb.METRICS: {'test': 1}}
state = {torchbearer.BATCH: 1, torchbearer.METRICS: {'test': 1}}
llp.on_end_epoch(state)
llp.on_end_epoch(state)
llp.on_end(state)
Expand Down
14 changes: 7 additions & 7 deletions tests/callbacks/test_lr_finder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import TestCase
import torchbearer as tb
import torchbearer
from torchbearer.callbacks import lr_finder as lrf
import numpy as np

Expand Down Expand Up @@ -76,12 +76,12 @@ def test_end_to_end(self):
], lr=1e-2, momentum=0.9)

clr = lrf.CyclicLR(step_size=75, base_lr=[0.001, 0.0001], max_lr=[0.006, 0.0006])
clr.on_start({tb.OPTIMIZER: optim})
clr.on_start({torchbearer.OPTIMIZER: optim})

lrs = []
for i in range(100):
clr.on_sample({tb.OPTIMIZER: optim})
clr.on_step_training({tb.OPTIMIZER: optim})
clr.on_sample({torchbearer.OPTIMIZER: optim})
clr.on_step_training({torchbearer.OPTIMIZER: optim})
for param_group in optim.param_groups:
lr = param_group['lr']
lrs.append(lr)
Expand All @@ -102,12 +102,12 @@ def test_end_to_end_2(self):
], lr=1e-2, momentum=0.9)

clr = lrf.CyclicLR(step_size=[75, 100], base_lr=0.001, max_lr=0.006)
clr.on_start({tb.OPTIMIZER: optim})
clr.on_start({torchbearer.OPTIMIZER: optim})

lrs = []
for i in range(100):
clr.on_sample({tb.OPTIMIZER: optim})
clr.on_step_training({tb.OPTIMIZER: optim})
clr.on_sample({torchbearer.OPTIMIZER: optim})
clr.on_step_training({torchbearer.OPTIMIZER: optim})
for param_group in optim.param_groups:
lr = param_group['lr']
lrs.append(lr)
Expand Down

0 comments on commit 4ddbc6b

Please sign in to comment.