# Paw Patrol classification training

In [None]:
from google_images_download import google_images_download

data_folder = "train"
if (not os.path.exists(data_folder)):
    os.mkdir(data_folder)

response = google_images_download.googleimagesdownload()

def download_paw_patrol(folder, query):
    if (not os.path.exists(data_folder + '/' + folder)):

        arguments = {"keywords": query,
                    "output_directory": data_folder,
                    "image_directory": folder,
                    "limit": 50,
                    "print_urls": True}

        try:
            paths = response.download(arguments)
            print(paths)
        except Exception as e:
            print('Exception ' + str(e))

chase_query = 'paw patrol chase -marshall -rubble -rocky -skye -zuma -everest'
marshall_query = 'paw patrol -chase marshall -rubble -rocky -skye -zuma -everest'
rubble_query = 'paw patrol -chase -marshall rubble -rocky -skye -zuma -everest'
rocky_query = 'paw patrol -chase -marshall -rubble rocky -skye -zuma -everest'
zuma_query = 'paw patrol -chase -marshall -rubble -rocky -skye zuma -everest'
skye_query = 'paw patrol -chase -marshall -rubble -rocky skye -zuma -everest'

download_paw_patrol("chase", chase_query)
download_paw_patrol("marshall", marshall_query)
download_paw_patrol("rubble", rubble_query)
download_paw_patrol("rocky", rocky_query)
download_paw_patrol("zuma", zuma_query)
download_paw_patrol("skye", skye_query)

In [None]:
from fastai.vision import *
from fastai.metrics import error_rate

# Load data
np.random.seed(42)
data = ImageDataBunch.from_folder("train", valid_pct=0.2, ds_tfms=get_transforms(), size=224, bs=32).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,6))

# Use ResNet34 as the architecture
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.load("final")
learn.export()

In [None]:
# Train stage 1
learn.fit_one_cycle(4)
learn.fit_one_cycle(4)
learn.save('stage-1.1')

In [None]:
# Show error
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
interp.plot_top_losses(9, figsize=(15,11))

In [None]:
# Run learning rate estimator
learn.lr_find()
learn.recorder.plot()

In [None]:
# Train second pass
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
learn.save('stage-2')

In [None]:
# Show error
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
interp.plot_top_losses(9, figsize=(15,11))