In [1]:
import mxnet as mx
%matplotlib inline
import os
import sys
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile
import boto3
import botocore

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
BUCKET_NAME = 'reinvent2018-builder-fair-recycle-arm-us-east-1'
dataset = 'imagenet_expanded'
project = 'imagenet_updated'

train_folder = 'train'
validation_folder = 'validation'

data_path_list = ['/tmp/data/{}'.format(train_folder), '/tmp/data/{}'.format(validation_folder)]


DATA_BUCKET_KEY = 'data/raw_data/{}'.format(dataset)
IMAGENET_CATEGORIES = {"bottle" : "n04557648", 
                       "coffee_cup": "n03216710", 
#                        "plastic_bag" : "n03958227", 
                      "soda_can" : "n04255586",
#                       "cardboard": "n03871724",
#                       "trash": "trash",
#                       "plastic": "plastic",
#                       "glass": "glass"
#                       "wine_bottle": "n04591713"
                      "aerosol_can":"n02682922"
                      }

In [3]:
import shutil
import os
shutil.rmtree('/tmp/data/', ignore_errors=True, onerror=None)
os.mkdir('/tmp/data/');

In [4]:
s3 = boto3.resource('s3')

for category_key, category_value in IMAGENET_CATEGORIES.items():
    Key = "{}/{}.tar".format(DATA_BUCKET_KEY, category_value)
    try:
        print("Downloading {}/{}.tar".format(BUCKET_NAME,category_value))
        s3.Bucket(BUCKET_NAME).download_file(Key, '/tmp/data/{}.tar'.format(category_key))
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == "404":
            print("The object does not exist.")
        else:
            raise

Downloading reinvent2018-builder-fair-recycle-arm-us-east-1/n04557648.tar
Downloading reinvent2018-builder-fair-recycle-arm-us-east-1/n03216710.tar
Downloading reinvent2018-builder-fair-recycle-arm-us-east-1/n04255586.tar
Downloading reinvent2018-builder-fair-recycle-arm-us-east-1/n02682922.tar


In [5]:
for root, dirs, files in os.walk("/tmp/data/"):  
    for filename in files:
        print(filename)

coffee_cup.tar
soda_can.tar
bottle.tar
aerosol_can.tar


In [6]:
import tarfile
for root, dirs, files in os.walk("/tmp/data/"):  
    for filename in files:
        print(filename)
        tar = tarfile.open("/tmp/data/{}".format(filename))
        tar.extractall(path='/tmp/data/{}/{}'.format(train_folder, filename.split(".tar")[0]))
        tar.close()


coffee_cup.tar
soda_can.tar
bottle.tar
aerosol_can.tar


In [7]:
#Create train and test data split
import tarfile
from shutil import copyfile
file_count = 0;
for category_key, category_value in IMAGENET_CATEGORIES.items():
    for root, dirs, files in os.walk("/tmp/data/{}/{}".format(train_folder, category_key)):
        file_count = file_count + len(files)
        i = 1
        for filename in files:
            #print(os.path.join(root, filename))
            os.makedirs("/tmp/data/{}/{}".format(validation_folder, category_key), exist_ok=True)
            copyfile(os.path.join(root, filename), os.path.join("/tmp/data/{}/{}"\
                                                                .format(validation_folder, category_key), filename))
            i = i + 1
            if i == 100:
                break

print("Total file count {}".format(file_count))

Total file count 5085


In [8]:
import os
import glob

files = glob.glob('/tmp/data/*.tar')
for f in files:
    os.remove(f)

In [9]:
#im2rec_path = mx.test_utils.get_im2rec_path()
im2rec_path = "/home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/tools/im2rec.py"

for data_path in data_path_list:
    print (data_path)
    data_path_split = data_path.split("/")
    print(data_path_split)
    prefix_path = "/" + data_path_split[1] + "/" + data_path_split[2] + "/" + "imagenet_" + data_path_split[3]
    print(prefix_path)
    with open(os.devnull, 'wb') as devnull:
        subprocess.check_call(['python', im2rec_path, '--list', '--recursive', prefix_path, data_path],
                              stdout=devnull)

    with open(os.devnull, 'wb') as devnull:
        subprocess.check_call(['python', im2rec_path, '--num-thread=4', '--quality=100', '--resize=480', prefix_path, data_path],
                              stdout=devnull)

/tmp/data/train
['', 'tmp', 'data', 'train']
/tmp/data/imagenet_train
/tmp/data/validation
['', 'tmp', 'data', 'validation']
/tmp/data/imagenet_validation


In [10]:
import boto3
BUCKET_NAME = 'deeplens-image-classification-varunrao'

s3 = boto3.resource('s3')

try:
    for data_path in data_path_list:
        data_path_split = data_path.split("/")
        prefix_path = "/" + data_path_split[1] + "/" + data_path_split[2] + "/" + "imagenet_" + data_path_split[3]
        s3.Bucket(BUCKET_NAME).upload_file('{}.rec'.format(prefix_path), '{}/{}'.format(project, data_path_split[3]) + "/{}.rec".format("imagenet_" + data_path_split[3]))
except botocore.exceptions.ClientError as e:
    if e.response['Error']['Code'] == "404":
        print("The object does not exist.")
    else:
        raise

In [None]:
data_iter = mx.io.ImageRecordIter(
    path_imgrec=os.path.join('/tmp','data','imagenet_validation.rec'),
    data_shape=(3, 500, 500), # output data shape. An 227x227 region will be cropped from the original image.
    batch_size=4, # number of samples per batch
    resize=256 # resize the shorter edge to 256 before cropping
    # ... you can add more augmentation options as defined in ImageRecordIter.
    )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

In [None]:
# data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
#                               path_imgrec=os.path.join('/tmp','data','imagenet_validation.rec'),
#                               path_imgidx=os.path.join('/tmp','data','imagenet_validation.idx') )
# data_iter.reset()
# batch = data_iter.next()
# data = batch.data[0]
# for i in range(4):
#     plt.subplot(1,4,i+1)
#     plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
# plt.show()

In [None]:
file_name = '/tmp/data/train/coffee_cup/n03216710_10011.JPEG'
# test image
from IPython.display import Image
Image(file_name)