Skip to content

Commit

Permalink
Test TensorBoardProjector callback (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and MattPainter01 committed Jul 19, 2018
1 parent 5f04b8e commit 7f65bb6
Showing 1 changed file with 141 additions and 1 deletion.
142 changes: 141 additions & 1 deletion tests/callbacks/test_tensor_board.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest.mock import patch, Mock, ANY

import torchbearer
from torchbearer.callbacks import TensorBoard, TensorBoardImages
from torchbearer.callbacks import TensorBoard, TensorBoardImages, TensorBoardProjector
import torch
import torch.nn as nn

Expand Down Expand Up @@ -220,3 +220,143 @@ def test_odd_batches(self, mock_board, mock_grid):
mock_grid.assert_called_once_with(ANY, nrow=9, padding=3, normalize=True, range='tmp', scale_each=True, pad_value=1)
mock_board.return_value.add_image.assert_called_once_with('test', 10, 1)
self.assertTrue(mock_grid.call_args[0][0].size() == torch.ones(40, 3, 10, 10).size())


class TestTensorBoardProjector(TestCase):
@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_log_dir(self, mock_board):
state = {torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardProjector(log_dir='./test', comment='torchbearer')
tboard.on_start(state)

mock_board.assert_called_once_with(log_dir=os.path.join('./test', 'Sequential_torchbearer'))

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_writer_closed_on_end(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.close = Mock()

state = {torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3))}

tboard = TensorBoardProjector()
tboard.on_start(state)
tboard.on_end({})
mock_board.return_value.close.assert_called_once()

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_simple_case(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=False, features_key=torchbearer.Y_TRUE)

tboard.on_start(state)
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size())

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_multi_epoch(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=False, features_key=torchbearer.Y_TRUE)

tboard.on_start(state)
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size())

tboard.on_end_epoch({})
mock_board.return_value.add_embedding.reset_mock()

tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features',
global_step=0)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == state[torchbearer.Y_TRUE].unsqueeze(1).size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size())

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_multi_batch(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=45, avg_data_channels=False, write_data=False, features_key=torchbearer.Y_TRUE)

tboard.on_start(state)
for i in range(3):
state[torchbearer.BATCH] = i
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='features', global_step=0)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == torch.Size([45, 1]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == torch.Size([45]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == torch.Size([45, 3, 10, 10]))

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_multi_batch_data(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=45, avg_data_channels=False, write_data=True, write_features=False)

tboard.on_start(state)
for i in range(3):
state[torchbearer.BATCH] = i
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='data', global_step=-1)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == torch.Size([45, 300]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == torch.Size([45]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == torch.Size([45, 3, 10, 10]))

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_channel_average(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 3, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=18, avg_data_channels=True, write_data=True, write_features=False)

tboard.on_start(state)
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='data', global_step=-1)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == torch.Size([18, 100]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == state[torchbearer.X].size())

@patch('torchbearer.callbacks.tensor_board.SummaryWriter')
def test_no_channels(self, mock_board):
mock_board.return_value = Mock()
mock_board.return_value.add_embedding = Mock()

state = {torchbearer.X: torch.ones(18, 10, 10), torchbearer.EPOCH: 0, torchbearer.MODEL: nn.Sequential(nn.Conv2d(3, 3, 3)), torchbearer.Y_TRUE: torch.ones(18), torchbearer.BATCH: 0}

tboard = TensorBoardProjector(num_images=18, avg_data_channels=False, write_data=True, write_features=False)

tboard.on_start(state)
tboard.on_step_validation(state)

mock_board.return_value.add_embedding.assert_called_once_with(ANY, metadata=ANY, label_img=ANY, tag='data', global_step=-1)
self.assertTrue(mock_board.return_value.add_embedding.call_args[0][0].size() == torch.Size([18, 100]))
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['metadata'].size() == state[torchbearer.Y_TRUE].size())
self.assertTrue(mock_board.return_value.add_embedding.call_args[1]['label_img'].size() == torch.Size([18, 1, 10, 10]))

0 comments on commit 7f65bb6

Please sign in to comment.