## Image segmentation Demo

In [None]:
import random

# data connection
import boto3
import io

# visualization
from matplotlib import pyplot as plt

# computer vision
import mxnet as mx
from mxnet import image
from mxnet.gluon.data.vision import transforms
import gluoncv
from gluoncv.data.transforms.presets.segmentation import test_transform

### Access Data

In [1]:
class DataLakeConnector():
    
    def __init__(self, access_key_id, secret_access_key, bucket):
        
        self.access_key_id = access_key_id
        self.secret_access_key = secret_access_key
        
        self.client = self.create_s3_client()
        
        self.bucket = bucket
    
    def create_s3_client(self):
        return boto3.client(
            's3', 
            aws_access_key_id = self.access_key_id, 
            aws_secret_access_key = self.secret_access_key
        )
    
    def list_files(self, dir_path):
        response = self.client.list_objects_v2(Bucket=self.bucket, Prefix=dir_path)
        files = response.get("Contents")
        file_paths = [f['Key'] for f in files]
        return file_paths
    
    def download_img(self, img_path):
        outfile = io.BytesIO()
        self.client.download_fileobj(self.bucket, img_path, outfile)
        outfile.seek(0)
        img = plt.imread(outfile, 'jpg')
        return img

##### Connect to Data Lake

In [None]:
# set keys
access_key_id = 'set_key'
secret_access_key = 'set_key'

In [None]:
data_connector = DataLakeConnector(access_key_id, secret_access_key, 'zoetrope-downloads')

##### List images

In [None]:
# for a specific address
data_connector.list_files("places/9796 Nature Trail Way, Elk Grove, CA 95757, USA/")

In [None]:
# all downloaded images
all_image_paths = data_connector.list_files("")

In [None]:
# all New York images
ny_images = [ny_image for ny_image in all_image_paths if "NY" in ny_image]

##### Visualize images

In [None]:
sample_image = data_connector.download_img(random.choice(list(ny_images)))

plt.imshow(sample_image)
plt.show()

### Perform Image Segmentation

In [None]:
ctx = mx.cpu(0)

In [None]:
# chose model for segmentation
model = gluoncv.model_zoo.get_model('psp_resnet101_ade', pretrained=True)

In [2]:
def segment_image(img, model):
    img = mx.nd.array(sample_image)
    img = test_transform(img, ctx)
    output = model.predict(img)
    predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
    mask = get_color_pallete(predict, 'ade20k')
    mask.save('output.png')
    mmask = mpimg.imread('output.png')
    plt.imshow(mmask)
    plt.show()

In [None]:
%%time
# show a random a image downloaded via zoetrope
sample_image = data_connector.download_img(random.choice(all_image_paths))
plt.imshow(sample_image)
plt.show()

# segment chosen image
segment_image(sample_image, model)