In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
from path import Path as p

In [2]:
PATH = "../../../data/"
PATH2 = "../../../data/Flicker8k_Dataset/"
sz=224
bs = 64
n = 1000 # Number of dogs and cats to train model on
# imbalanced classes
CATDOG = 0 # Dummy class variables
NOTCATDOG = 1 # Not Cat or Not Dog

In [3]:
def get_names(filelist, suffix):
    return [suffix + f.name for f in filelist]

In [4]:
train_cats = get_names(p(PATH + "dogscats/train/cats/").files(), "dogscats/train/cats/")
train_dogs = get_names(p(PATH + "dogscats/train/dogs/").files(), "dogscats/train/dogs/")
valid_cats = get_names(p(PATH + "dogscats/valid/cats/").files(), "dogscats/valid/cats/")
valid_dogs = get_names(p(PATH + "dogscats/valid/dogs/").files(), "dogscats/valid/dogs/")
flickr = get_names(p(PATH2).files(), "Flicker8k_Dataset/")
print(len(train_cats))
print(len(train_dogs))
print(len(valid_cats))
print(len(valid_dogs))
print(len(flickr))

11500
11500
1000
1000
8091


## Part 1: Not Cats

In [7]:
def prep(train, valid):
    p(PATH + "/tmp/").rmtree_p() # remove cached training data
    animals = list(np.random.choice(train, n, False)) + list(valid)
    notanimals = list(np.random.choice(flickr, n + len(valid), False))
    fn = animals + notanimals
    yy = np.array([CATDOG]*len(animals) + [NOTCATDOG]*len(notanimals))
    cs = list(set(yy))
    v_cat_dog_idx = range(n, n + len(valid))
    v_not_idx = range(len(animals) + n, len(fn))
    vi = list(v_cat_dog_idx) + list(v_not_idx)
    return (fn, yy, cs, vi)

# fn or fnames: file names
# yy or y: numpy array which contains target labels ordered by filenames.
# cs or classes: a list of all labels/classifications, [0, 1]
# vi or val_idxs: index of images to be used for validation.

In [8]:
fnames, y, classes, val_idxs = prep(train_cats, valid_cats)

In [9]:
# Let's set up our model using the pretrained Resnet34 Imagenet model
arch=resnet34
data = ImageClassifierData.from_names_and_array(PATH, fnames, y, classes, \
                                                val_idxs, bs=bs, \
                                                tfms=tfms_from_model(arch, sz))
learn = ConvLearner.pretrained(arch, data, precompute=True)

100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:28<00:00,  1.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:23<00:00,  1.35it/s]


In [10]:
# Use a learning rate of 0.01 and train for 5 epochs
lr = 0.01
epochs = 5
learn.fit(lr, epochs)

HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                                                                              
    0      0.14451    0.028791   0.9895    
    1      0.07773    0.022651   0.992                                                                                 
    2      0.052172   0.020733   0.993                                                                                 
    3      0.03796    0.020349   0.992                                                                                 
    4      0.031553   0.02006    0.9925                                                                                



[array([0.02006]), 0.9925]

In [11]:
# Create our prediction function
def predict(learner, pred_files):
    orig_precompute = learner.precompute
    learner.precompute = False
    trn_tfms, val_tfms = tfms_from_model(arch, sz)
    ds = FilesIndexArrayDataset(list(pred_files), np.zeros(len(pred_files)), val_tfms, PATH)
    dl = DataLoader(ds)
    log_preds = learner.predict_dl(dl)
    preds = np.exp(log_preds)
    results = np.argmax(preds, axis=1)
    learner.precompute = orig_precompute
    return results

In [12]:
# Now try to predict on dogs
pred_dogs = predict(learn, valid_dogs)
print(sum(pred_dogs))

802


## Part 2: Not Dogs

In [13]:
fnames, y, classes, val_idxs = prep(list(np.random.choice(flickr, n, False)), valid_dogs)

In [14]:
# Let's set up our model using the pretrained Resnet34 Imagenet model
arch=resnet34
data = ImageClassifierData.from_names_and_array(PATH, fnames, y, classes, \
                                                val_idxs, bs=bs, \
                                                tfms=tfms_from_model(arch, sz))
learn = ConvLearner.pretrained(arch, data, precompute=True)

100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:28<00:00,  1.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 32/32 [00:22<00:00,  1.45it/s]


In [18]:
# Use a learning rate of 0.01 and train for 5 epochs
lr = 0.01
epochs = 20
learn.fit(lr, epochs)

HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                                                                              
    0      0.641971   0.764582   0.441     
    1      0.629963   0.774343   0.4275                                                                                
    2      0.632966   0.75218    0.4605                                                                                
    3      0.622216   0.77162    0.4395                                                                                
    4      0.616978   0.764566   0.453                                                                                 
    5      0.603909   0.770566   0.4525                                                                                
    6      0.607761   0.763815   0.4665                                                                                
    7      0.607189   0.789551   0.4365                                                                             

[array([0.83266]), 0.4385]

In [13]:
# Now try to predict on cats
pred_cats = predict(learn, valid_cats)
print(sum(pred_cats))

237
