Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 27 additions & 32 deletions contrib/colab/style_transfer_inference-xrt-1-15.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@
"colab_type": "text"
},
"source": [
"### [Installs PyTorch & Loads the Network]\n",
"### Installs PyTorch & Loads the Networks\n",
"(This may take a couple minutes.)\n",
"\n",
"The pre-trained fast neural style transfer network will be stored in you Google Drive. You may be prompted to enter an authorization code (check the output of the next cell)."
"Fast neural style transfer networks use the same architecture but different weights to encode their styles. This notebook creates four fast neural style transfer networks: \"rain princess,\" \"candy,\" \"mosaic,\" and \"udnie.\" You can apply these styles below."
]
},
{
Expand Down Expand Up @@ -110,11 +110,9 @@
"from google.colab.patches import cv2_imshow\n",
"import cv2\n",
"import sys\n",
"from google.colab import drive\n",
"drive.mount('/content/gdrive')\n",
"\n",
"# Setup repo in google drive\n",
"REPO_DIR='/content/gdrive/My Drive/demo'\n",
"# Configures repo in local colab fs\n",
"REPO_DIR = '/demo'\n",
"%mkdir -p \"$REPO_DIR\"\n",
"%cd \"$REPO_DIR\" \n",
"%rm -rf examples\n",
Expand All @@ -125,6 +123,8 @@
"!python download_saved_models.py\n",
"%cd \"$REPO_DIR/examples/fast_neural_style/neural_style\"\n",
"\n",
"\n",
"## Creates pre-trained style networks\n",
"import argparse\n",
"import os\n",
"import sys\n",
Expand All @@ -145,50 +145,31 @@
"import torch_xla.utils.utils as xu\n",
"import utils\n",
"from transformer_net import TransformerNet\n",
"from vgg import Vgg16"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xp3osjyjjwcZ",
"colab_type": "text"
},
"source": [
"The following snippet loads four sets of weights for the same fast neural style network. Each set of weights encodes a different style. Here we have \"rain princess,\" \"candy,\" \"mosaic,\" and \"udnie.\" You can apply each of these below.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "9w1UZhTff7tV",
"colab_type": "code",
"colab": {}
},
"source": [
"from vgg import Vgg16\n",
"\n",
"# Acquires the XLA device (a TPU core)\n",
"device = xm.xla_device()\n",
"\n",
"# Loads various style models\n",
"# Loads pre-trained weights\n",
"rain_princess_path = '../saved_models/rain_princess.pth'\n",
"candy_path = '../saved_models/candy.pth'\n",
"mosaic_path = '../saved_models/mosaic.pth'\n",
"udnie_path = '../saved_models/udnie.pth'\n",
"\n",
"# Loads the model onto the TPU\n",
"# Loads the pre-trained weights into the fast neural style transfer\n",
"# network architecture and puts the network on the Cloud TPU core.\n",
"def load_style(path):\n",
" with torch.no_grad():\n",
" model = TransformerNet()\n",
" state_dict = torch.load(path)\n",
" # filters deprecated running_* keys in InstanceNorm from the checkpoint\n",
" # filters deprecated running_* keys from the checkpoint\n",
" for k in list(state_dict.keys()):\n",
" if re.search(r'in\\d+\\.running_(mean|var)$', k):\n",
" del state_dict[k]\n",
" model.load_state_dict(state_dict)\n",
" return model.to(device)\n",
"\n",
"# Creates each fast neural style transfer network\n",
"rain_princess = load_style(rain_princess_path)\n",
"candy = load_style(candy_path)\n",
"mosaic = load_style(mosaic_path)\n",
Expand All @@ -197,6 +178,20 @@
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j1w1G4AcWw9f",
"colab_type": "text"
},
"source": [
"## Try it out!\n",
"\n",
"The next cell loads and display an image from a URL. This image is styled by the following cell. You can re-run these two cells as often as you like to style multiple images.\n",
"\n",
"Start by copying and pasting an image URL here (or use the default corgi)."
]
},
{
"cell_type": "code",
"metadata": {
Expand Down