From ce0c525de8005d4ac133d12d6deead759f6529b0 Mon Sep 17 00:00:00 2001 From: Gabriele Picco Date: Fri, 30 Nov 2018 16:37:45 +0100 Subject: [PATCH] Added ability to load images from an URL Enhanced the LoadImage function to also allow loading from URL --- style-transfer/Style_Transfer_Exercise.ipynb | 9 +++++++-- style-transfer/Style_Transfer_Solution.ipynb | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/style-transfer/Style_Transfer_Exercise.ipynb b/style-transfer/Style_Transfer_Exercise.ipynb index b6330ef02d..3ea9aec2a1 100644 --- a/style-transfer/Style_Transfer_Exercise.ipynb +++ b/style-transfer/Style_Transfer_Exercise.ipynb @@ -38,11 +38,13 @@ "%matplotlib inline\n", "\n", "from PIL import Image\n", + "from io import BytesIO\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import torch\n", "import torch.optim as optim\n", + "import requests\n", "from torchvision import transforms, models" ] }, @@ -109,8 +111,11 @@ "def load_image(img_path, max_size=400, shape=None):\n", " ''' Load in and transform an image, making sure the image\n", " is <= 400 pixels in the x-y dims.'''\n", - " \n", - " image = Image.open(img_path).convert('RGB')\n", + " if \"http\" in img_path:\n", + " response = requests.get(img_path)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(img_path).convert('RGB')\n", " \n", " # large images will slow down processing\n", " if max(image.size) > max_size:\n", diff --git a/style-transfer/Style_Transfer_Solution.ipynb b/style-transfer/Style_Transfer_Solution.ipynb index f0e6760816..67c7bf1d50 100644 --- a/style-transfer/Style_Transfer_Solution.ipynb +++ b/style-transfer/Style_Transfer_Solution.ipynb @@ -38,11 +38,13 @@ "%matplotlib inline\n", "\n", "from PIL import Image\n", + "from io import BytesIO\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import torch\n", "import torch.optim as optim\n", + "import requests\n", "from torchvision import transforms, models" ] }, @@ -158,8 +160,11 @@ "def load_image(img_path, max_size=400, shape=None):\n", " ''' Load in and transform an image, making sure the image\n", " is <= 400 pixels in the x-y dims.'''\n", - " \n", - " image = Image.open(img_path).convert('RGB')\n", + " if \"http\" in img_path:\n", + " response = requests.get(img_path)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(img_path).convert('RGB')\n", " \n", " # large images will slow down processing\n", " if max(image.size) > max_size:\n",