In [None]:
import shutil
from pathlib import Path
import os

# Fastai related libraries.
import fastbook
from fastai.vision.all import *
from fastai.vision.widgets import *

In [None]:
image_download_dir = 'build/data/images'

shutil.rmtree(image_download_dir, ignore_errors=True)
os.makedirs(image_download_dir, exist_ok=True)

def download_images(search_terms, label, num_images_to_download):
    dest_dir = Path(image_download_dir) / label

    num_images_per_term = round(num_images_to_download / len(search_terms))

    for term in search_terms:
        results = fastbook.search_images_ddg(term, num_images_per_term)
        fastbook.download_images(dest_dir, urls=results)

    # Remove any files that don't load as proper image files.
    failed = fastbook.verify_images(fastbook.get_image_files(dest_dir))
    failed.map(Path.unlink)


download_images(['panther animal'], 'panther', 200)
download_images(['leopard animal'], 'leopard', 200)
download_images(['snow leopard animal'], 'snow leopard', 200)
download_images(['tiger animal'], 'tiger', 200)
download_images(['lion animal'], 'lion', 200)
download_images(['cheetah animal'], 'cheetah', 200)
download_images(['cougar animal'], 'cougar', 200)

In [None]:
data_loaders = DataBlock(
    blocks=[ImageBlock, CategoryBlock],
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=1168),
    get_y=parent_label,
    item_tfms=[Resize(192, method='squish')]
).dataloaders(image_download_dir, bs=150)

data_loaders.valid.show_batch(max_n=20, nrows=5)

In [None]:
learn = vision_learner(data_loaders, resnet18, metrics=error_rate)
learn.fine_tune(2)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(8, ncols=1)

In [None]:
learn.export()