diff --git a/style-transfer/Style_Transfer_Exercise.ipynb b/style-transfer/Style_Transfer_Exercise.ipynb index b6330ef02d..3ea9aec2a1 100644 --- a/style-transfer/Style_Transfer_Exercise.ipynb +++ b/style-transfer/Style_Transfer_Exercise.ipynb @@ -38,11 +38,13 @@ "%matplotlib inline\n", "\n", "from PIL import Image\n", + "from io import BytesIO\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import torch\n", "import torch.optim as optim\n", + "import requests\n", "from torchvision import transforms, models" ] }, @@ -109,8 +111,11 @@ "def load_image(img_path, max_size=400, shape=None):\n", " ''' Load in and transform an image, making sure the image\n", " is <= 400 pixels in the x-y dims.'''\n", - " \n", - " image = Image.open(img_path).convert('RGB')\n", + " if \"http\" in img_path:\n", + " response = requests.get(img_path)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(img_path).convert('RGB')\n", " \n", " # large images will slow down processing\n", " if max(image.size) > max_size:\n", diff --git a/style-transfer/Style_Transfer_Solution.ipynb b/style-transfer/Style_Transfer_Solution.ipynb index f0e6760816..67c7bf1d50 100644 --- a/style-transfer/Style_Transfer_Solution.ipynb +++ b/style-transfer/Style_Transfer_Solution.ipynb @@ -38,11 +38,13 @@ "%matplotlib inline\n", "\n", "from PIL import Image\n", + "from io import BytesIO\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import torch\n", "import torch.optim as optim\n", + "import requests\n", "from torchvision import transforms, models" ] }, @@ -158,8 +160,11 @@ "def load_image(img_path, max_size=400, shape=None):\n", " ''' Load in and transform an image, making sure the image\n", " is <= 400 pixels in the x-y dims.'''\n", - " \n", - " image = Image.open(img_path).convert('RGB')\n", + " if \"http\" in img_path:\n", + " response = requests.get(img_path)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(img_path).convert('RGB')\n", " \n", " # large images will slow down processing\n", " if max(image.size) > max_size:\n",