In [1]:
# Fish Classifier V0.1

#Curious as to what fish you've caught? Use this tool to get a better idea!

In [2]:
%%capture
!pip install voila
!jupyter serverextension enable --sys-prefix voila

In [3]:
# ignore pandas filter warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
#import pandas as pd

#import fastai.vision.all and vision.widgets to create widgets
from fastai.vision.all import *
from fastai.vision.widgets import *
from ipywidgets import *
import re

In [4]:
path = Path('.')
#learn_inf = load_learner(path/'export_resNet34_1649714642.pkl', cpu=True)

# get model from drive because it pushes us over heroku's slug size of 500mb
# model is ~87mb
import urllib.request
#MODEL_URL = 'https://drive.google.com/file/d/1cOSP83I03sybH1_dXjMChSKaQ3TiAMlI/view?usp=sharing'
#MODEL_URL = 'https://drive.google.com/uc?export=download&amp;id=1cOSP83I03sybH1_dXjMChSKaQ3TiAMlI&amp;confirm=t'
MODEL_URL = r"https://drive.google.com/uc?export=download&confirm=yTib&id=1cOSP83I03sybH1_dXjMChSKaQ3TiAMlI"

learn_url = urllib.request.urlretrieve(MODEL_URL, "model.pkl")

learn_inf = load_learner('model.pkl',cpu=True)

In [5]:
# create upload button

btn_upload = widgets.FileUpload(button_style='info',layout=widgets.Layout(width='100%',height='75%'))
btn_upload

# button to do the classification
btn_runner = widgets.Button(description = 'Classify Fish Image', icon='check', button_style='danger',layout=widgets.Layout(width='100%',height='75%'))
btn_runner

# output widget
output_img = widgets.Output()

# label widgets
lbl_pred = widgets.Label()
lbl_pred.add_class('box_label_style')

lbl_probs = widgets.Label()
lbl_probs.add_class('box_label_style')

lbl_pred_out = widgets.Output()
lbl_probs_out = widgets.Output()

df_widget = widgets.Output()

def on_click_classify(change):
    output_img.clear_output()
    # placeholder image
    img = PILImage.create(btn_upload.data[-1])
    # create an output to display the uploaded file(s)
    
    with output_img: 
        display(img.to_thumb(512,512))
   
    pred,pred_idx,probs = learn_inf.predict(img)
    # output prediction
    lbl_pred.remove_class('box_label_style')
    lbl_pred.value = f'Predicted Species:\n{pred.title()}'
    with lbl_pred_out:
        lbl_pred_out.add_class('box_style')
        display(lbl_pred)

        
    # output class probability for predicted class
    lbl_probs.remove_class('box_label_style')
    lbl_probs.value = f'Probability of Class Membership:\n{100*probs[pred_idx]:.04f}%'
    with lbl_probs_out:
        lbl_probs_out.add_class('box_style')
        display(lbl_probs)

    
    # output our top n species (5) as a dataframe
    df_widget.clear_output()
    species = learn_inf.dls.vocab
    probs = [float(re.sub(r'[A-Za-z=\(\)\"\']', '', str(x))) for x in [f'{float(x):,.8f}' for x in probs.tolist()]]
    spec_probs = dict(zip(species,probs))
    df_probs = pd.DataFrame.from_dict(spec_probs.items())
    df_probs.rename(columns = {0:'Species', 1:'Probability'}, inplace=True)
    with df_widget:
        display(df_probs.sort_values(by='Probability', ascending=False)[:5].style.hide_index())
    
def on_upload_show(change):  
    img = PILImage.create(btn_upload.data[-1])     
    output_img.clear_output()   
    with output_img: display(img.to_thumb(512,512))
    
# get preds when click classify
btn_runner.on_click(on_click_classify)
btn_upload.observe(on_upload_show, names=['data'])

In [6]:
# config *******************************************
primary_layout = widgets.Layout(justify_content='center')

box_layout = Layout(display='flex',
                    flex_flow='column',
                    align_items='stretch',
                    width='100%')
# **************************************************

# images *******************************************
# banner = Image.open(Path('banner.jpeg'))
# # crop image to show as a banner
# width, height = banner.size
# left = 0
# right = width
# top = height//2
# bottom=height

# # set the crop
# banner = banner.crop((left,top,right,bottom))


file = open(Path("banner.jpeg"), "rb")
banner_image = file.read()

# **************************************************

In [7]:
# static content
# **************************************************
title_content = """
    <style>
        .div-1 {
            background-color: #F2F2F2;
            display: block;
            text-align: center;
            padding: 10px;
        }

        .div-2 {
            background-color: #FFF;
        }

        .div-3 {
            background-color: #FBD603;
        }
        .box_style{
        width:80%;
        height:30%;
        background-color: #F2F2F2;
        display: flex;
        align-items: center;
        justify-content: center;
        }
        .box_style_2{
        width:100%;
        height:90%;
        background-color: #F2F2F2;
        display: flex;
        align-items: center;
        justify-content: center;
        }
        .box_label_style{
        background-color: #D9D9C7;
        height: 90%;
        display: flex;
        align-items: center;
        justify-content: center;
    }
    </style>

    <body>
        <div class='div-1'>
            <h1>Fish Species Classifier</h1>
        </div>
    </body>
"""
#display: block;
# align:center;
# height: 60%;
# width:100%;
# **************************************************
description_content = """
    <div class="div-2"> 
            <p> This is demo of using a computer vision model to classify images, in this case of sportfish. 
            The dataset was trained on 60 species of fish, with up to 150 images per species (class). Though most 
            species are from the United states, there are some fish from India as well. This includes freshwater fish 
            only. The model was trained using transfer learning via Fastai; the base model is a CNN with the ResNet34 architecture.
        </div>
        """

# **************************************************
df_title_content = """<div class='div-1'><h4>Top 5 Class Probabilities</h4></div>"""
# **************************************************
footer_content = """<div class="div-1"><small> <b>NOTICE</b> This is hosted on Binder and as such inference capabilities may be throttled by available compute. <small></p></div>"""
# **************************************************

In [10]:
# **************************************************
# app build

# header layout
header_grid = widgets.GridspecLayout(n_rows=4, n_columns=3,grid_gap='2px',height='260px')
header_grid[0:2,:] = widgets.Image(value=banner_image,format='jpeg',width='100%')
header_grid[2:,0] = widgets.HTML(title_content)
header_grid[2:,1:] = widgets.HTML(description_content)
#header_grid[2,:] = widgets.HTML("""<hr>""")

# center content layout
main_grid = widgets.GridspecLayout(n_rows=6, n_columns=10, grid_gap='2px', align_items='center', height='700px')
main_grid[:1,:4] = btn_upload
main_grid[1:,:4] = widgets.HBox([output_img],layout= widgets.Layout(display='flex',flex_flow='column',align_items='center')).add_class('box_style_2')
main_grid[:1,4:] = btn_runner


#widgets.Button(description=f'{lbl_pred.value}',icon='fa-fish', layout=widgets.Layout(width='80%',height='30%'), disabled=True)
#                                  ,widgets.Button(description=f'{lbl_probs.value}',icon='fa-dice-d20',layout=widgets.Layout(width='80%',height='30%'), disabled=True)
main_grid[1:, 4:8] = widgets.VBox([lbl_pred_out, lbl_probs_out]#.add_class('box_style'
                                ,layout= widgets.Layout(display='flex', align_items='center', justify_content='center')).add_class('box_label_style')

#main_grid[2:4,4:8] = widgets.HBox([lbl_pred],layout= widgets.Layout(display='flex', align_items='center', justify_content='center'))#.add_class('box_label_style') #flex_flow='row',align_items='center')
#main_grid[4:,5:8] = widgets.HBox([lbl_probs],layout= widgets.Layout(display='flex',align_items='center', justify_content='center'))#.add_class('box_label_style')
main_grid[1:,8:] = widgets.HBox([
                                VBox([widgets.Button(description = 'Top 5 Class Probabilities', icon='fa-table'
                                                     ,layout=widgets.Layout(width='100%',height='80px'),disabled=True)
                                      , df_widget])
                                ]
                                     , layout= widgets.Layout(display='flex',flex_flow='column',align_items='center')
                            ).add_class("box_style_2")


# footer layout
footer_grid = widgets.GridspecLayout(n_rows=2, n_columns=1, height='60px')
footer_grid[:,0] = widgets.HTML(footer_content)
# **************************************************
# compile app into container
#app = widgets.VBox(content, layout=primary_layout)
app = widgets.AppLayout(header = header_grid, 
                        center = main_grid, 
                        #left_sidebar=None,right_sidebar=None,
                        footer=None, # footer_grid
                        merge=True, 
                        grid_gap='2px',#justify_content='flex-start',
                        #pane_widths=[2, 2, 2],
                        #pane_heights=[0.9,2,0.3],
                        # width='1024px',
                        # height='768px'
                       )


In [11]:
# **************************************************
display(app)
# **************************************************

AppLayout(children=(GridspecLayout(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00H\x…