Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example of bag of visual words (BoVW) using SIFT and a RandomForestClassifier #6126

Open
glemaitre opened this issue Dec 13, 2021 · 6 comments
Labels
📄 type: Documentation Updates, fixes and additions to documentation 😴 Dormant no recent activity

Comments

@glemaitre
Copy link
Contributor

glemaitre commented Dec 13, 2021

I'm super happy to have seen SIFT released recently. It recalls some memory from my computer vision course pre-deep-learning era :)

I quickly drafted (it should be double-checked for bugs) such a scikit-learn compatible transformer that uses SIFT and can be integrated with scikit-learn:

# %%
from tudarmstadt import fetch_tu_darmstadt

image_filenames, labels = fetch_tu_darmstadt()

# %%
import numpy as np
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cluster import MiniBatchKMeans
from sklearn.utils import check_random_state
from skimage.feature import SIFT
from skimage.io import imread


def _load_and_extract_sift(filename, sift):
    img = imread(filename, as_gray=True)
    sift.detect_and_extract(img)
    return sift.descriptors


def _descriptors_to_histogram(descriptors, dictionary):
    return np.histogram(
        dictionary.predict(descriptors), bins=range(dictionary.n_clusters), density=True
    )[0]


class BagOfVisualWords(TransformerMixin, BaseEstimator):
    def __init__(self, n_words, batch_size=1024, n_jobs=None, random_state=None):
        self.n_words = n_words
        self.batch_size = batch_size
        self.n_jobs = n_jobs
        self.random_state = random_state

    def fit_transform(self, X, y=None):
        random_state = check_random_state(self.random_state)

        self.dictionary = MiniBatchKMeans(
            n_clusters=self.n_words, random_state=random_state
        )
        self.sift = SIFT()

        descriptors = Parallel(n_jobs=self.n_jobs)(
            delayed(_load_and_extract_sift)(filename, self.sift) for filename in X
        )

        self.dictionary.fit(np.concatenate(descriptors))

        X_trans = Parallel(n_jobs=self.n_jobs)(
            delayed(_descriptors_to_histogram)(descr_img, self.dictionary)
            for descr_img in descriptors
        )

        return np.array(X_trans)

    def transform(self, X, y=None):
        descriptors = Parallel(n_jobs=self.n_jobs)(
            delayed(_load_and_extract_sift)(filename, self.sift) for filename in X
        )

        X_trans = Parallel(n_jobs=self.n_jobs)(
            delayed(_descriptors_to_histogram)(descr_img, self.dictionary)
            for descr_img in descriptors
        )

        return np.array(X_trans)


# %%
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline

bovw = BagOfVisualWords(n_words=1000, n_jobs=-1)
classifier = RandomForestClassifier(n_estimators=100, n_jobs=-1)
model = make_pipeline(bovw, classifier)

# %%
from sklearn.model_selection import cross_validate

cv_results = cross_validate(
    model, image_filenames, labels, cv=5, scoring='accuracy', return_train_score=True
)

# %%
import pandas as pd

cv_results = pd.DataFrame(cv_results)
cv_results
fit_time score_time test_score train_score
21.868036 5.510977 0.938462 1.0
22.021476 5.568267 0.938462 1.0
24.324928 4.670179 0.938462 1.0
25.808603 5.165597 0.984615 1.0
28.789002 6.064161 0.907692 1.0

Let me know if it is in the scope of scikit-image and if you would be interested in such an example.
Here, I am using a custom fetcher (see below) and I did not really tune the example to reduce the number of descriptors so it is quite slow.

Also, there is a possibility to make a BoVW more generic that take any detector/descriptor instead of SIFT. I don't know if such a class is indeed of broader interest as well.

Code for fetching data:

"""TU Darmstadt dataset.
The original database was available from
    http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz
"""

import os
from os.path import join, exists
from urllib.error import HTTPError
from urllib.request import urlopen

import numpy as np
import tarfile

from sklearn.datasets import get_data_home

DATA_URL = "http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz"
TARGET_FILENAME = "tud.pkz"


def fetch_tu_darmstadt(data_home=None):
    """Loader for the TU Darmstadt dataset.

    Read more in the :ref:`User Guide <datasets>`.


    Parameters
    ----------
    data_home : optional, default: None
        Specify another download and cache folder for the datasets. By default
        all scikit learn data is stored in '~/scikit_learn_data' subfolders.

    Returns
    -------
    images_list : list
        Python list with the path of each image to consider during the
        classification.

    labels : array-like, shape (n_images, )
        An array with the different label corresponding to the categories.
        0: motorbikes - 1: cars - 2: cows.

    Notes
    ------
    The dataset is composed of 124 motorbikes images, 100 cars, and 112 cows.

    Examples
    --------
    Load the 'tu-darmstadt' dataset:

    >>> from tudarmstadt import fetch_tu_darmstadt
    >>> import tempfile
    >>> test_data_home = tempfile.mkdtemp()
    >>> im_list, labels = fetch_tu_darmstadt(data_home=test_data_home)
    """

    # check if the data has been already downloaded
    data_home = get_data_home(data_home=data_home)
    data_home = join(data_home, 'tu_darmstadt')
    if not exists(data_home):
        os.makedirs(data_home)

    # dataset tar file
    filename = join(data_home, 'tud.tar.gz')

    # if the file does not exist, download it
    if not exists(filename):
        try:
            db_url = urlopen(DATA_URL)
            with open(filename, 'wb') as f:
                f.write(db_url.read())
            db_url.close()
        except HTTPError as e:
            if e.code == 404:
                e.msg = 'TU Darmstadt dataset not found.'
            raise
    # Try to extract the complete archieve
    try:
        tarfile.open(filename, "r:gz").extractall(path=data_home)
    finally:
        os.remove(filename)

    # the file 'motorbikes023' is a gray scale image and need to be removed
    file_removal = [
        join(data_home,
             'TUDarmstadt/PNGImages/motorbike-testset/motorbikes023.png'),
        join(data_home,
             'TUDarmstadt/Annotations/motorbike-testset/motorbikes023.txt'),
    ]
    for f in file_removal:
        os.remove(f)

    # list the different images
    data_path = join(data_home, 'TUDarmstadt/PNGImages')
    images_list = [os.path.join(root, name)
                   for root, dirs, files in os.walk(data_path)
                   for name in files
                   if name.endswith((".png"))]

    # create the label array
    labels = []
    for imf in images_list:
        if 'motorbike' in imf:
            labels.append(0)
        elif 'cars' in imf:
            labels.append(1)
        elif 'cows' in imf:
            labels.append(2)

    # Return these information
    return images_list, np.array(labels)
@mkcor
Copy link
Member

mkcor commented Dec 14, 2021

Hello @glemaitre,

Yay, thanks for sharing! We currently have only one gallery example showcasing SIFT; I think your addition would be very valuable.

Would you have time to submit your example to the gallery? I think it would fit well under ./doc/examples/applications/.

@glemaitre
Copy link
Contributor Author

Could you have time to submit your example to the gallery?

I can find time to submit a PR. I am just wondering if there is already a dataset (or a fetcher) in scikit-image that I could reuse?

@mkcor
Copy link
Member

mkcor commented Dec 14, 2021

I guess you could use the images (which are based on rgb2gray(data.astronaut())) which are used in the current SIFT example?

Otherwise, I was looking at http://host.robots.ox.ac.uk/pascal/VOC/ but I can't find any licensing info regarding the datasets...

@grlee77
Copy link
Contributor

grlee77 commented Dec 14, 2021

There is one existing face detection demo using a small faces vs. non-faces dataset in combination with scikit-learn

That ones uses Haar-like features instead of SIFT. could be interesting to compare the performance of the two? I'm not sure that we have any other bundled datasets that could be used. We do store some data externally from this repository at https://gitlab.com/scikit-image/data and it only gets downloaded as needed.

@grlee77 grlee77 added the 📄 type: Documentation Updates, fixes and additions to documentation label Dec 14, 2021
@glemaitre
Copy link
Contributor Author

Yep I can have a look at this dataset.

That ones uses Haar-like features instead of SIFT. could be interesting to compare the performance of the two?

Now that I see this example, I think that at this time I wanted to reproduce a similar approach to the Haar cascade of Viola and Jones. I could modify a bit this example to be a bit more readable indeed.

Basically with BoVW, could be modified and take different detectors/descriptors and then we could make a comparison.

@github-actions
Copy link

github-actions bot commented Oct 1, 2022

Hey, there hasn't been any activity on this issue for more than 180 days. For now, we have marked it as "dormant" until there is some new activity. You are welcome to reach out to people by mentioning them here or on our forum if you need more feedback! If you think that this issue is no longer relevant, you may close it by yourself; otherwise, we may do it at some point (either way, it will be done manually). In any case, thank you for your contributions so far!

@github-actions github-actions bot added the 😴 Dormant no recent activity label Oct 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
📄 type: Documentation Updates, fixes and additions to documentation 😴 Dormant no recent activity
Projects
None yet
Development

No branches or pull requests

3 participants