Skip to content
Permalink
Browse files

[MRG] Update transform learning to use class in NeutralNet.__init__ (#…

…425)

* DOC: Uses class based init

* DOC: Fixes bugs to update
  • Loading branch information...
thomasjpfan authored and ottonemo committed Jan 21, 2019
1 parent 62ff0da commit a9f1140cb4f28b123a35efa50286def121a3451a
Showing with 27 additions and 18 deletions.
  1. +27 −18 notebooks/Transfer_Learning.ipynb
@@ -39,7 +39,7 @@
"metadata": {},
"outputs": [],
"source": [
"! [ ! -z \"$COLAB_GPU\" ] && pip install torch torchvision pillow==4.1.1 git+https://github.com/dnouri/skorch\n",
"! [ ! -z \"$COLAB_GPU\" ] && pip install torch torchvision pillow==4.1.1 skorch\n",
"! [ ! -z \"$COLAB_GPU\" ] && mkdir -p datasets\n",
"! [ ! -z \"$COLAB_GPU\" ] && wget -nc --no-check-certificate https://download.pytorch.org/tutorial/hymenoptera_data.zip -P datasets\n",
"! [ ! -z \"$COLAB_GPU\" ] && unzip -u datasets/hymenoptera_data.zip -d datasets"
@@ -185,9 +185,16 @@
"metadata": {},
"outputs": [],
"source": [
"model_ft = models.resnet18(pretrained=True)\n",
"num_ftrs = model_ft.fc.in_features\n",
"model_ft.fc = nn.Linear(num_ftrs, 2)"
"class PretrainedModel(nn.Module):\n",
" def __init__(self, output_features):\n",
" super().__init__()\n",
" model = models.resnet18(pretrained=True)\n",
" num_ftrs = model.fc.in_features\n",
" model.fc = nn.Linear(num_ftrs, output_features)\n",
" self.model = model\n",
" \n",
" def forward(self, x):\n",
" return self.model(x)"
]
},
{
@@ -227,7 +234,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -246,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -260,18 +267,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Lastly, we create a `Freezer` to freeze all weights besides the final layer named `fc`:"
"Lastly, we create a `Freezer` to freeze all weights besides the final layer named `model.fc`:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from skorch.callbacks import Freezer\n",
"\n",
"freezer = Freezer(lambda x: not x.startswith('fc'))"
"freezer = Freezer(lambda x: not x.startswith('model.fc'))"
]
},
{
@@ -290,16 +297,17 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"net = NeuralNetClassifier(\n",
" model_ft, \n",
" PretrainedModel, \n",
" criterion=nn.CrossEntropyLoss,\n",
" lr=0.001,\n",
" batch_size=4,\n",
" max_epochs=25,\n",
" module__output_features=2,\n",
" optimizer=optim.SGD,\n",
" optimizer__momentum=0.9,\n",
" iterator_train__shuffle=True,\n",
@@ -323,19 +331,20 @@
"3. `lr`: Initial learning rate\n",
"4. `batch_size`: Size of a batch\n",
"5. `max_epochs`: Number of epochs to train\n",
"6. `optimizer`: Our optimizer\n",
"7. `optimizer__momentum`: The initial momentum\n",
"8. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.\n",
"9. `train_split`: A wrapper around `val_ds` to use our validation dataset.\n",
"10. `callbacks`: Our callbacks \n",
"11. `device`: Set to `cuda` to train on gpu.\n",
"6. `module__output_features`: Used by `__init__` in our `PretrainedModel` class to set the number of classes.\n",
"7. `optimizer`: Our optimizer\n",
"8. `optimizer__momentum`: The initial momentum\n",
"9. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.\n",
"10. `train_split`: A wrapper around `val_ds` to use our validation dataset.\n",
"11. `callbacks`: Our callbacks \n",
"12. `device`: Set to `cuda` to train on gpu.\n",
"\n",
"Now we are ready to train our neutral network:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{

0 comments on commit a9f1140

Please sign in to comment.
You can’t perform that action at this time.