In [14]:
import ipywidgets as widgets
from urllib.request import urlopen
from PIL import Image
import requests
import re
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
from typing import NamedTuple, Iterable
from copy import deepcopy
import time
import threading

In [15]:
import pickle 

def pkl_save(path, obj):
  with open(path, 'wb') as file:
    pickle.dump(obj, file)

def pkl_load(path):
  with open(path, 'rb') as file:
    return pickle.load(file)

In [16]:
class CheekyLoadingBar(threading.Thread):
    def __init__(self, out, text, eta):
        super().__init__()
        label = widgets.Label(text)
        self.bar = widgets.FloatProgress(value=0,min=0,max=1)
        to_display = widgets.HBox([label, self.bar])
        
        self.delay = 0.01
        self.incr_start = self.delay / eta
        self.incr = self.incr_start
        self.alpha = 1 - self.incr
        
        self.done = False
        
        with out:
            display(to_display)
    
    def start(self):
        self.start_time = time.time()
        self.wait_target = self.start_time
        super().start()
    
    def run(self):
        while self.bar.value < 1:
            self.wait_target += self.delay
            diff = self.wait_target - time.time()
            if diff > 0:
                time.sleep(diff)
            self.bar.value += self.incr
            # Cheekily slow down progress
            self.incr *= self.alpha
        self.done = True
    
    def stop(self):
        self.incr = self.incr_start
        while not self.done:
            self.incr *= 1.02
            time.sleep(self.delay)
        self.join()
        self.bar.bar_style='success'

In [17]:
VOCAB_COUNTRY = ['United Arab Emirates','Albania', 'Argentina','American Samoa','Austria', 'Australia','Bangladesh','Belgium', 'Bulgaria','Bolivia, Plurinational State of','Brazil', 'Bhutan','Botswana', 'Canada','Switzerland','Chile', 'China','Colombia','Czech Republic','Germany','Denmark','Dominican Republic','Ecuador','Estonia','Egypt','Spain','Finland','Faroe Islands','France','United Kingdom','Ghana','Greenland','Greece', 'Guatemala', 'Hong Kong','Croatia','Hungary','Indonesia', 'Ireland','Israel','India','Iceland','Italy','Jordan','Japan','Kenya','Kyrgyzstan','Cambodia','Korea, Republic of','Sri Lanka', 'Lesotho','Lithuania','Latvia','Madagascar','Macedonia, the Former Yugoslav Republic of','Mongolia','Mexico','Malaysia','Nigeria','Netherlands','Norway','New Zealand','Peru','Philippines','Pakistan','Poland','Puerto Rico','Portugal','Romania','Serbia','Russian Federation','Sweden','Singapore','Slovenia','Slovakia','Senegal','Swaziland','Thailand','Tunisia','Turkey','Taiwan, Province of China','Ukraine','Uganda','United States','Uruguay','Viet Nam', 'South Africa']
VOCAB_US = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland','Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire','New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania','Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington','West Virginia','Wisconsin','Wyoming']

In [18]:
def convert_url(url):
    n = r'([\d\.-]*)'
    res = re.search(fr'@{n},{n},.*,{n}h,{n}t', url)
    
    # Not a google maps URL
    if res is None:
        return url
    
    lat, lng, heading, pitch = [float(res.group(i)) for i in range(1,5)]
    # URL pitch ranges 0 to 180, scrape API ranges -90 to 90
    pitch -= 90

    scrub3key = 'AIzaSyABWCcImw44lzIqLHzBIJLngYLTx5El11M'
    params = {
        'size': '480x480',
        'location': f'{lat},{lng}',
        'heading': str(heading),
        'pitch': str(pitch),
        'fov': '105',
        'key': scrub3key,
    }
    url = "https://maps.googleapis.com/maps/api/streetview"
    url_params = "&".join(f"{k}={v}" for k, v in params.items())
    return f"{url}?{url_params}"

In [19]:
def softmax(arr):
    exp = np.exp(arr)
    return exp / exp.sum()

In [20]:
class ModelOut(NamedTuple):
    COUNTRY: np.array
    GEOCELL: np.array
    US: np.array
    NEIGHBOR_LAT: Iterable[float]
    NEIGHBOR_LNG: Iterable[float]
    SIMILARITY: Iterable[float]
    PRED_LAT: float
    PRED_LNG: float

In [21]:
def get_model_out(url):
    # Returns logits_country: List[float]
    params = {
        'code': 'iZP6dHFLCjWvmLQx9v1haxW8Du21Phk/hMTQj4c/aGJseXAMgWuOPw==',
        'img': url,
    }
    model_url = 'https://countryfinal.azurewebsites.net/api/classify'
    res = requests.get(url=model_url, params=params).json()
    
    return ModelOut(
        COUNTRY = np.array(res['logits_country']),
        GEOCELL = np.array(res['logits_geocell']),
        US = np.array(res['logits_us']),
        NEIGHBOR_LAT = res['neighbor_lat'],
        NEIGHBOR_LNG = res['neighbor_lng'],
        SIMILARITY = res['similarity'],
        PRED_LAT = res['pred_lat'],
        PRED_LNG = res['pred_lng']
    )

In [22]:
def country_preds(model_out):
    logits = model_out.COUNTRY
    probs = softmax(logits)

    df = pd.DataFrame({'Country': VOCAB_COUNTRY, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('Country')
    return df

In [23]:
def us_preds(model_out):
    logits = model_out.US
    probs = softmax(logits)

    df = pd.DataFrame({'US State': VOCAB_US, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('US State')
    return df

In [24]:
def plot_neighbors(model_out, cheeky_lb=None):
    if cheeky_lb: cheeky_lb.start()
    
    neighbor_df = pd.DataFrame({
        'lat': model_out.NEIGHBOR_LAT,
        'lng': model_out.NEIGHBOR_LNG,
        'similarity': model_out.SIMILARITY,
    })
    neighbor_df['size'] = 1
    fig = px.scatter_mapbox(
        neighbor_df,
        lat='lat',
        lon='lng',
        color='similarity',
        size='size',
        size_max=10,
        color_continuous_scale=px.colors.sequential.Plasma_r,
    )
    
    # Add green circle for prediction
    pred_df = pd.DataFrame({
        'lat': [model_out.PRED_LAT], 
        'lng': [model_out.PRED_LNG],
    })
    pred_df['color'] = 'rgb(50,220,50)'
    pred_df['size'] = 1
    trace = px.scatter_mapbox(
        pred_df, 
        lat='lat', 
        lon='lng',
        size='size',
        size_max=10,
        color='color',
        color_discrete_map='identity',
    ).data[0]
    fig.add_trace(trace)

    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox=dict(
            center={'lat': pred_df.lat[0], 'lon': pred_df.lng[0]},
            zoom=2,
        ),
        margin=dict(t=0, b=0, l=0, r=0),
    )
    
    fig.show()
    
    if cheeky_lb: cheeky_lb.stop()
    
    plt.pause(0.001)

In [25]:
def warmup():
    url = "https://www.google.com/maps/@42.3008051,-71.2985194,3a,75y,162.05h,90.17t/data=!3m7!1e1!3m5!1sEbBwxPSFSyilap1gl29nIQ!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fpanoid%3DEbBwxPSFSyilap1gl29nIQ%26cb_client%3Dmaps_sv.tactile.gps%26w%3D203%26h%3D100%26yaw%3D262.69263%26pitch%3D0%26thumbfov%3D100!7i13312!8i6656"
    url = convert_url(url)
    _ = get_model_out(url)

In [26]:
def gmaps_link_text():
    return (
        "<a href=http://google.com/maps target='_blank' style='color:blue;'=>Google Maps</a>"
    )

In [27]:
e2a = widgets.Label("1. Open ")
e2b = widgets.HTML(value=gmaps_link_text())
e2c = widgets.Label(" and go to any Street View location in the world")
e3 = widgets.Label(" (do this by dragging the little yellow person in the bottom right corner of Google Maps onto any road)")
e4 = widgets.Label("2. Copy the URL and paste it below")
e5 = widgets.Label('3. Click "Predict" (this button will appear once the models finish downloading)')

layout = widgets.Layout(
    display='flex',
    flex_flow='column',
    padding='0px',
    border='0px'
)
expl = widgets.VBox([
    widgets.HBox([e2a,e2b,e2c]), e3, e4, e5
], layout=layout)

In [28]:
text_url = widgets.Text(placeholder='Paste the url here...')
btn_clear = widgets.Button(description="Clear")
waiting = widgets.Output()

s1 = widgets.Label("While you're waiting, go to ")
s2 = widgets.HTML(value=gmaps_link_text())
s3 = widgets.Label("and pick a Street View image!")
suggestion = widgets.HBox([s1, s2, s3])

btn_run = widgets.Button(description="Predict Location")
err = widgets.Output()
model_sees_out = widgets.Output()
model_sees = widgets.Label("What the model sees:")
img_out = widgets.Output()
pred_out_country = widgets.Output()
pred_out_us = widgets.Output()
waiting_pred_loc = widgets.Output()
loc_expl_output = widgets.Output()
loc_expl = widgets.VBox([
    widgets.Label("Here are the most similar places in my database."),
    widgets.Label("The green dot my final guess (a weighted average of similar places)"),
])
pred_loc = widgets.Output()

In [29]:
def clear_text_url(_):
    text_url.value = ''

btn_clear.on_click(clear_text_url)

In [30]:
class OnClickClassify: 
    def click(self, change):
        if len(text_url.value) == 0:
            with err:
                print("Please paste an image URL first!")
            return

        for elt in [model_sees_out, img_out, pred_out_country, 
                    pred_out_us, waiting_pred_loc, loc_expl_output, 
                    pred_loc]:
            elt.clear_output()

        url = convert_url(text_url.value)
        with urlopen(url) as testImage:
            image = Image.open(testImage)
            image.thumbnail((320,320), Image.ANTIALIAS)

        with model_sees_out:
            display(model_sees)
        with img_out:
            display(image)

        model_out = get_model_out(url)
        waiting.clear_output()

        with pred_out_country:
            display(country_preds(model_out))

        with pred_out_us:
            display(us_preds(model_out))

        pred_loc_lb = CheekyLoadingBar(waiting_pred_loc, "Predicting Precise Location", 10)
        with pred_loc:
            plot_neighbors(model_out, pred_loc_lb)
        with loc_expl_output:
            display(loc_expl)

In [31]:
occ = OnClickClassify()
btn_run.on_click(occ.click)

In [32]:
main = widgets.VBox([
    expl,
    widgets.HBox([text_url, btn_clear]),
    waiting,
    suggestion,
    btn_run,
    err,
    model_sees_out,
    img_out,
    widgets.HBox([pred_out_country, pred_out_us]),
    waiting_pred_loc,
    loc_expl_output,
    pred_loc,
])
main.layout.visibility = 'hidden'
btn_start_main = widgets.Button(description="Start!")

def start_main(_):
    btn_start_main.layout.display = 'none'
    main.layout.visibility = None
    
    btn_run.layout.visibility = 'hidden'
    btn_clear.layout.visibility = 'hidden'

    pb = CheekyLoadingBar(waiting, "Downloading model", 25)
    pb.start()
    warmup()
    pb.stop()

    suggestion.layout.display = 'none' # Remove
    btn_run.layout.visibility = None # Put back
    btn_clear.layout.visibility = None # Put back
    
btn_start_main.on_click(start_main)

main_wrapper = widgets.VBox([btn_start_main, main])

In [33]:
display(main_wrapper)

VBox(children=(Button(description='Start!', style=ButtonStyle()), VBox(children=(VBox(children=(HBox(children=…