Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
support Python 3, fixed issues in model loading and missing example i…
…mages (#18)

Change print and urlopen syntax to support Python 3. Change iteration over plots
to support single-row plots.
Create 'checkpoints' drectory if missing and load images from the magenta
package because they are missing in the repository. This solves issue #14.
  • Loading branch information
Banus authored and adarob committed Oct 25, 2018
1 parent f696db3 commit 58db7bb
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions jupyter-notebooks/Image_Stylization.ipynb
Expand Up @@ -24,17 +24,19 @@
"directory run 'jupyter notebook'\n",
"\"\"\"\n",
"from __future__ import absolute_import\n",
"from __future__ import print_function\n",
"from __future__ import division\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"import ast\n",
"import os\n",
"import numpy as np\n",
"import sys\n",
"import random\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import urllib2\n",
"from six.moves.urllib.request import urlopen\n",
"\n",
"from magenta.models.image_stylization import image_utils\n",
"from magenta.models.image_stylization import model\n",
Expand All @@ -46,18 +48,23 @@
" for checkpoint in checkpoints:\n",
" full_checkpoint = os.path.join(checkpoint_dir, checkpoint)\n",
" if not os.path.exists(full_checkpoint):\n",
" print 'Downloading', full_checkpoint\n",
" response = urllib2.urlopen(url_prefix + checkpoint)\n",
" print('Downloading {}'.format(full_checkpoint))\n",
" response = urlopen(url_prefix + checkpoint)\n",
" data = response.read()\n",
" with open(full_checkpoint, 'wb') as fh:\n",
" fh.write(data)\n",
"\n",
"# Select an image (any jpg or png).\n",
"input_image = 'evaluation_images/guerrillero_heroico.jpg'\n",
"example_path = os.path.dirname(sys.modules['magenta.models.image_stylization'].__file__)\n",
"input_image = os.path.join(example_path, 'evaluation_images/guerrillero_heroico.jpg')\n",
"\n",
"# Select a demo ('varied' or 'monet')\n",
"demo = 'varied'\n",
"\n",
"# create 'checkpoints' directory if it doesn't exist\n",
"if not os.path.isdir('checkpoints'):\n",
" os.makedirs('checkpoints')\n",
"\n",
"DownloadCheckpointFiles()\n",
"image = np.expand_dims(image_utils.load_np_image(\n",
" os.path.expanduser(input_image)), 0)\n",
Expand All @@ -71,14 +78,14 @@
"# Styles from checkpoint file to render. They are done in batch, so the more \n",
"# rendered, the longer it will take and the more memory will be used.\n",
"# These can be modified as you like. Here we randomly select six styles.\n",
"styles = range(num_styles)\n",
"styles = list(range(num_styles))\n",
"random.shuffle(styles)\n",
"which_styles = styles[0:6]\n",
"num_rendered = len(which_styles) \n",
"num_rendered = len(which_styles)\n",
"\n",
"with tf.Graph().as_default(), tf.Session() as sess:\n",
" stylized_images = model.transform(\n",
" tf.concat([image for _ in range(len(which_styles))], 0),\n",
" tf.concat([image for _ in range(num_rendered)], 0),\n",
" normalizer_params={\n",
" 'labels': tf.constant(which_styles),\n",
" 'num_categories': num_styles,\n",
Expand All @@ -88,15 +95,13 @@
" model_saver.restore(sess, checkpoint)\n",
" stylized_images = stylized_images.eval()\n",
" \n",
" # Plot the images.\n",
" counter = 0\n",
" num_cols = 3\n",
" f, axarr = plt.subplots(num_rendered // num_cols, num_cols, figsize=(25, 25))\n",
" for col in range(num_cols):\n",
" for row in range( num_rendered // num_cols):\n",
" axarr[row, col].imshow(stylized_images[counter])\n",
" axarr[row, col].set_xlabel('Style %i' % which_styles[counter])\n",
" counter += 1\n",
"# Plot the images.\n",
"counter = 0\n",
"num_cols = 3\n",
"f, _ = plt.subplots(num_rendered // num_cols, num_cols, figsize=(25, 25))\n",
"for counter, axis in enumerate(f.axes):\n",
" axis.imshow(stylized_images[counter])\n",
" axis.set_xlabel('Style %i' % which_styles[counter])\n",
" "
]
},
Expand Down

0 comments on commit 58db7bb

Please sign in to comment.