Skip to content

Commit

Permalink
DOCS: Apply Ruff and Black to image_examples (#3300)
Browse files Browse the repository at this point in the history
  • Loading branch information
connortann committed Oct 7, 2023
1 parent c0bb3c9 commit 4316c41
Show file tree
Hide file tree
Showing 13 changed files with 568 additions and 415 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ repos:
notebooks/api_examples/.*.ipynb|
notebooks/benchmarks/.*.ipynb|
notebooks/genomic_examples/.*.ipynb|
notebooks/image_examples/.*.ipynb|
notebooks/tabular_examples/.*.ipynb|
)$
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
"outputs": [],
"source": [
"import json\n",
"\n",
"import numpy as np\n",
"import torchvision\n",
"import torch\n",
"import torch.nn as nn\n",
"import shap\n",
"from PIL import Image"
"import torchvision\n",
"\n",
"import shap"
]
},
{
Expand Down Expand Up @@ -88,7 +88,7 @@
"with open(shap.datasets.cache(url)) as file:\n",
" class_names = [v[1] for v in json.load(file).values()]\n",
"print(\"Number of ImageNet classes:\", len(class_names))\n",
"#print(\"Class names:\", class_names)"
"# print(\"Class names:\", class_names)"
]
},
{
Expand All @@ -102,33 +102,35 @@
"mean = [0.485, 0.456, 0.406]\n",
"std = [0.229, 0.224, 0.225]\n",
"\n",
"\n",
"def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:\n",
" if x.dim() == 4:\n",
" x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)\n",
" elif x.dim() == 3:\n",
" x = x if x.shape[0] == 3 else x.permute(2, 0, 1)\n",
" return x\n",
"\n",
"\n",
"def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:\n",
" if x.dim() == 4:\n",
" x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)\n",
" elif x.dim() == 3:\n",
" x = x if x.shape[2] == 3 else x.permute(1, 2, 0)\n",
" return x \n",
" \n",
" return x\n",
"\n",
"\n",
"transform= [\n",
"transform = [\n",
" torchvision.transforms.Lambda(nhwc_to_nchw),\n",
" torchvision.transforms.Lambda(lambda x: x*(1/255)),\n",
" torchvision.transforms.Lambda(lambda x: x * (1 / 255)),\n",
" torchvision.transforms.Normalize(mean=mean, std=std),\n",
" torchvision.transforms.Lambda(nchw_to_nhwc),\n",
"]\n",
"\n",
"inv_transform= [\n",
"inv_transform = [\n",
" torchvision.transforms.Lambda(nhwc_to_nchw),\n",
" torchvision.transforms.Normalize(\n",
" mean = (-1 * np.array(mean) / np.array(std)).tolist(),\n",
" std = (1 / np.array(std)).tolist()\n",
" mean=(-1 * np.array(mean) / np.array(std)).tolist(),\n",
" std=(1 / np.array(std)).tolist(),\n",
" ),\n",
" torchvision.transforms.Lambda(nchw_to_nhwc),\n",
"]\n",
Expand Down Expand Up @@ -168,7 +170,7 @@
"Xtr = transform(torch.Tensor(X))\n",
"out = predict(Xtr[1:3])\n",
"classes = torch.argmax(out, axis=1).cpu().numpy()\n",
"print(f'Classes: {classes}: {np.array(class_names)[classes]}')"
"print(f\"Classes: {classes}: {np.array(class_names)[classes]}\")"
]
},
{
Expand Down Expand Up @@ -218,8 +220,12 @@
"\n",
"# feed only one image\n",
"# here we explain two images using 100 evaluations of the underlying model to estimate the SHAP values\n",
"shap_values = explainer(Xtr[1:2], max_evals=n_evals, batch_size=batch_size,\n",
" outputs=shap.Explanation.argsort.flip[:topk])"
"shap_values = explainer(\n",
" Xtr[1:2],\n",
" max_evals=n_evals,\n",
" batch_size=batch_size,\n",
" outputs=shap.Explanation.argsort.flip[:topk],\n",
")"
]
},
{
Expand Down Expand Up @@ -253,7 +259,7 @@
"outputs": [],
"source": [
"shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0]\n",
"shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)]"
"shap_values.values = [val for val in np.moveaxis(shap_values.values[0], -1, 0)]"
]
},
{
Expand All @@ -275,10 +281,12 @@
}
],
"source": [
"shap.image_plot(shap_values=shap_values.values, \n",
" pixel_values=shap_values.data, \n",
" labels=shap_values.output_names,\n",
" true_labels=[class_names[132]])"
"shap.image_plot(\n",
" shap_values=shap_values.values,\n",
" pixel_values=shap_values.data,\n",
" labels=shap_values.output_names,\n",
" true_labels=[class_names[132]],\n",
")"
]
},
{
Expand Down Expand Up @@ -368,8 +376,12 @@
"\n",
"# feed only one image\n",
"# here we explain two images using 100 evaluations of the underlying model to estimate the SHAP values\n",
"shap_values = explainer(Xtr[1:4], max_evals=n_evals, batch_size=batch_size,\n",
" outputs=shap.Explanation.argsort.flip[:topk])"
"shap_values = explainer(\n",
" Xtr[1:4],\n",
" max_evals=n_evals,\n",
" batch_size=batch_size,\n",
" outputs=shap.Explanation.argsort.flip[:topk],\n",
")"
]
},
{
Expand Down Expand Up @@ -399,7 +411,7 @@
"outputs": [],
"source": [
"shap_values.data = inv_transform(shap_values.data).cpu().numpy()\n",
"shap_values.values = [val for val in np.moveaxis(shap_values.values,-1, 0)]"
"shap_values.values = [val for val in np.moveaxis(shap_values.values, -1, 0)]"
]
},
{
Expand Down Expand Up @@ -441,9 +453,11 @@
}
],
"source": [
"shap.image_plot(shap_values=shap_values.values,\n",
" pixel_values=shap_values.data,\n",
" labels=shap_values.output_names)"
"shap.image_plot(\n",
" shap_values=shap_values.values,\n",
" pixel_values=shap_values.data,\n",
" labels=shap_values.output_names,\n",
")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"outputs": [],
"source": [
"import json\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input\n",
"\n",
"import shap"
]
},
Expand All @@ -38,7 +38,7 @@
"outputs": [],
"source": [
"# load pre-trained model and data\n",
"model = ResNet50(weights='imagenet')\n",
"model = ResNet50(weights=\"imagenet\")\n",
"X, y = shap.datasets.imagenet50()"
]
},
Expand All @@ -52,8 +52,8 @@
"url = \"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\"\n",
"with open(shap.datasets.cache(url)) as file:\n",
" class_names = [v[1] for v in json.load(file).values()]\n",
"#print(\"Number of ImageNet classes:\", len(class_names))\n",
"#print(\"Class names:\", class_names)"
"# print(\"Number of ImageNet classes:\", len(class_names))\n",
"# print(\"Class names:\", class_names)"
]
},
{
Expand Down Expand Up @@ -91,20 +91,23 @@
}
],
"source": [
"# python function to get model output; replace this function with your own model function. \n",
"# python function to get model output; replace this function with your own model function.\n",
"def f(x):\n",
" tmp = x.copy()\n",
" preprocess_input(tmp)\n",
" return model(tmp)\n",
"\n",
"# define a masker that is used to mask out partitions of the input image. \n",
"\n",
"# define a masker that is used to mask out partitions of the input image.\n",
"masker = shap.maskers.Image(\"inpaint_telea\", X[0].shape)\n",
"\n",
"# create an explainer with model and image masker \n",
"# create an explainer with model and image masker\n",
"explainer = shap.Explainer(f, masker, output_names=class_names)\n",
"\n",
"# here we explain two images using 500 evaluations of the underlying model to estimate the SHAP values\n",
"shap_values = explainer(X[1:3], max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) "
"shap_values = explainer(\n",
" X[1:3], max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]\n",
")"
]
},
{
Expand Down Expand Up @@ -214,20 +217,23 @@
}
],
"source": [
"# python function to get model output; replace this function with your own model function. \n",
"# python function to get model output; replace this function with your own model function.\n",
"def f(x):\n",
" tmp = x.copy()\n",
" preprocess_input(tmp)\n",
" return model(tmp)\n",
"\n",
"# define a masker that is used to mask out partitions of the input image. \n",
"\n",
"# define a masker that is used to mask out partitions of the input image.\n",
"masker_blur = shap.maskers.Image(\"blur(128,128)\", X[0].shape)\n",
"\n",
"# create an explainer with model and image masker \n",
"# create an explainer with model and image masker\n",
"explainer_blur = shap.Explainer(f, masker_blur, output_names=class_names)\n",
"\n",
"# here we explain two images using 500 evaluations of the underlying model to estimate the SHAP values\n",
"shap_values_fine = explainer_blur(X[1:3], max_evals=5000, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) "
"shap_values_fine = explainer_blur(\n",
" X[1:3], max_evals=5000, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]\n",
")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
"metadata": {},
"outputs": [],
"source": [
"import torch, torchvision\n",
"from torch import nn\n",
"from torchvision import transforms, models, datasets\n",
"import shap\n",
"import json\n",
"import numpy as np"
"\n",
"import numpy as np\n",
"import torch\n",
"from torchvision import models\n",
"\n",
"import shap"
]
},
{
Expand All @@ -34,6 +35,7 @@
"mean = [0.485, 0.456, 0.406]\n",
"std = [0.229, 0.224, 0.225]\n",
"\n",
"\n",
"def normalize(image):\n",
" if image.max() > 1:\n",
" image /= 255\n",
Expand Down Expand Up @@ -64,7 +66,7 @@
"# load the model\n",
"model = models.vgg16(pretrained=True).eval()\n",
"\n",
"X,y = shap.datasets.imagenet50()\n",
"X, y = shap.datasets.imagenet50()\n",
"\n",
"X /= 255\n",
"\n",
Expand All @@ -75,9 +77,11 @@
"fname = shap.datasets.cache(url)\n",
"with open(fname) as f:\n",
" class_names = json.load(f)\n",
" \n",
"\n",
"e = shap.GradientExplainer((model, model.features[7]), normalize(X))\n",
"shap_values,indexes = e.shap_values(normalize(to_explain), ranked_outputs=2, nsamples=200)\n",
"shap_values, indexes = e.shap_values(\n",
" normalize(to_explain), ranked_outputs=2, nsamples=200\n",
")\n",
"\n",
"# get the names for the classes\n",
"index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)\n",
Expand Down Expand Up @@ -118,8 +122,12 @@
"source": [
"# note that because the inputs are scaled to be between 0 and 1, the local smoothing also has to be\n",
"# scaled compared to the Keras model\n",
"explainer = shap.GradientExplainer((model, model.features[7]), normalize(X), local_smoothing=0.5)\n",
"shap_values,indexes = explainer.shap_values(normalize(to_explain), ranked_outputs=2, nsamples=200)\n",
"explainer = shap.GradientExplainer(\n",
" (model, model.features[7]), normalize(X), local_smoothing=0.5\n",
")\n",
"shap_values, indexes = explainer.shap_values(\n",
" normalize(to_explain), ranked_outputs=2, nsamples=200\n",
")\n",
"\n",
"# get the names for the classes\n",
"index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,37 @@
}
],
"source": [
"from keras.applications.vgg16 import VGG16\n",
"from keras.applications.vgg16 import preprocess_input, decode_predictions\n",
"import json\n",
"\n",
"import keras.backend as K\n",
"import numpy as np\n",
"from keras.applications.vgg16 import VGG16, preprocess_input\n",
"\n",
"import shap\n",
"import keras.backend as K\n",
"import json\n",
"\n",
"# load pre-trained model and choose two images to explain\n",
"model = VGG16(weights='imagenet', include_top=True)\n",
"X,y = shap.datasets.imagenet50()\n",
"to_explain = X[[39,41]]\n",
"model = VGG16(weights=\"imagenet\", include_top=True)\n",
"X, y = shap.datasets.imagenet50()\n",
"to_explain = X[[39, 41]]\n",
"\n",
"# load the ImageNet class names\n",
"url = \"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\"\n",
"fname = shap.datasets.cache(url)\n",
"with open(fname) as f:\n",
" class_names = json.load(f)\n",
"\n",
"\n",
"# explain how the input to the 7th layer of the model explains the top two classes\n",
"def map2layer(x, layer):\n",
" feed_dict = dict(zip([model.layers[0].input], [preprocess_input(x.copy())]))\n",
" return K.get_session().run(model.layers[layer].input, feed_dict)\n",
"e = shap.GradientExplainer((model.layers[7].input, model.layers[-1].output), map2layer(preprocess_input(X.copy()), 7))\n",
"shap_values,indexes = e.shap_values(map2layer(to_explain, 7), ranked_outputs=2)\n",
"\n",
"\n",
"e = shap.GradientExplainer(\n",
" (model.layers[7].input, model.layers[-1].output),\n",
" map2layer(preprocess_input(X.copy()), 7),\n",
")\n",
"shap_values, indexes = e.shap_values(map2layer(to_explain, 7), ranked_outputs=2)\n",
"\n",
"# get the names for the classes\n",
"index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)\n",
Expand Down Expand Up @@ -109,9 +116,9 @@
"explainer = shap.GradientExplainer(\n",
" (model.layers[7].input, model.layers[-1].output),\n",
" map2layer(preprocess_input(X.copy()), 7),\n",
" local_smoothing=100\n",
" local_smoothing=100,\n",
")\n",
"shap_values,indexes = explainer.shap_values(map2layer(to_explain, 7), ranked_outputs=2)\n",
"shap_values, indexes = explainer.shap_values(map2layer(to_explain, 7), ranked_outputs=2)\n",
"\n",
"# get the names for the classes\n",
"index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)\n",
Expand Down

0 comments on commit 4316c41

Please sign in to comment.