In [1]:
import numpy as np
from PIL import Image
import os
import glob
import math
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
DATA_DIR = '../data/tulips/'
IMG_DIR = os.path.join(DATA_DIR, 'sample/')
CHECKPOINTS_DIR = 'checkpoints'

In [3]:
import mxnet as mx
from mxnet import gluon
from mxnet.image import color_normalize

In [5]:
# ctx = mx.gpu(0)
ctx = mx.cpu()

In [20]:
# Load the pretrained network
rn50bin = gluon.model_zoo.vision.resnet101_v2(classes=2, ctx=ctx)
rn50bin.initialize(mx.init.Xavier())
rn50bin.collect_params().reset_ctx(ctx)

In [23]:
# Load results
checkpoints = os.path.join(CHECKPOINTS_DIR, 'resnet100-bin')
rn50bin.load_params(os.path.join(checkpoints, '44-0.params'), ctx)

In [9]:
labels = {0:'clear', 1:'cloudy'}

In [10]:
mean = mx.nd.array([0.485, 0.456, 0.406], ctx=ctx).reshape((3,1,1))
std =  mx.nd.array([0.229, 0.224, 0.225], ctx=ctx).reshape((3,1,1))

In [11]:
augs = [
     mx.image.RandomCropAug((224, 224))
]

def preprocess(data, augs):
    data = mx.nd.array(data).astype('float32').as_in_context(ctx)
    for aug in augs:
        data = aug(data)
    data = mx.nd.transpose(data, (2,0,1))
    data = color_normalize(data/255, mean, std)
    return data

def read_img(filename):
    img = Image.open(filename)
    return preprocess(img, augs)

In [12]:
def load_batch(filenames):
    batch = mx.nd.empty((len(filenames),3,224,224))
    for idx,fn in enumerate(filenames):
        batch[idx] = read_img(fn)
    
    return batch

In [17]:
def filter_folder(root, batch_size, net, ext='.png'):
    files = glob.glob(root + '*' + ext)
    cloudy = []
    
    for n in range(math.ceil(len(files)/batch_size)):
        files_batch = files[n*batch_size:(n+1)*batch_size]
        batch = load_batch(files_batch)
        batch = batch.as_in_context(ctx)
        preds = mx.nd.argmax(net(batch), axis=1)
        idxs = np.nonzero(preds.asnumpy())[0]
        cloudy.extend([files_batch[i] for i in idxs])
        
    # Do something with the cloudy images (delete/move to new directory)
    print(cloudy)

In [24]:
filter_folder(IMG_DIR, 2, rn50bin)

Batch 0
Batch 1
Batch 2
Batch 3
['../data/tulips/sample/img8.png', '../data/tulips/sample/img7.png']
