Skip to content

Commit

Permalink
Feature/examples (#576)
Browse files Browse the repository at this point in the history
* Update examples, fix small errors

* Update adversarial example notebook

* Update svm example

* Update GAN example
  • Loading branch information
ethanwharris committed Jun 21, 2019
1 parent 2c6294f commit fb55c8d
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 213 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -108,4 +108,5 @@ ENV/
/docs/_static/examples/images
/docs/_static/examples/logs
/docs/_static/notebooks/data
/docs/_static/notebooks/*.png
/docs/_static/notebooks/*.png
/docs/_static/notebooks/images
199 changes: 155 additions & 44 deletions docs/_static/notebooks/adversarial.ipynb

Large diffs are not rendered by default.

121 changes: 61 additions & 60 deletions docs/_static/notebooks/amsgrad.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/_static/notebooks/basic_opt.ipynb
Expand Up @@ -17,7 +17,7 @@
"\n",
"> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
"\n",
"## Dependencies\n",
"## Install Torchbearer\n",
"\n",
"First we install torchbearer if needed. "
]
Expand Down
2 changes: 1 addition & 1 deletion docs/_static/notebooks/callbacks.ipynb
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. You won't need a GPU for these examples.\n",
"\n",
"## Dependencies\n",
"## Install Torchbearer\n",
"\n",
"First we install torchbearer if needed. "
]
Expand Down
149 changes: 95 additions & 54 deletions docs/_static/notebooks/gan.ipynb
Expand Up @@ -7,45 +7,40 @@
"id": "mpfN9z4hrrYx"
},
"source": [
"Training a GAN\n",
"====================================\n",
"# Training a GAN\n",
"\n",
"We shall try to implement something more complicated using torchbearer - a Generative Adverserial Network (GAN). This tutorial is a modified version of the [GAN](https://github.com/eriklindernoren/PyTorch-GAN#gan) from the brilliant collection of GAN implementations [PyTorch_GAN](https://github.com/eriklindernoren/PyTorch-GAN) by eriklindernoren on github.\n",
"\n",
"**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with\n",
"\n",
"First lets import all the necessary packages. "
"> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
"\n",
"## Install Torchbearer\n",
"\n",
"First we install torchbearer if needed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "vcMl2qlkr8UM"
},
"outputs": [],
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.3.2\n"
]
}
],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as transforms\n",
"from torch.autograd import Variable\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets\n",
"from torchvision.utils import save_image\n",
"\n",
"try:\n",
" import torchbearer\n",
" import torchbearer\n",
"except:\n",
" !pip install torchbearer\n",
" import torchbearer\n",
" \n",
"import torchbearer.callbacks as callbacks\n",
"from torchbearer import state_key\n",
"from torchbearer.bases import base_closure"
" !pip install -q torchbearer\n",
" import torchbearer\n",
" \n",
"print(torchbearer.__version__)"
]
},
{
Expand All @@ -63,24 +58,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4uuz1gTYsGLp"
},
"outputs": [],
"source": [
"import torch\n",
"from torchbearer import state_key\n",
"\n",
"# Define constants\n",
"epochs = 200\n",
"batch_size = 64\n",
"lr = 0.0002\n",
"nworkers = 8\n",
"train_steps = 50000\n",
"batch_size = 128\n",
"lr = 0.002\n",
"latent_dim = 100\n",
"sample_interval = 400\n",
"img_shape = (1, 28, 28)\n",
"adversarial_loss = torch.nn.BCELoss()\n",
"device = 'cuda'\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"valid = torch.ones(batch_size, 1, device=device)\n",
"fake = torch.zeros(batch_size, 1, device=device)\n",
"batch = torch.randn(25, latent_dim).to(device)\n",
Expand Down Expand Up @@ -112,14 +109,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GBFLYyH5sP1C"
},
"outputs": [],
"source": [
"from torchvision import datasets, transforms\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" ])\n",
Expand All @@ -142,14 +141,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "y2YUr7jWsXTX"
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import numpy as np\n",
"\n",
"class Generator(nn.Module):\n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
Expand All @@ -165,13 +167,12 @@
" *block(latent_dim, 128, normalize=False),\n",
" *block(128, 256),\n",
" *block(256, 512),\n",
" *block(512, 1024),\n",
" nn.Linear(1024, int(np.prod(img_shape))),\n",
" nn.Tanh()\n",
" nn.Linear(512, int(np.prod(img_shape))),\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, real_imgs, state):\n",
" z = Variable(torch.Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim)))).to(state[torchbearer.DEVICE])\n",
" z = torch.randn(real_imgs.shape[0], latent_dim, device = state[torchbearer.DEVICE])\n",
" img = self.model(z)\n",
" img = img.view(img.size(0), *img_shape)\n",
" return img\n",
Expand Down Expand Up @@ -209,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -240,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -278,7 +279,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -298,13 +299,12 @@
},
"source": [
"\n",
"Closures\n",
"------------------------------------\n",
"## Closures\n",
"The training loop of a GAN is a bit different to a standard model training loop.\n",
"GANs require separate forward and backward passes for the generator and discriminator.\n",
"To achieve this in torchbearer we can write a new closure.\n",
"Since the individual training loops for the generator and discriminator are the same as a\n",
"standard training loop we can use a :func:`~torchbearer.bases.base_closure`.\n",
"standard training loop we can use a [`~torchbearer.bases.base_closure`](https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.bases.base_closure).\n",
"The base closure takes state keys for required objects (data, model, optimiser, etc.) and returns a standard closure consisting of:\n",
"\n",
"1. Zero gradients\n",
Expand All @@ -317,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -342,7 +342,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {
"colab": {},
"colab_type": "code",
Expand Down Expand Up @@ -373,14 +373,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "FRmFEk8EtklV"
},
"outputs": [],
"source": [
"from torchvision.utils import save_image\n",
"from torchbearer import callbacks\n",
"import os\n",
"os.makedirs('images', exist_ok=True)\n",
"\n",
"@callbacks.on_step_training\n",
Expand All @@ -407,7 +410,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {
"colab": {},
"colab_type": "code",
Expand All @@ -416,7 +419,7 @@
"outputs": [],
"source": [
"trial = torchbearer.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback])\n",
"trial.with_train_generator(dataloader, steps=200000)\n",
"trial.with_train_generator(dataloader, steps=train_steps)\n",
"_ = trial.to(device)"
]
},
Expand All @@ -432,7 +435,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -456,7 +459,45 @@
"id": "6G-pLJLMtOwo",
"outputId": "6a8e2b33-0d19-42b7-c250-da8a80572168"
},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6325852d4ae743a8be2e8c79655578f2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='0/1(t)', max=50000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"[((50000, None),\n",
" {'running_loss': 0.32024282217025757,\n",
" 'running_d_loss': 0.32024282217025757,\n",
" 'running_g_loss': 2.202518939971924,\n",
" 'loss': 0.4093643128871918,\n",
" 'd_loss': 0.4093643128871918,\n",
" 'g_loss': 1.8098881244659424})]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_keys = {DISC_MODEL: discriminator.to(device), DISC_OPT: optimizer_D, GEN_OPT: optimizer_G, DISC_CRIT: disc_crit}\n",
"trial.state.update(new_keys)\n",
Expand Down Expand Up @@ -501,7 +542,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs/_static/notebooks/quickstart.ipynb
Expand Up @@ -15,7 +15,7 @@
"\n",
"> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
"\n",
"## Dependencies\n",
"## Install Torchbearer\n",
"\n",
"First we install torchbearer if needed. \n",
"\n"
Expand Down

0 comments on commit fb55c8d

Please sign in to comment.