# Downloading dataset

In [1]:
# install kaggle python api
!pip install kaggle --upgrade

Collecting kaggle
  Downloading kaggle-1.5.12.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 1.7 MB/s eta 0:00:01
Collecting python-slugify
  Downloading python_slugify-6.1.2-py2.py3-none-any.whl (9.4 kB)
Collecting text-unidecode>=1.3
  Downloading text_unidecode-1.3-py2.py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 5.3 MB/s eta 0:00:01
Building wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py) ... [?25ldone
[?25h  Created wheel for kaggle: filename=kaggle-1.5.12-py3-none-any.whl size=73053 sha256=ca6724ad6aa14af2a117bacdbc1a913d871a936bfe1346ddbc1715af638db02b
  Stored in directory: /Users/zhaoyudong/Library/Caches/pip/wheels/29/da/11/144cc25aebdaeb4931b231e25fd34b394e6a5725cbb2f50106
Successfully built kaggle
Installing collected packages: text-unidecode, python-slugify, kaggle
Successfully installed kaggle-1.5.12 python-slugify-6.1.2 text-unidecode-1.3


In [2]:
import os, json, subprocess

In [3]:
def check_dataset(path=''):
    if not path:
        path = os.getcwd() + '/dataset/'  
    return os.path.exists(path + 'cataract') and os.path.exists(path + 'diabetic_retinopathy') and os.path.exists(path + 'glaucoma') and os.path.exists(path + 'normal')
    
def init_on_kaggle(username, api_key):
    KAGGLE_CONFIG_DIR = os.path.join(os.path.expandvars('$HOME'), '.kaggle')
    os.makedirs(KAGGLE_CONFIG_DIR, exist_ok = True)
    api_dict = {"username":username, "key":api_key}
    with open(f"{KAGGLE_CONFIG_DIR}/kaggle.json", "w", encoding='utf-8') as f:
        json.dump(api_dict, f)
    cmd = f"chmod 600 {KAGGLE_CONFIG_DIR}/kaggle.json"
    output = subprocess.check_output(cmd.split(" "))
    output = output.decode(encoding='UTF-8')
    print(output)
    
    
def download_dataset_from_kaggle():
    init_on_kaggle("yudzhao", '14e199e96baf549cf5fbf0c5f2dfbc27')
    
    import kaggle
    dataset_name = "gunavenkatdoddi/eye-diseases-classification"
    print(kaggle.api.dataset_view(dataset_name))
    
    kaggle.api.dataset_download_files(dataset_name)
    
    os.system('unzip eye-diseases-classification.zip')
    

In [4]:
dataset_location = os.getcwd() + "/dataset/"

In [5]:
if not check_dataset(dataset_location):
    print('The dataset doesn\'t exist. Try to load from kaggle')
    download_dataset_from_kaggle()

In [6]:
!ls

README.md                         eye-diseases-classification.zip
[34mdataset[m[m                           eye_diseases_classification.ipynb


# Data Preparation

## Import Images

In [7]:
!pip install Pillow



In [9]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.3.1.tar.gz (281.4 MB)
[K     |████████████████████████████████| 281.4 MB 69 kB/s  eta 0:00:012    |█████                           | 44.9 MB 6.8 MB/s eta 0:00:35     |█████▍                          | 47.8 MB 6.8 MB/s eta 0:00:35     |████████▎                       | 72.5 MB 5.1 MB/s eta 0:00:42     |█████████▊                      | 85.1 MB 11.8 MB/s eta 0:00:17     |████████████                    | 105.1 MB 6.8 MB/s eta 0:00:26     |█████████████                   | 114.6 MB 12.5 MB/s eta 0:00:14     |███████████████▎                | 134.7 MB 9.3 MB/s eta 0:00:16     |███████████████▋                | 137.3 MB 9.3 MB/s eta 0:00:16     |██████████████████████▌         | 197.4 MB 5.9 MB/s eta 0:00:15     |██████████████████████▌         | 197.7 MB 5.9 MB/s eta 0:00:15     |██████████████████████▊         | 199.7 MB 9.9 MB/s eta 0:00:09     |███████████████████████▏        | 203.7 MB 9.9 MB/s eta 0:00:08     |████████████████████████▉      

In [10]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col,when,length,sum,avg,max,count,round
import seaborn as sns
from PIL import Image, ImageOps, ImageFilter
import io

In [12]:
df = spark.read.format('image').option('dropInvalid', True) \
    .option("recursiveFileLookup","true").load(dataset_location).cache()

In [13]:
label = when(col('image.origin').contains('cataract'), 'cataract')\
.when(col('image.origin').contains('diabetic_retinopathy'), 'diabetic_retinopathy')\
.when(col('image.origin').contains('glaucoma'), 'glaucoma')\
.otherwise('normal')

df = df.withColumn('type', label)

In [14]:
df.printSchema()

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)
 |-- type: string (nullable = false)



In [18]:
df.count()

2977

In [20]:
df.select('type', 'image.origin', 'image.width', 'image.height', 'image.nChannels', 'image.mode').show(10, truncate=False)

+--------+-----------------------------------------------------------------------------------+-----+------+---------+----+
|type    |origin                                                                             |width|height|nChannels|mode|
+--------+-----------------------------------------------------------------------------------+-----+------+---------+----+
|cataract|file:///Users/zhaoyudong/workspace/clarku/BAN5600/dataset/cataract/cataract_024.png|2464 |1632  |3        |16  |
|glaucoma|file:///Users/zhaoyudong/workspace/clarku/BAN5600/dataset/glaucoma/Glaucoma_081.png|2464 |1632  |3        |16  |
|glaucoma|file:///Users/zhaoyudong/workspace/clarku/BAN5600/dataset/glaucoma/Glaucoma_072.png|2464 |1632  |3        |16  |
|glaucoma|file:///Users/zhaoyudong/workspace/clarku/BAN5600/dataset/glaucoma/Glaucoma_024.png|2464 |1632  |3        |16  |
|glaucoma|file:///Users/zhaoyudong/workspace/clarku/BAN5600/dataset/glaucoma/Glaucoma_071.png|2464 |1632  |3        |16  |
|glaucoma|file:/

In [24]:
df.select('type', 'image.origin', 'image.width', 'image.height', 'image.nChannels', 'image.mode').summary().show()

+-------+--------+--------------------+------------------+------------------+---------+----+
|summary|    type|              origin|             width|            height|nChannels|mode|
+-------+--------+--------------------+------------------+------------------+---------+----+
|  count|    2977|                2977|              2977|              2977|     2977|2977|
|   mean|    null|                null| 648.1746724890829|  591.086328518643|      3.0|16.0|
| stddev|    null|                null|507.94338819232763|295.39371099230834|      0.0| 0.0|
|    min|cataract|file:///Users/zha...|               512|               512|        3|  16|
|    25%|    null|                null|               512|               512|        3|  16|
|    50%|    null|                null|               512|               512|        3|  16|
|    75%|    null|                null|               512|               512|        3|  16|
|    max|  normal|file:///Users/zha...|              2592|            

## Feature Extraction

In [None]:
from pyspark.sql.functions import udf, lit
from pyspark.sql.types import BinaryType, ArrayType, IntegerType
import numpy as np
from pyspark.ml.linalg import Vectors, VectorUDT

### Grayscale

In [None]:
def to_grayscale(image):
    img_obj = Image.frombytes('RGB', (image['width'], image['height']), bytes(image['data']))
    img_obj = ImageOps.grayscale(img_obj)
   # imgByteArr = io.BytesIO()
   # img_obj.save(imgByteArr, format=image.format)
    return bytearray(img_obj.tobytes())

In [None]:
type(bytearray(img.tobytes()))

In [None]:
grayscale_udf = udf(lambda x:to_grayscale(x), BinaryType())

In [None]:
gray_df = df.withColumn("gray", grayscale_udf(df.image))

### Edge

In [None]:
# def find_edge(image):
#     img_obj = Image.frombytes('RGB', (image['width'], image['height']), bytes(image['data']))
#     img_obj =  img_obj.convert("L").filter(ImageFilter.FIND_EDGES).tobytes()
#     return bytearray(img_obj)


In [None]:
# edge_udf = udf(lambda x:find_edge(x), ArrayType(IntegerType()))

In [None]:
# edge_df = df.withColumn("edge", edge_udf(df.image))

compare orign image, gray image, and edge image

In [None]:
img.show()

In [None]:
first_ = gray_df.first()

In [None]:
Image.frombytes('L', (first_['image']['width'], first_['image']['height']), bytes(first_['gray'])).show()