In [1]:
import ipywidgets as widgets
from urllib.request import urlopen
from PIL import Image
import requests
import re
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import plotly.express as px
import matplotlib.pyplot as plt
from typing import NamedTuple
from scipy.sparse import csr_matrix
from scipy.ndimage import gaussian_filter
from copy import deepcopy
import time
import threading

In [2]:
import pickle 

def pkl_save(path, obj):
  with open(path, 'wb') as file:
    pickle.dump(obj, file)

def pkl_load(path):
  with open(path, 'rb') as file:
    return pickle.load(file)

In [3]:
class CheekyLoadingBar(threading.Thread):
    def __init__(self, out, text, eta):
        super().__init__()
        label = widgets.Label(text)
        self.bar = widgets.FloatProgress(value=0,min=0,max=1)
        to_display = widgets.HBox([label, self.bar])
        
        self.delay = 0.01
        self.incr_start = self.delay / eta
        self.incr = self.incr_start
        self.alpha = 1 - self.incr
        
        self.done = False
        
        with out:
            display(to_display)
    
    def start(self):
        self.start_time = time.time()
        self.wait_target = self.start_time
        super().start()
    
    def run(self):
        while self.bar.value < 1:
            self.wait_target += self.delay
            diff = self.wait_target - time.time()
            if diff > 0:
                time.sleep(diff)
            self.bar.value += self.incr
            # Cheekily slow down progress
            self.incr *= self.alpha
        self.done = True
    
    def stop(self):
        self.incr = self.incr_start
        while not self.done:
            self.incr *= 1.02
            time.sleep(self.delay)
        self.join()
        self.bar.bar_style='success'

In [4]:
### NEAREST NEIGHBORS LOGIC ###

In [5]:
"""
Params:
- k (num neighbors)
    - cur: 5
- m (preference for closest onces)
    - cur: 10
- sigma (units: lat/lng)
    - cur: 4
Minor:
- scale
    - determines the granularity of our map
    - scale=100 creates units of 1/100 lat (resp. lng)
- truncate
    - how many stdevs out to truncate filters
    - smaller is more efficient
    - in general, see how small I can make it without hurting results
"""

class Config(NamedTuple):
    K: int
    M: int
    SIGMA: int
    SCALE: int
    TRUNCATE: int
        
CONFIG = Config(
    K = 20,
    M = 5,
    SIGMA = 4,
    SCALE = 10,
    TRUNCATE = 2,
)

In [6]:
data = pd.read_csv("data.csv")
logits_country = pkl_load("logits/logits_country")
logits_geocell_1 = pkl_load("logits/logits_geocell_1")
logits_geocell_2 = pkl_load("logits/logits_geocell_2")
logits_us = pkl_load('logits/logits_us')

X_us = np.concatenate([logits_country, logits_geocell_1, logits_geocell_2, logits_us], 1)
X_no_us = np.concatenate([logits_country, logits_geocell_1, logits_geocell_2], 1)
y = np.array(data[["lat", "lng"]])

In [7]:
def get_base_filter():
    half_width = int(CONFIG.TRUNCATE*CONFIG.SIGMA*CONFIG.SCALE)
    zs = np.zeros([2*half_width+1]*2)
    zs[half_width,half_width]=1
    return gaussian_filter(zs, sigma=CONFIG.SIGMA*CONFIG.SCALE, truncate=CONFIG.TRUNCATE)

MAIN = csr_matrix((180*CONFIG.SCALE, 360*CONFIG.SCALE))
BASE_FILTER = get_base_filter()

In [8]:
def make_filter(lat_lng, weight):        
    lat, lng = lat_lng
    # Rescale starting at zero
    lat += 90.0
    lng += 180.0
    # Get x and y indices
    lat_idx = round(lat*CONFIG.SCALE)
    lng_idx = round(lng*CONFIG.SCALE)
    # Make filter
    data = BASE_FILTER.copy() * weight
    
    # Put filter into sparse array of the same shape as "main"
    
    # b = "big", s = "small", c = "center"
    hb, wb = MAIN.shape
    hs, ws = data.shape
    hc, wc = lat_idx, lng_idx
    
    # Get row and col indices
    row, col = np.meshgrid(range(hs), range(ws))
    row = row.T.flatten()
    col = col.T.flatten()

    # Re-center them
    row = row - hs // 2 + hc
    col = col - ws // 2 + wc
    
    # Remove out-of-bounds indices
    df = pd.DataFrame({'row': row, 'col': col, 'data': data.flatten()})
    df = df[(df.row >= 0) & (df.col >= 0) & (df.row < hb) & (df.col < wb)]
    
    # Return sparse matrix
    row, col, data = df.row.tolist(), df.col.tolist(), df.data.tolist()
    return csr_matrix((data, (row, col)), shape=MAIN.shape)

In [9]:
def get_best_lat_lng(filts):
    lat_idx_max = filts.max(1).argmax()
    lng_idx_max = filts.max(0).argmax()
    lat_max = (lat_idx_max/CONFIG.SCALE) - 90
    lng_max = (lng_idx_max/CONFIG.SCALE) - 180
    return lat_max, lng_max

In [10]:
def predict(indices, weights, top=None):
    neighbor_locs = y[indices][0] # (K, 2)
    neighbor_weights = weights[0]

    main = deepcopy(MAIN)
    rows = []
    for lat_lng, weight in zip(neighbor_locs, neighbor_weights):
        main += make_filter(lat_lng=lat_lng, weight=weight)
    return get_best_lat_lng(main)

In [86]:
def plot_neighbors(nn, logits, cheeky_lb=None):
    X_test = logits.reshape(1, -1)
    distances, indices = nn.kneighbors(X_test)
    weights = (1/distances)**CONFIG.M
    weights /= weights.sum(1, keepdims=True)
    if cheeky_lb: cheeky_lb.start()
    pred = predict(indices, weights)
    if cheeky_lb: cheeky_lb.stop()
    pred_df = pd.DataFrame([list(pred)]).rename(columns={0:'lat',1:'lng'})
    
    z1 = y[indices][0]
    z2 = weights[0].reshape(-1,1)
    data = pd.DataFrame(np.concatenate([z1, z2], 1))
    data = data.rename(columns={0: 'lat', 1: 'lng', 2: 'similarity'})
    data['size'] = 1

    fig = px.scatter_mapbox(
        data,
        lat='lat',
        lon='lng',
        color='similarity',
        size='size',
        size_max=10,
        color_continuous_scale=px.colors.sequential.Plasma_r,
    )
    
    # Add green circle for prediction
    pred_df['color'] = 'rgb(50,220,50)'
    pred_df['size'] = 1
    trace = px.scatter_mapbox(
        pred_df, 
        lat='lat', 
        lon='lng',
        size='size',
        size_max=10,
        color='color',
        color_discrete_map='identity',
    ).data[0]
    fig.add_trace(trace)

    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox=dict(
            center={'lat': pred_df.lat[0], 'lon': pred_df.lng[0]},
            zoom=2,
        ),
        margin=dict(t=0, b=0, l=0, r=0),
    )
    
    fig.show()
    plt.pause(0.001)

In [87]:
### MAIN LOGIC ###

In [88]:
VOCAB_COUNTRY = ['United Arab Emirates','Albania', 'Argentina','American Samoa','Austria', 'Australia','Bangladesh','Belgium', 'Bulgaria','Bolivia, Plurinational State of','Brazil', 'Bhutan','Botswana', 'Canada','Switzerland','Chile', 'China','Colombia','Czech Republic','Germany','Denmark','Dominican Republic','Ecuador','Estonia','Egypt','Spain','Finland','Faroe Islands','France','United Kingdom','Ghana','Greenland','Greece', 'Guatemala', 'Hong Kong','Croatia','Hungary','Indonesia', 'Ireland','Israel','India','Iceland','Italy','Jordan','Japan','Kenya','Kyrgyzstan','Cambodia','Korea, Republic of','Sri Lanka', 'Lesotho','Lithuania','Latvia','Madagascar','Macedonia, the Former Yugoslav Republic of','Mongolia','Mexico','Malaysia','Nigeria','Netherlands','Norway','New Zealand','Peru','Philippines','Pakistan','Poland','Puerto Rico','Portugal','Romania','Serbia','Russian Federation','Sweden','Singapore','Slovenia','Slovakia','Senegal','Swaziland','Thailand','Tunisia','Turkey','Taiwan, Province of China','Ukraine','Uganda','United States','Uruguay','Viet Nam', 'South Africa']
VOCAB_US = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland','Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire','New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania','Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington','West Virginia','Wisconsin','Wyoming']

In [89]:
def convert_url(url):
    n = r'([\d\.-]*)'
    res = re.search(fr'@{n},{n},.*,{n}h,{n}t', url)
    
    # Not a google maps URL
    if res is None:
        return url
    
    lat, lng, heading, pitch = [float(res.group(i)) for i in range(1,5)]
    # URL pitch ranges 0 to 180, scrape API ranges -90 to 90
    pitch -= 90

    scrub3key = 'AIzaSyABWCcImw44lzIqLHzBIJLngYLTx5El11M'
    params = {
        'size': '480x480',
        'location': f'{lat},{lng}',
        'heading': str(heading),
        'pitch': str(pitch),
        'fov': '105',
        'key': scrub3key,
    }
    url = "https://maps.googleapis.com/maps/api/streetview"
    url_params = "&".join(f"{k}={v}" for k, v in params.items())
    return f"{url}?{url_params}"

In [90]:
def softmax(arr):
    exp = np.exp(arr)
    return exp / exp.sum()

In [91]:
class Logits(NamedTuple):
    COUNTRY: np.array
    GEOCELL: np.array
    US: np.array

In [92]:
def get_model_out(url):
    # Returns logits_country: List[float]
    params = {
        'code': 'iZP6dHFLCjWvmLQx9v1haxW8Du21Phk/hMTQj4c/aGJseXAMgWuOPw==',
        'img': url,
    }
    model_url = 'https://countryfinal.azurewebsites.net/api/classify'
    res = requests.get(url=model_url, params=params).json()
    
    return Logits(
        COUNTRY = np.array(res['logits_country']),
        GEOCELL = np.array(res['logits_geocell']),
        US = np.array(res['logits_us']),
    )

In [93]:
def country_preds(model_out):
    logits = model_out.COUNTRY
    probs = softmax(logits)

    df = pd.DataFrame({'Country': VOCAB_COUNTRY, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('Country')
    return df

In [94]:
def us_preds(model_out):
    logits = model_out.US
    probs = softmax(logits)

    df = pd.DataFrame({'US State': VOCAB_US, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('US State')
    return df

In [95]:
def warmup_and_get_nns():
    WARMUP_URL = "https://maps.googleapis.com/maps/api/streetview?size=480x480&location=42.3014176,-71.3006286&heading=233.55&pitch=5.1&fov=105&key=AIzaSyABWCcImw44lzIqLHzBIJLngYLTx5El11M"
    _out = get_model_out(WARMUP_URL)
    NN_us = NearestNeighbors(n_neighbors=CONFIG.K).fit(X_us)
    NN_no_us = NearestNeighbors(n_neighbors=CONFIG.K).fit(X_no_us)
    return NN_us, NN_no_us

In [96]:
def warmup_and_get_nns_DUMMY():
    time.sleep(5)
    return 0, 0

In [97]:
def gmaps_link_text():
    return (
        "<a href=http://google.com/maps target='_blank' style='color:blue;'=>Google Maps</a>"
    )

In [98]:
e2a = widgets.Label("1. Open ")
e2b = widgets.HTML(value=gmaps_link_text())
e2c = widgets.Label(" and go to any Street View location in the world")
e3 = widgets.Label(" (do this by dragging the little yellow person in the bottom right corner of Google Maps onto any road)")
e4 = widgets.Label("2. Copy the URL and paste it below")
e5 = widgets.Label('3. Click "Predict"! (this button will appear once my models finish downloading)')

layout = widgets.Layout(
    display='flex',
    flex_flow='column',
    padding='0px',
    border='0px'
)
expl = widgets.VBox([
    widgets.HBox([e2a,e2b,e2c]), e3, e4, e5
], layout=layout)

In [99]:
text_url = widgets.Text(placeholder='Paste the url here...')
btn_clear = widgets.Button(description="Clear")
waiting = widgets.Output()

s1 = widgets.Label("While you're waiting, go to ")
s2 = widgets.HTML(value=gmaps_link_text())
s3 = widgets.Label("and pick a Street View image!")
suggestion = widgets.HBox([s1, s2, s3])

btn_run = widgets.Button(description="Predict Location")
err = widgets.Output()
model_sees_out = widgets.Output()
model_sees = widgets.Label("What the model sees:")
img_out = widgets.Output()
pred_out_country = widgets.Output()
pred_out_us = widgets.Output()
waiting_pred_loc = widgets.Output()
loc_expl_output = widgets.Output()
loc_expl = widgets.VBox([
    widgets.Label("Here are the most similar places in my database."),
    widgets.Label("The green dot my final guess (a weighted average of similar places)"),
])
pred_loc = widgets.Output()

In [100]:
def clear_text_url(_):
    text_url.value = ''

btn_clear.on_click(clear_text_url)

In [101]:
class OnClickClassify:
    def __init__(self):
        self.NN_us = None
        self.NN_no_us = None
        
    def click(self, change):
        if len(text_url.value) == 0:
            with err:
                print("Please paste an image URL first!")
            return

        for elt in [model_sees_out, img_out, pred_out_country, 
                    pred_out_us, waiting_pred_loc, loc_expl_output, 
                    pred_loc]:
            elt.clear_output()

        url = convert_url(text_url.value)
        with urlopen(url) as testImage:
            image = Image.open(testImage)
            image.thumbnail((256,256), Image.ANTIALIAS)

        with model_sees_out:
            display(model_sees)
        with img_out:
            display(image)

        model_out = get_model_out(url)
        waiting.clear_output()

        with pred_out_country:
            display(country_preds(model_out))

        with pred_out_us:
            display(us_preds(model_out))

        is_us = VOCAB_COUNTRY[model_out.COUNTRY.argmax()] == "United States"
        if is_us:
            nn = self.NN_us
            logits_cat = np.concatenate([model_out.COUNTRY, model_out.GEOCELL, model_out.US])
        else:
            nn = self.NN_no_us
            logits_cat = np.concatenate([model_out.COUNTRY, model_out.GEOCELL])
        pred_loc_lb = CheekyLoadingBar(waiting_pred_loc, "Predicting Precise Location", 10)
        pred_loc_lb.start()
        with pred_loc:
            plot_neighbors(nn, logits_cat)
        with loc_expl_output:
            display(loc_expl)
        pred_loc_lb.stop()

In [102]:
occ = OnClickClassify()
btn_run.on_click(occ.click)

In [103]:
main = widgets.VBox([
    expl,
    widgets.HBox([text_url, btn_clear]),
    waiting,
    suggestion,
    btn_run,
    err,
    model_sees_out,
    img_out,
    widgets.HBox([pred_out_country, pred_out_us]),
    waiting_pred_loc,
    loc_expl_output,
    pred_loc,
])
main.layout.visibility = 'hidden'
btn_start_main = widgets.Button(description="Start!")

def start_main(_):
    btn_start_main.layout.display = 'none'
    main.layout.visibility = None
    
    btn_run.layout.visibility = 'hidden'
    btn_clear.layout.visibility = 'hidden'

    pb = CheekyLoadingBar(waiting, "Downloading model", 25)
    pb.start()
    NN_us, NN_no_us = warmup_and_get_nns()
    occ.NN_us = NN_us
    occ.NN_no_us = NN_no_us
    pb.stop()

    suggestion.layout.display = 'none' # Remove
    btn_run.layout.visibility = None # Put back
    btn_clear.layout.visibility = None # Put back
    
btn_start_main.on_click(start_main)

main_wrapper = widgets.VBox([btn_start_main, main])

In [104]:
display(main_wrapper)

VBox(children=(Button(description='Start!', style=ButtonStyle()), VBox(children=(VBox(children=(HBox(children=…