In [33]:
import ipywidgets as widgets
from urllib.request import urlopen
from PIL import Image
import requests
import re
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import plotly.express as px
import matplotlib.pyplot as plt
from typing import NamedTuple
from scipy.sparse import csr_matrix
from scipy.ndimage import gaussian_filter
from copy import deepcopy

In [34]:
import pickle 

def pkl_save(path, obj):
  with open(path, 'wb') as file:
    pickle.dump(obj, file)

def pkl_load(path):
  with open(path, 'rb') as file:
    return pickle.load(file)

In [35]:
### NEAREST NEIGHBORS LOGIC ###

In [36]:
"""
Params:
- k (num neighbors)
    - cur: 5
- m (preference for closest onces)
    - cur: 10
- sigma (units: lat/lng)
    - cur: 4
Minor:
- scale
    - determines the granularity of our map
    - scale=100 creates units of 1/100 lat (resp. lng)
- truncate
    - how many stdevs out to truncate filters
    - smaller is more efficient
    - in general, see how small I can make it without hurting results
"""

class Config(NamedTuple):
    K: int
    M: int
    SIGMA: int
    SCALE: int
    TRUNCATE: int
        
CONFIG = Config(
    K = 20,
    M = 5,
    SIGMA = 4,
    SCALE = 10,
    TRUNCATE = 2,
)

In [37]:
data = pd.read_csv("data.csv")
logits_country = pkl_load("logits/country")
logits_geocell_1 = pkl_load("logits/geocell_1")
logits_geocell_2 = pkl_load("logits/geocell_2")
logits_us = pkl_load('logits/us')

X = np.concatenate([logits_country, logits_geocell_1, logits_geocell_2, logits_us], 1)
y = np.array(data[["lat", "lng"]])

In [38]:
NN = NearestNeighbors(n_neighbors=CONFIG.K).fit(X)

In [39]:
def get_base_filter():
    half_width = int(CONFIG.TRUNCATE*CONFIG.SIGMA*CONFIG.SCALE)
    zs = np.zeros([2*half_width+1]*2)
    zs[half_width,half_width]=1
    return gaussian_filter(zs, sigma=CONFIG.SIGMA*CONFIG.SCALE, truncate=CONFIG.TRUNCATE)

MAIN = csr_matrix((180*CONFIG.SCALE, 360*CONFIG.SCALE))
BASE_FILTER = get_base_filter()

In [40]:
def make_filter(lat_lng, weight):        
    lat, lng = lat_lng
    # Rescale starting at zero
    lat += 90.0
    lng += 180.0
    # Get x and y indices
    lat_idx = round(lat*CONFIG.SCALE)
    lng_idx = round(lng*CONFIG.SCALE)
    # Make filter
    data = BASE_FILTER.copy() * weight
    
    # Put filter into sparse array of the same shape as "main"
    
    # b = "big", s = "small", c = "center"
    hb, wb = MAIN.shape
    hs, ws = data.shape
    hc, wc = lat_idx, lng_idx
    
    # Get row and col indices
    row, col = np.meshgrid(range(hs), range(ws))
    row = row.T.flatten()
    col = col.T.flatten()

    # Re-center them
    row = row - hs // 2 + hc
    col = col - ws // 2 + wc
    
    # Remove out-of-bounds indices
    df = pd.DataFrame({'row': row, 'col': col, 'data': data.flatten()})
    df = df[(df.row >= 0) & (df.col >= 0) & (df.row < hb) & (df.col < wb)]
    
    # Return sparse matrix
    row, col, data = df.row.tolist(), df.col.tolist(), df.data.tolist()
    return csr_matrix((data, (row, col)), shape=MAIN.shape)

In [41]:
def get_best_lat_lng(filts):
    lat_idx_max = filts.max(1).argmax()
    lng_idx_max = filts.max(0).argmax()
    lat_max = (lat_idx_max/CONFIG.SCALE) - 90
    lng_max = (lng_idx_max/CONFIG.SCALE) - 180
    return lat_max, lng_max

In [42]:
def predict(indices, weights, top=None):
    neighbor_locs = y[indices][0] # (K, 2)
    neighbor_weights = weights[0]

    main = deepcopy(MAIN)
    rows = []
    for lat_lng, weight in zip(neighbor_locs, neighbor_weights):
        main += make_filter(lat_lng=lat_lng, weight=weight)
    return get_best_lat_lng(main)

In [43]:
def plot_neighbors(logits):
    X_test = logits.reshape(1, -1)
    distances, indices = NN.kneighbors(X_test)
    weights = (1/distances)**CONFIG.M
    weights /= weights.sum(1, keepdims=True)
    pred = predict(indices, weights)
    pred_df = pd.DataFrame([list(pred)]).rename(columns={0:'lat',1:'lng'})
    
    z1 = y[indices][0]
    z2 = weights[0].reshape(-1,1)
    data = pd.DataFrame(np.concatenate([z1, z2], 1))
    data = data.rename(columns={0: 'lat', 1: 'lng', 2: 'similarity'})
    data['size'] = 1

    fig = px.scatter_mapbox(
        data,
        lat='lat',
        lon='lng',
        color='similarity',
        size='size',
        size_max=10,
        color_continuous_scale=px.colors.sequential.Plasma_r,
    )
    
    # Add green circle for prediction
    pred_df['color'] = 'rgb(50,220,50)'
    pred_df['size'] = 1
    trace = px.scatter_mapbox(
        pred_df, 
        lat='lat', 
        lon='lng',
        size='size',
        size_max=10,
        color='color',
        color_discrete_map='identity',
    ).data[0]
    fig.add_trace(trace)

    fig.update_layout(
        title=(
            ".       Most Similar Locations"
            "<br>.       (the green dot is my guess for the true location)"
        ),
        mapbox_style="open-street-map",
        mapbox=dict(
            center={'lat': pred_df.lat[0], 'lon': pred_df.lng[0]},
            zoom=2,
        ),
    )
    
    fig.show()
    plt.pause(0.001)

In [44]:
# plot_neighbors(X[0]+np.random.randn(412)/100)

In [45]:
### MAIN LOGIC ###

In [46]:
VOCAB_COUNTRY = ['United Arab Emirates','Albania', 'Argentina','American Samoa','Austria', 'Australia','Bangladesh','Belgium', 'Bulgaria','Bolivia, Plurinational State of','Brazil', 'Bhutan','Botswana', 'Canada','Switzerland','Chile', 'China','Colombia','Czech Republic','Germany','Denmark','Dominican Republic','Ecuador','Estonia','Egypt','Spain','Finland','Faroe Islands','France','United Kingdom','Ghana','Greenland','Greece', 'Guatemala', 'Hong Kong','Croatia','Hungary','Indonesia', 'Ireland','Israel','India','Iceland','Italy','Jordan','Japan','Kenya','Kyrgyzstan','Cambodia','Korea, Republic of','Sri Lanka', 'Lesotho','Lithuania','Latvia','Madagascar','Macedonia, the Former Yugoslav Republic of','Mongolia','Mexico','Malaysia','Nigeria','Netherlands','Norway','New Zealand','Peru','Philippines','Pakistan','Poland','Puerto Rico','Portugal','Romania','Serbia','Russian Federation','Sweden','Singapore','Slovenia','Slovakia','Senegal','Swaziland','Thailand','Tunisia','Turkey','Taiwan, Province of China','Ukraine','Uganda','United States','Uruguay','Viet Nam', 'South Africa']
VOCAB_US = ['Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware','Florida','Georgia','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas','Kentucky','Louisiana','Maine','Maryland','Massachusetts','Michigan','Minnesota','Mississippi','Missouri','Montana','Nebraska','Nevada','New Hampshire','New Jersey','New Mexico','New York','North Carolina','North Dakota','Ohio','Oklahoma','Oregon','Pennsylvania','Rhode Island','South Carolina','South Dakota','Tennessee','Texas','Utah','Vermont','Virginia','Washington','West Virginia','Wisconsin','Wyoming']

In [47]:
def convert_url(url):
    n = r'([\d\.-]*)'
    res = re.search(fr'@{n},{n},.*,{n}h,{n}t', url)
    
    # Not a google maps URL
    if res is None:
        return url
    
    lat, lng, heading, pitch = [float(res.group(i)) for i in range(1,5)]
    # URL pitch ranges 0 to 180, scrape API ranges -90 to 90
    pitch -= 90

    scrub3key = 'AIzaSyABWCcImw44lzIqLHzBIJLngYLTx5El11M'
    params = {
        'size': '480x480',
        'location': f'{lat},{lng}',
        'heading': str(heading),
        'pitch': str(pitch),
        'fov': '105',
        'key': scrub3key,
    }
    url = "https://maps.googleapis.com/maps/api/streetview"
    url_params = "&".join(f"{k}={v}" for k, v in params.items())
    return f"{url}?{url_params}"

In [48]:
def softmax(arr):
    exp = np.exp(arr)
    return exp / exp.sum()

In [49]:
class Logits(NamedTuple):
    COUNTRY: np.array
    GEOCELL: np.array
    US: np.array

In [50]:
def get_model_out(url):
    # Returns logits_country: List[float]
    params = {
        'code': 'iZP6dHFLCjWvmLQx9v1haxW8Du21Phk/hMTQj4c/aGJseXAMgWuOPw==',
        'img': url,
    }
    model_url = 'https://countryfinal.azurewebsites.net/api/classify'
    res = requests.get(url=model_url, params=params).json()
    
    return Logits(
        COUNTRY = np.array(res['logits_country']),
        GEOCELL = np.array(res['logits_geocell']),
        US = np.array(res['logits_us']),
    )

In [51]:
def country_preds(model_out):
    logits = model_out.COUNTRY
    probs = softmax(logits)

    df = pd.DataFrame({'Country': VOCAB_COUNTRY, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('Country')
    return df

In [52]:
def us_preds(model_out):
    logits = model_out.US
    probs = softmax(logits)

    df = pd.DataFrame({'US State': VOCAB_US, 'Confidence': probs})
    df = df.sort_values('Confidence', ascending=False)
    df.Confidence = (df.Confidence * 100).round(2)
    df = df[df.Confidence > 0.].iloc[:10]
    df.Confidence = df.Confidence.astype(str) + '%'
    df = df.set_index('US State')
    return df

In [53]:
text_url = widgets.Text(placeholder='Paste the url here...')
btn_run = widgets.Button(description="Predict Location")
err = widgets.Output()
img_out = widgets.Output()
waiting = widgets.Output()
pred_out_country = widgets.Output()
pred_out_us = widgets.Output()
pred_loc = widgets.Output()

In [54]:
def on_click_classify(change):
    if len(text_url.value) == 0:
        with err:
            print("Please paste an image URL first!")
        return
    
    for elt in [img_out, pred_out_country, pred_out_us, pred_loc]:
        elt.clear_output()
    
    url = convert_url(text_url.value)
    with urlopen(url) as testImage:
        image = Image.open(testImage)
        image.thumbnail((256,256), Image.ANTIALIAS)
    
    with img_out:
        display(image)
    
    with waiting:
        print("Predicting... (takes about 15 seconds)")
    model_out = get_model_out(url)
    waiting.clear_output()
    
    with pred_out_country:
        display(country_preds(model_out))
    
    with pred_out_us:
        display(us_preds(model_out))
    
    logits_cat = np.concatenate([model_out.COUNTRY, model_out.GEOCELL, model_out.US])
    with pred_loc:
        plot_neighbors(logits_cat)

In [55]:
btn_run.on_click(on_click_classify)

In [56]:
content = widgets.VBox([
    widgets.Label("Enter any image URL (can be copy-pasted from Google Street View)"),
    text_url,
    btn_run,
    err,
    img_out,
    waiting,
    widgets.HBox([pred_out_country, pred_out_us]),
])

main = widgets.VBox(
    [content, pred_loc],
    layout=widgets.Layout(
        display='flex',
        flex_flow='column',
        align_items='center',
        width='80%',
    ),
)

In [57]:
main

VBox(children=(VBox(children=(Label(value='Enter any image URL (can be copy-pasted from Google Street View)'),…

In [None]:
"""
Wellesley, MA
https://www.google.com/maps/@42.3001618,-71.2873389,3a,75y,282.13h,90.85t/data=!3m7!1e1!3m5!1sAfuLRKxkRWTKquuYzD9P4A!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fpanoid%3DAfuLRKxkRWTKquuYzD9P4A%26cb_client%3Dmaps_sv.tactile.gps%26w%3D203%26h%3D100%26yaw%3D17.52565%26pitch%3D0%26thumbfov%3D100!7i13312!8i6656

Ireland
https://www.google.com/maps/@52.4622135,-8.5005979,3a,75y,283.8h,93.24t/data=!3m7!1e1!3m5!1snwfVyxU9kwNpqZTjleScFg!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fpanoid%3DnwfVyxU9kwNpqZTjleScFg%26cb_client%3Dmaps_sv.tactile.gps%26w%3D203%26h%3D100%26yaw%3D193.45317%26pitch%3D0%26thumbfov%3D100!7i13312!8i6656
"""