From 35e81f35439612c41f25365d7579b81db7f235c2 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Mon, 23 Jul 2018 11:25:58 -0400 Subject: [PATCH] DOC: Adds transfer learning tutorial (#281) Adds transfer learning tutorial notebook. This tutorial converts the pure PyTorch approach described in https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html to skorch. --- .gitignore | 3 +- notebooks/Transfer_Learning.ipynb | 400 ++++++++++++++++++++++++++++++ notebooks/datasets/.gitkeep | 0 3 files changed, 401 insertions(+), 2 deletions(-) create mode 100644 notebooks/Transfer_Learning.ipynb create mode 100644 notebooks/datasets/.gitkeep diff --git a/.gitignore b/.gitignore index dfe8dd451..d38e9dcbc 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,6 @@ data/ *.pyc *.pkl *.gz -*.ipynb *.log *.c *.so @@ -19,4 +18,4 @@ data/ *.vocab *.w2v *prof -*.py~ \ No newline at end of file +*.py~ diff --git a/notebooks/Transfer_Learning.ipynb b/notebooks/Transfer_Learning.ipynb new file mode 100644 index 000000000..fdc64d720 --- /dev/null +++ b/notebooks/Transfer_Learning.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Transfer Learning with skorch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, you will learn how to train a neutral network using transfer learning with the `skorch` API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in [PyTorch's Transfer Learning Tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) to `skorch`.\n", + "\n", + "We will be using `torchvision` for this tutorial. Instructions on how to install `torchvision` for your platform can be found at https://pytorch.org." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from urllib import request\n", + "from zipfile import ZipFile\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "from torchvision import datasets, models, transforms\n", + "\n", + "from skorch import NeuralNetClassifier\n", + "from skorch.callbacks import LRScheduler, Checkpoint\n", + "from skorch.helper import filtered_optimizer\n", + "from skorch.helper import filter_requires_grad\n", + "from skorch.helper import predefined_split\n", + "\n", + "torch.manual_seed(360);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we begin, lets download the data needed for this tutorial:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data has been downloaded and extracted to datasets.\n" + ] + } + ], + "source": [ + "def download_and_extract_data(dataset_dir='datasets'):\n", + " data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')\n", + " data_path = os.path.join(dataset_dir, 'hymenoptera_data')\n", + " url = \"https://download.pytorch.org/tutorial/hymenoptera_data.zip\"\n", + "\n", + " if not os.path.exists(data_path):\n", + " if not os.path.exists(data_zip):\n", + " print(\"Starting to download data...\")\n", + " data = request.urlopen(url, timeout=15).read()\n", + " with open(data_path, 'wb') as f:\n", + " f.write(data)\n", + "\n", + " print(\"Starting to extract data...\")\n", + " with ZipFile(data_zip, 'r') as zip_f:\n", + " zip_f.extractall(dataset_dir)\n", + " \n", + " print(\"Data has been downloaded and extracted to {}.\".format(dataset_dir))\n", + " \n", + "download_and_extract_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Problem" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are going to train a neutral network to classify **ants** and **bees**. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = 'datasets/hymenoptera_data'\n", + "train_transforms = transforms.Compose([\n", + " transforms.RandomResizedCrop(224),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], \n", + " [0.229, 0.224, 0.225])\n", + "])\n", + "val_transforms = transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.485, 0.456, 0.406], \n", + " [0.229, 0.224, 0.225])\n", + "])\n", + "\n", + "train_ds = datasets.ImageFolder(\n", + " os.path.join(data_dir, 'train'), train_transforms)\n", + "val_ds = datasets.ImageFolder(\n", + " os.path.join(data_dir, 'val'), val_transforms)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: `[0.485, 0.456, 0.406]`, and standard deviation: `[0.229, 0.224, 0.225]`. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use a pretrained `ResNet18` neutral network model with its final layer replaced with a fully connected layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model_ft = models.resnet18(pretrained=True)\n", + "num_ftrs = model_ft.fc.in_features\n", + "model_ft.fc = nn.Linear(num_ftrs, 2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we are training a binary classifier, the output of the final fully connected layer has size 2. Next, we freeze all layers except the final layer by setting `requires_grad` to False:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "for name, param in model_ft.named_parameters():\n", + " if not name.startswith('fc'):\n", + " param.requires_grad_(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using skorch's API" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section, we will create a `skorch.NeuralNetClassifier` to solve our classification problem. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Callbacks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we create two callbacks:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "lrscheduler = LRScheduler(\n", + " policy='StepLR', step_size=7, gamma=0.1)\n", + "\n", + "checkpoint = Checkpoint(\n", + " target='best_model.pt', monitor='valid_acc_best')\n", + "\n", + "callbacks = [lrscheduler, checkpoint]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `LRScheduler` callback defines a learning rate scheduler that uses `torch.optim.lr_scheduler.StepLR` to scale learning rates by `gamma=0.1` every 7 steps. The `Checkpoint` callback saves the best model by by monitoring the validation accuracy." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filtered optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we froze some layers in our `Resnet18` neutral network, we need to configure our optimizer to only update gradients in our final fully connected layer. Luckily, `skorch` provides two functions that make this simple:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = filtered_optimizer(\n", + " optim.SGD, filter_requires_grad\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This function does not do any processing and returns the two datasets. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### skorch.NeutralNetClassifier" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With all the preparations out of the way, we can now define our `NeutralNetClassifier`:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "net = NeuralNetClassifier(\n", + " model_ft, \n", + " criterion=nn.CrossEntropyLoss,\n", + " lr=0.001,\n", + " batch_size=4,\n", + " max_epochs=25,\n", + " optimizer=optimizer,\n", + " optimizer__momentum=0.9,\n", + " iterator_train__shuffle=True,\n", + " iterator_train__num_workers=4,\n", + " iterator_valid__shuffle=True,\n", + " iterator_valid__num_workers=4,\n", + " train_split=predefined_split(val_ds),\n", + " callbacks=callbacks,\n", + " device='cuda' # uncomment to train on gpu\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That is quite a few parameters! Lets walk through each one:\n", + "\n", + "1. `model_ft`: Our `ResNet18` neutral network\n", + "2. `criterion=nn.CrossEntropyLoss`: loss function\n", + "3. `lr`: Initial learning rate\n", + "4. `batch_size`: Size of a batch\n", + "5. `max_epochs`: Number of epochs to train\n", + "6. `optimizer`: Our filtered optimizer\n", + "7. `optimizer__momentum`: The initial momentum\n", + "8. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.\n", + "9. `train_split`: A wrapper around `val_ds` to use our validation dataset.\n", + "10. `callbacks`: Our callbacks \n", + "11. `device`: Set to `cuda` to train on gpu.\n", + "\n", + "Now we are ready to train our neutral network:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " epoch train_loss valid_acc valid_loss cp dur\n", + "------- ------------ ----------- ------------ ---- ------\n", + " 1 \u001b[36m0.8333\u001b[0m \u001b[32m0.9346\u001b[0m \u001b[35m0.2150\u001b[0m + 1.4656\n", + " 2 \u001b[36m0.5203\u001b[0m 0.9150 0.2439 1.2813\n", + " 3 \u001b[36m0.4469\u001b[0m 0.8693 0.3494 1.2942\n", + " 4 0.4665 \u001b[32m0.9542\u001b[0m \u001b[35m0.1949\u001b[0m + 1.2363\n", + " 5 \u001b[36m0.3884\u001b[0m 0.9412 0.1962 1.2834\n", + " 6 \u001b[36m0.3807\u001b[0m 0.9412 \u001b[35m0.1923\u001b[0m 1.3064\n", + " 7 \u001b[36m0.3292\u001b[0m 0.9412 \u001b[35m0.1876\u001b[0m 1.3428\n", + " 8 \u001b[36m0.2864\u001b[0m 0.9412 0.1961 1.3132\n", + " 9 0.4199 0.9346 0.1987 1.2935\n", + " 10 0.4462 0.9412 0.2054 1.2886\n", + " 11 0.2971 0.9412 0.1952 1.4172\n", + " 12 0.3474 0.9412 0.2092 1.3306\n", + " 13 0.2891 0.9412 0.2285 1.2747\n", + " 14 0.3648 0.8889 0.2870 1.3012\n", + " 15 0.3090 0.9346 0.2029 1.2845\n", + " 16 0.3348 0.9281 0.2345 1.3291\n", + " 17 0.2940 0.9412 0.1908 1.2621\n", + " 18 0.3961 0.9412 0.2181 1.3413\n", + " 19 0.3652 0.9281 0.2256 1.2868\n", + " 20 0.3038 0.9346 0.2178 1.3177\n", + " 21 0.3489 0.9477 0.2041 1.2754\n", + " 22 0.3161 0.9477 \u001b[35m0.1848\u001b[0m 1.2837\n", + " 23 \u001b[36m0.2622\u001b[0m 0.9477 0.2074 1.3112\n", + " 24 0.3823 0.9477 0.1923 1.2328\n", + " 25 0.2952 0.9412 0.2223 1.3896\n" + ] + } + ], + "source": [ + "net.fit(train_ds, y=None);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The best model is stored at `best_model.pt`, with a validiation accuracy of roughly 0.95.\n", + "\n", + "Congrualations! You now know how to finetune a neutral network using `skorch`. Feel free to explore the other tutorials to learn more about using `skorch`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/datasets/.gitkeep b/notebooks/datasets/.gitkeep new file mode 100644 index 000000000..e69de29bb