Skip to content

Commit

Permalink
Feature/dataset splitter (#245)
Browse files Browse the repository at this point in the history
* Add dataset validation splitter

* Commenting

* Update quickstart example and README for new train/val splitter

* Update CHANGELOG.md
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 24, 2018
1 parent a8752d0 commit 8e37753
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
- Added a on_validation_criterion callback hook
- Added a DatasetValidationSplitter which can be used to create a validation split if required for datasets like Cifar10 or MNIST
### Changed
### Deprecated
### Removed
Expand Down
26 changes: 20 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,22 @@ The easiest way to install torchbearer is with pip:

## Quickstart

- Define your data and model as usual (here we use a simple CNN on Cifar10):
- Define your data and model as usual (here we use a simple CNN on Cifar10). Note that we use torchbearers DatasetValidationSplitter here to create a validation set (10% of the data). This is essential to avoid [over-fitting to your test data](http://blog.kaggle.com/2012/07/06/the-dangers-of-overfitting-psychopathy-post-mortem/):

```python
BATCH_SIZE = 128

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

trainset = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True,
dataset = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]))
splitter = DatasetValidationSplitter(len(dataset), 0.1)
trainset = splitter.get_train_dataset(dataset)
valset = splitter.get_val_dataset(dataset)

traingen = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
valgen = torch.utils.data.DataLoader(valset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)


testset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,
Expand Down Expand Up @@ -93,20 +99,28 @@ loss = nn.CrossEntropyLoss()
from torchbearer import Model

torchbearer_model = Model(model, optimizer, loss, metrics=['acc', 'loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen)
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=valgen)

torchbearer_model.evaluate_generator(testgen)
```
- Running that code gives output using Tqdm and providing running accuracies and losses during the training phase:

```
0/10(t): 100%|██████████| 391/391 [00:01<00:00, 211.19it/s, running_acc=0.549, running_loss=1.25, acc=0.469, acc_std=0.499, loss=1.48, loss_std=0.238]
0/10(v): 100%|██████████| 79/79 [00:00<00:00, 265.14it/s, val_acc=0.556, val_acc_std=0.497, val_loss=1.25, val_loss_std=0.0785]
0/10(t): 100%|██████████| 352/352 [00:01<00:00, 233.36it/s, running_acc=0.536, running_loss=1.32, acc=0.459, acc_std=0.498, loss=1.52, loss_std=0.239]
0/10(v): 100%|██████████| 40/40 [00:00<00:00, 239.40it/s, val_acc=0.536, val_acc_std=0.499, val_loss=1.29, val_loss_std=0.0731]
.
.
.
9/10(t): 100%|██████████| 352/352 [00:01<00:00, 215.76it/s, running_acc=0.741, running_loss=0.735, acc=0.754, acc_std=0.431, loss=0.703, loss_std=0.0897]
9/10(v): 100%|██████████| 40/40 [00:00<00:00, 222.72it/s, val_acc=0.68, val_acc_std=0.466, val_loss=0.948, val_loss_std=0.181]
0/1(e): 100%|██████████| 79/79 [00:00<00:00, 268.70it/s, val_acc=0.678, val_acc_std=0.467, val_loss=0.925, val_loss_std=0.109]
```

<a name="docs"/>

## Documentation

Our documentation containing the API reference, examples and some notes can be found at [https://torchbearer.readthedocs.io](https://torchbearer.readthedocs.io)
Our documentation containing the API reference, examples and some notes can be found at [torchbearer.readthedocs.io](https://torchbearer.readthedocs.io)

<a name="others"/>

Expand Down
13 changes: 11 additions & 2 deletions docs/_static/examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
import torchvision
from torchvision import transforms

from torchbearer.cv_utils import DatasetValidationSplitter

BATCH_SIZE = 128

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

trainset = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True,
dataset = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]))
splitter = DatasetValidationSplitter(len(dataset), 0.1)
trainset = splitter.get_train_dataset(dataset)
valset = splitter.get_val_dataset(dataset)

traingen = torch.utils.data.DataLoader(trainset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)
valgen = torch.utils.data.DataLoader(valset, pin_memory=True, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)


testset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,
Expand Down Expand Up @@ -50,4 +57,6 @@ def forward(self, x):
from torchbearer import Model

torchbearer_model = Model(model, optimizer, loss, metrics=['acc', 'loss']).to('cuda')
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=testgen)
torchbearer_model.fit_generator(traingen, epochs=10, validation_generator=valgen)

torchbearer_model.evaluate_generator(testgen)
56 changes: 30 additions & 26 deletions docs/examples/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,47 @@ Let's get using torchbearer. Here's some data from Cifar10 and a simple 3 layer

.. literalinclude:: /_static/examples/quickstart.py
:language: python
:lines: 7-45
:lines: 9-52

Typically we would need a training loop and a series of calls to backward, step etc.
Instead, with torchbearer, we can define our optimiser and some metrics (just 'acc' and 'loss' for now) and let it do the work.
Note that we use torchbearers :class:`.DatasetValidationSplitter` here to create a validation set (10% of the data).
This is essential to avoid `over-fitting to your test data <http://blog.kaggle.com/2012/07/06/the-dangers-of-overfitting-psychopathy-post-mortem/>`_.

Training on Cifar10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Typically we would need a training loop and a series of calls to backward, step etc.
Instead, with torchbearer, we can define our optimiser and some metrics (just 'acc' and 'loss' for now) and let it do the work.

.. literalinclude:: /_static/examples/quickstart.py
:lines: 47-53
:lines: 54-62

Running the above produces the following output:

.. code::
Files already downloaded and verified
Files already downloaded and verified
0/10(t): 100%|██████████| 391/391 [00:01<00:00, 211.19it/s, running_acc=0.549, running_loss=1.25, acc=0.469, acc_std=0.499, loss=1.48, loss_std=0.238]
0/10(v): 100%|██████████| 79/79 [00:00<00:00, 265.14it/s, val_acc=0.556, val_acc_std=0.497, val_loss=1.25, val_loss_std=0.0785]
1/10(t): 100%|██████████| 391/391 [00:01<00:00, 209.80it/s, running_acc=0.61, running_loss=1.09, acc=0.593, acc_std=0.491, loss=1.14, loss_std=0.0968]
1/10(v): 100%|██████████| 79/79 [00:00<00:00, 227.97it/s, val_acc=0.593, val_acc_std=0.491, val_loss=1.14, val_loss_std=0.0865]
2/10(t): 100%|██████████| 391/391 [00:01<00:00, 220.70it/s, running_acc=0.656, running_loss=0.972, acc=0.645, acc_std=0.478, loss=1.01, loss_std=0.0915]
2/10(v): 100%|██████████| 79/79 [00:00<00:00, 218.91it/s, val_acc=0.631, val_acc_std=0.482, val_loss=1.04, val_loss_std=0.0951]
3/10(t): 100%|██████████| 391/391 [00:01<00:00, 208.67it/s, running_acc=0.682, running_loss=0.906, acc=0.675, acc_std=0.468, loss=0.922, loss_std=0.0895]
3/10(v): 100%|██████████| 79/79 [00:00<00:00, 86.95it/s, val_acc=0.657, val_acc_std=0.475, val_loss=0.97, val_loss_std=0.0925]
4/10(t): 100%|██████████| 391/391 [00:01<00:00, 211.22it/s, running_acc=0.693, running_loss=0.866, acc=0.699, acc_std=0.459, loss=0.86, loss_std=0.092]
4/10(v): 100%|██████████| 79/79 [00:00<00:00, 249.74it/s, val_acc=0.662, val_acc_std=0.473, val_loss=0.957, val_loss_std=0.093]
5/10(t): 100%|██████████| 391/391 [00:01<00:00, 205.12it/s, running_acc=0.71, running_loss=0.826, acc=0.713, acc_std=0.452, loss=0.818, loss_std=0.0904]
5/10(v): 100%|██████████| 79/79 [00:00<00:00, 230.12it/s, val_acc=0.661, val_acc_std=0.473, val_loss=0.962, val_loss_std=0.0966]
6/10(t): 100%|██████████| 391/391 [00:01<00:00, 210.87it/s, running_acc=0.714, running_loss=0.81, acc=0.727, acc_std=0.445, loss=0.779, loss_std=0.0904]
6/10(v): 100%|██████████| 79/79 [00:00<00:00, 241.95it/s, val_acc=0.667, val_acc_std=0.471, val_loss=0.952, val_loss_std=0.11]
7/10(t): 100%|██████████| 391/391 [00:01<00:00, 209.94it/s, running_acc=0.727, running_loss=0.791, acc=0.74, acc_std=0.439, loss=0.747, loss_std=0.0911]
7/10(v): 100%|██████████| 79/79 [00:00<00:00, 223.23it/s, val_acc=0.673, val_acc_std=0.469, val_loss=0.938, val_loss_std=0.107]
8/10(t): 100%|██████████| 391/391 [00:01<00:00, 203.16it/s, running_acc=0.747, running_loss=0.736, acc=0.752, acc_std=0.432, loss=0.716, loss_std=0.0899]
8/10(v): 100%|██████████| 79/79 [00:00<00:00, 221.55it/s, val_acc=0.679, val_acc_std=0.467, val_loss=0.923, val_loss_std=0.113]
9/10(t): 100%|██████████| 391/391 [00:01<00:00, 213.23it/s, running_acc=0.756, running_loss=0.701, acc=0.759, acc_std=0.428, loss=0.695, loss_std=0.0915]
9/10(v): 100%|██████████| 79/79 [00:00<00:00, 245.33it/s, val_acc=0.676, val_acc_std=0.468, val_loss=0.951, val_loss_std=0.111]
Files already downloaded and verified
Files already downloaded and verified
0/10(t): 100%|██████████| 352/352 [00:01<00:00, 233.36it/s, running_acc=0.536, running_loss=1.32, acc=0.459, acc_std=0.498, loss=1.52, loss_std=0.239]
0/10(v): 100%|██████████| 40/40 [00:00<00:00, 239.40it/s, val_acc=0.536, val_acc_std=0.499, val_loss=1.29, val_loss_std=0.0731]
1/10(t): 100%|██████████| 352/352 [00:01<00:00, 211.19it/s, running_acc=0.599, running_loss=1.13, acc=0.578, acc_std=0.494, loss=1.18, loss_std=0.096]
1/10(v): 100%|██████████| 40/40 [00:00<00:00, 232.97it/s, val_acc=0.594, val_acc_std=0.491, val_loss=1.14, val_loss_std=0.101]
2/10(t): 100%|██████████| 352/352 [00:01<00:00, 216.68it/s, running_acc=0.636, running_loss=1.04, acc=0.631, acc_std=0.482, loss=1.04, loss_std=0.0944]
2/10(v): 100%|██████████| 40/40 [00:00<00:00, 210.73it/s, val_acc=0.626, val_acc_std=0.484, val_loss=1.07, val_loss_std=0.0974]
3/10(t): 100%|██████████| 352/352 [00:01<00:00, 190.88it/s, running_acc=0.671, running_loss=0.929, acc=0.664, acc_std=0.472, loss=0.957, loss_std=0.0929]
3/10(v): 100%|██████████| 40/40 [00:00<00:00, 221.79it/s, val_acc=0.639, val_acc_std=0.48, val_loss=1.02, val_loss_std=0.103]
4/10(t): 100%|██████████| 352/352 [00:01<00:00, 212.43it/s, running_acc=0.685, running_loss=0.897, acc=0.689, acc_std=0.463, loss=0.891, loss_std=0.0888]
4/10(v): 100%|██████████| 40/40 [00:00<00:00, 249.99it/s, val_acc=0.655, val_acc_std=0.475, val_loss=0.983, val_loss_std=0.113]
5/10(t): 100%|██████████| 352/352 [00:01<00:00, 209.45it/s, running_acc=0.711, running_loss=0.835, acc=0.706, acc_std=0.456, loss=0.844, loss_std=0.088]
5/10(v): 100%|██████████| 40/40 [00:00<00:00, 240.80it/s, val_acc=0.648, val_acc_std=0.477, val_loss=0.965, val_loss_std=0.107]
6/10(t): 100%|██████████| 352/352 [00:01<00:00, 216.89it/s, running_acc=0.713, running_loss=0.826, acc=0.72, acc_std=0.449, loss=0.802, loss_std=0.0903]
6/10(v): 100%|██████████| 40/40 [00:00<00:00, 238.17it/s, val_acc=0.655, val_acc_std=0.475, val_loss=0.97, val_loss_std=0.0997]
7/10(t): 100%|██████████| 352/352 [00:01<00:00, 213.82it/s, running_acc=0.737, running_loss=0.773, acc=0.734, acc_std=0.442, loss=0.765, loss_std=0.0878]
7/10(v): 100%|██████████| 40/40 [00:00<00:00, 202.45it/s, val_acc=0.677, val_acc_std=0.468, val_loss=0.936, val_loss_std=0.0985]
8/10(t): 100%|██████████| 352/352 [00:01<00:00, 211.36it/s, running_acc=0.732, running_loss=0.744, acc=0.746, acc_std=0.435, loss=0.728, loss_std=0.0902]
8/10(v): 100%|██████████| 40/40 [00:00<00:00, 204.52it/s, val_acc=0.674, val_acc_std=0.469, val_loss=0.949, val_loss_std=0.124]
9/10(t): 100%|██████████| 352/352 [00:01<00:00, 215.76it/s, running_acc=0.741, running_loss=0.735, acc=0.754, acc_std=0.431, loss=0.703, loss_std=0.0897]
9/10(v): 100%|██████████| 40/40 [00:00<00:00, 222.72it/s, val_acc=0.68, val_acc_std=0.466, val_loss=0.948, val_loss_std=0.181]
0/1(e): 100%|██████████| 79/79 [00:00<00:00, 268.70it/s, val_acc=0.678, val_acc_std=0.467, val_loss=0.925, val_loss_std=0.109]
Source Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
33 changes: 33 additions & 0 deletions tests/test_cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,37 @@ def test_get_train_valid_sets_no_valid(self):
self.assertTrue(valset is None)
self.assertTrue(len(trainset) == len(x))

def test_DatasetValidationSplitter(self):
data = torch.Tensor(list(range(1000)))
dataset = TensorDataset(data)

splitter = DatasetValidationSplitter(len(dataset), 0.1)
trainset = splitter.get_train_dataset(dataset)
validset = splitter.get_val_dataset(dataset)

self.assertTrue(len(trainset) == 900)
self.assertTrue(len(validset) == 100)

# Check for ids in both train and validation set
collision = False
for id in trainset:
if id in validset.ids:
collision = True
self.assertFalse(collision)

def test_DatasetValidationSplitter_seed(self):
data = torch.Tensor(list(range(1000)))
dataset = TensorDataset(data)

splitter_1 = DatasetValidationSplitter(len(dataset), 0.1, shuffle_seed=1)
trainset_1 = splitter_1.get_train_dataset(dataset)
validset_1 = splitter_1.get_val_dataset(dataset)

splitter_2 = DatasetValidationSplitter(len(dataset), 0.1, shuffle_seed=1)
trainset_2 = splitter_2.get_train_dataset(dataset)
validset_2 = splitter_2.get_val_dataset(dataset)

self.assertTrue(trainset_1.ids[0] == trainset_2.ids[0])
self.assertTrue(validset_1.ids[0] == validset_2.ids[0])


80 changes: 75 additions & 5 deletions torchbearer/cv_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import TensorDataset, Dataset
import random


def train_valid_splitter(x, y, split, shuffle=True):
''' Generate training and validation tensors from whole dataset data and label tensors
""" Generate training and validation tensors from whole dataset data and label tensors
:param x: Data tensor for whole dataset
:type x: torch.Tensor
Expand All @@ -17,7 +18,7 @@ def train_valid_splitter(x, y, split, shuffle=True):
:type shuffle: bool
:return: Training and validation tensors (training data, training labels, validation data, validation labels)
:rtype: tuple
'''
"""
num_samples_x = x.size()[0]
num_valid_samples = math.floor(num_samples_x * split)

Expand All @@ -32,7 +33,7 @@ def train_valid_splitter(x, y, split, shuffle=True):


def get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True):
''' Generate validation and training datasets from whole dataset tensors
""" Generate validation and training datasets from whole dataset tensors
:param x: Data tensor for dataset
:type x: torch.Tensor
Expand All @@ -46,7 +47,7 @@ def get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True):
:type shuffle: bool
:return: Training and validation datasets
:rtype: tuple
'''
"""

valset = None

Expand All @@ -62,3 +63,72 @@ def get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True):
valset = TensorDataset(x_val, y_val)

return trainset, valset


class DatasetValidationSplitter:
def __init__(self, dataset_len, split_fraction, shuffle_seed=None):
""" Generates training and validation split indicies for a given dataset length and creates training and
validation datasets using these splits
:param dataset_len: The length of the dataset to be split into training and validation
:param split_fraction: The fraction of the whole dataset to be used for validation
:param shuffle_seed: Optional random seed for the shuffling process
"""
super().__init__()
self.dataset_len = dataset_len
self.split_fraction = split_fraction
self.valid_ids = None
self.train_ids = None
self._gen_split_ids(shuffle_seed)

def _gen_split_ids(self, seed):
all_ids = list(range(self.dataset_len))

if seed is not None:
random.seed(seed)
random.shuffle(all_ids)

num_valid_ids = math.floor(self.dataset_len*self.split_fraction)
self.valid_ids = all_ids[:num_valid_ids]
self.train_ids = all_ids[num_valid_ids:]

def get_train_dataset(self, dataset):
""" Creates a training dataset from existing dataset
:param dataset: Dataset to be split into a training dataset
:type dataset: torch.utils.data.Dataset
:return: Training dataset split from whole dataset
:rtype: torch.utils.data.Dataset
"""
return SubsetDataset(dataset, self.train_ids)

def get_val_dataset(self, dataset):
""" Creates a validation dataset from existing dataset
:param dataset: Dataset to be split into a validation dataset
:type dataset: torch.utils.data.Dataset
:return: Validation dataset split from whole dataset
:rtype: torch.utils.data.Dataset
"""
return SubsetDataset(dataset, self.valid_ids)


class SubsetDataset(Dataset):
def __init__(self, dataset, ids):
""" Dataset that consists of a subset of a previous dataset
:param dataset: Complete dataset
:type dataset: torch.utils.data.Dataset
:param ids: List of subset IDs
:type ids: list
"""
super().__init__()
self.dataset = dataset
self.ids = ids

def __getitem__(self, index):
return self.dataset.__getitem__(self.ids[index])

def __len__(self):
return len(self.ids)

0 comments on commit 8e37753

Please sign in to comment.