Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add wordnet helpers and rename synsets properties
  • Loading branch information
Ludwig Schubert committed Jan 23, 2019
1 parent b24b223 commit 9933879
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 27 deletions.
13 changes: 5 additions & 8 deletions lucid/misc/io/loading.py
Expand Up @@ -44,15 +44,14 @@ def _load_npy(handle, **kwargs):
return np.load(handle)


def _load_img(handle, target_dtype=np.float32, **kwargs):
def _load_img(handle, target_dtype=np.float32, size=None):
"""Load image file as numpy array."""

image_pil = PIL.Image.open(handle)

# resize the image to the requested size, if one was specified
if 'size' in kwargs:
requested_image_size = kwargs['size']
image_pil = image_pil.resize(requested_image_size, resample=PIL.Image.LANCZOS)
if size is not None:
image_pil = image_pil.resize(size, resample=PIL.Image.LANCZOS)

image_array = np.asarray(image_pil)

Expand All @@ -67,9 +66,8 @@ def _load_img(handle, target_dtype=np.float32, **kwargs):
return np.divide(image_array, image_max_value, dtype=target_dtype)


def _load_json(handle, **kwargs):
def _load_json(handle):
"""Load json file as python object."""
del kwargs
return json.load(handle)


Expand All @@ -84,9 +82,8 @@ def _load_text(handle, split=False, encoding="utf-8"):
return string


def _load_graphdef_protobuf(handle, **kwargs):
def _load_graphdef_protobuf(handle):
"""Load GraphDef from a binary proto file."""
del kwargs
return tf.GraphDef.FromString(handle.read())


Expand Down
20 changes: 5 additions & 15 deletions lucid/modelzoo/vision_base.py
Expand Up @@ -22,18 +22,13 @@
import tensorflow as tf
import numpy as np

from lucid.modelzoo.util import load_text_labels, load_graphdef, forget_xy
from lucid.modelzoo.util import load_graphdef, forget_xy
from lucid.modelzoo.aligned_activations import get_aligned_activations as _get_aligned_activations
from lucid.misc.io import load
import lucid.misc.io.showing as showing

# ImageNet classes correspond to WordNet Synsets.
# If NLTK and the WordNet corpus are installed, we can support
# interoperability in a few places.
try:
from nltk.corpus import wordnet
except:
wordnet = None
from lucid.modelzoo.wordnet import synset_from_id


IMAGENET_MEAN = np.array([123.68, 116.779, 103.939])
Expand Down Expand Up @@ -111,15 +106,10 @@ class Model(with_metaclass(ModelPropertiesMetaClass, object)):
def __init__(self):
self.graph_def = None
if hasattr(self, 'labels_path') and self.labels_path is not None:
self.labels = load_text_labels(self.labels_path)
self.labels = load(self.labels_path, split=True)
if hasattr(self, 'synsets_path') and self.synsets_path is not None:
self.synsets = load_text_labels(self.synsets_path)
# If NLTK WordNet is available, provide synsets in that form as well.
if wordnet is not None:
def get_synset(id_str):
pos, offset = id_str[0], int(id_str[1:])
return wordnet.synset_from_pos_and_offset(pos, offset)
self.nltk_synsets = [get_synset(id) for id in self.synsets]
self.synset_ids = load(self.synsets_path, split=True)
self.synsets = [synset_from_id(id) for id in self.synset_ids]

@property
def name(self):
Expand Down
104 changes: 104 additions & 0 deletions lucid/modelzoo/wordnet.py
@@ -0,0 +1,104 @@
# Copyright 2019 The Lucid Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers for using WordNet Synsets.
When comparing different models, be aware that they may encode their predictions in
different orders. Do not compare outputs of models without ensuring their outputs are
in the same order! We recommend relying on WordNet's synsets to uniquely identify a
label. Let's clarify these terms:
## Labels ("Labrador Retriever")
Label are totally informal and vary between implementations. We aim to provide a list of
model labels in the `.labels` property. These may include differen labels in different
orders for each model.
For translating between textual labels and synsets, plase use the labels and synsets
collections on models. There's no other foolproof way of goinfg from a descriptive text
label to a precise synset definition.
## Synset IDs ("n02099712")
Synset IDs are identifiers used by the ILSVRC2012 ImageNet classification contest.
We provide `id_from_synset()` to format them correctly.
## Synsets Names ('labrador_retriever.n.01')
Synset names are a wordnet internal concept. When youw ant to create a synset but don't
know its precise name, we offer `imagenet_synset_from_description()` to search for a
synset containing the description in its name that is also one of the synsets used for
the ILSVRC2012.
## Label indexes (logits[i])
When obtaining predictions from a model, they will often be provided in the form of a
BATCH by NUM_CLASSES multidimensional array. In order to map those to human readable
strings, please use a model's `.labels` or `.synsets` or `.synset_ids` property. We aim
to provide these in the same ordering as the model was trained on. Unfortunately these
may be subtly different between models.
"""

from cachetools.func import lru_cache

import nltk
nltk.download("wordnet")
from nltk.corpus import wordnet as wn

from lucid.misc.io import load


IMAGENET_SYNSETS_PATH = "gs://modelzoo/labels/ImageNet_standard_synsets.txt"


def id_from_synset(synset):
return f"{synset.pos()}{synset.offset():08}"


def synset_from_id(id_str):
assert len(id_str) == 1 + 8
pos, offset = id_str[0], int(id_str[1:])
return wn.synset_from_pos_and_offset(pos, offset)


@lru_cache(maxsize=1)
def imagenet_synset_ids():
return load(IMAGENET_SYNSETS_PATH, split=True)


@lru_cache(maxsize=1)
def imagenet_synsets():
return [synset_from_id(id) for id in imagenet_synset_ids()]


@lru_cache()
def imagenet_synset_from_description(search_term):
names_and_synsets = [(synset.name(), synset) for synset in imagenet_synsets()]
candidates = [
synset for (name, synset) in names_and_synsets if search_term.lower().replace(' ', '_') in name
]
hits = len(candidates)
if hits == 1:
return candidates[0]
if hits == 0:
message = "Could not find any imagenet synset with search term {}."
raise ValueError(message.format(search_term))
else:
message = "Found {} imagenet synsets with search term {}: {}."
names = [synset.name() for synset in candidates]
raise ValueError(message.format(hits, search_term, ", ".join(names)))
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -2,7 +2,7 @@
description-file = README.md

[aliases]
test = pytest -s
test=pytest

[flake8]
ignore = E501,E731,E111
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Expand Up @@ -57,16 +57,17 @@
"scipy",
"scikit-learn",
"umap-learn",
"nltk",
"ipython",
"pillow",
"future",
"decorator",
"pyopengl",
"click",
"filelock",
"cachetools",
"more-itertools"
"more-itertools",
],
setup_requires=["pytest-runner"],
tests_require=test_deps,
extras_require=extras,
classifiers=[
Expand Down
1 change: 0 additions & 1 deletion tests/modelzoo/test_nets_factory.py
Expand Up @@ -18,7 +18,6 @@
from __future__ import print_function

import pytest
import tensorflow as tf

from lucid.modelzoo.nets_factory import get_model, models_map
from lucid.modelzoo.vision_models import InceptionV1
Expand Down
56 changes: 56 additions & 0 deletions tests/modelzoo/test_wordnet.py
@@ -0,0 +1,56 @@
import pytest

import nltk

nltk.download("wordnet")
from nltk.corpus import wordnet as wn

from lucid.modelzoo.wordnet import (
id_from_synset,
synset_from_id,
imagenet_synset_ids,
imagenet_synsets,
imagenet_synset_from_description,
)


@pytest.fixture()
def synset():
return wn.synset("great_white_shark.n.01")


@pytest.fixture()
def synset_id():
return "n01484850"


def test_id_from_synset(synset, synset_id):
result = id_from_synset(synset)
assert result == synset_id


def test_synset_from_id(synset_id, synset):
result = synset_from_id(synset_id)
assert result == synset


def test_imagenet_synset_ids(synset_id):
synset_ids = imagenet_synset_ids()
assert len(synset_ids) == 1000
assert synset_id in synset_ids


def test_imagenet_synsets(synset):
synsets = imagenet_synsets()
assert len(synsets) == 1000
assert synset in synsets


def test_imagenet_synset_from_description(synset):
synset_from_description = imagenet_synset_from_description("white shark")
assert synset == synset_from_description


def test_imagenet_synset_from_description_raises(synset):
with pytest.raises(ValueError, match=r'.*great_white_shark.*tiger_shark.*'):
imagenet_synset_from_description("shark")

0 comments on commit 9933879

Please sign in to comment.