In [None]:
import os
from dotenv import load_dotenv
import io

import h5py
from tqdm import tqdm
from glob import glob
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
from shapely.geometry import Point
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from IPython.display import clear_output
import requests

load_dotenv()


In [None]:
def get_google_maps_static_image(
    api_key, lat_lon: str, zoom=17, scale=1, size=(224, 244), format="png8", maptype="satellite"
):
    """Get a static image from Google Maps API.

    Args:
        api_key: API key for Google Maps API. See https://developers.google.com/maps/documentation/maps-static/get-api-key. You may first need to enable "Google Maps Static API" in your Google Cloud Console.
        lat_lon: format "lat,lon"
        zoom: Zoom level passed to Google Maps API.
        scale: Scale passed to Google Maps API. If scale=2, the image will have 2x pixels for the same area.
        size: Size of the image in pixels in (rows, columns) format.
        format: Format of the image. See https://developers.google.com/maps/documentation/maps-static/start#ImageFormats
        maptype: Type of map. See https://developers.google.com/maps/documentation/maps-static/start#MapTypes

    Raises:
        ValueError: If response code is not 200 or if image is not found.

    Returns:
        Image bytes.
    """
    base_url = "https://maps.googleapis.com/maps/api/staticmap"
    params = {
        "center": lat_lon,
        "zoom": zoom,
        "size": f"{size[0]}x{size[1]}",
        "maptype": maptype,
        "key": api_key,
        "scale": scale,
        "format": format,
    }
    response = requests.get(base_url, params=params)
    img = response.content
    code = response.status_code
    if code != 200:
        raise ValueError("Request failed with code: " + str(code) + " and message: " + str(img))
    if len(img) < 1000:
        raise ValueError("Image not found. Message: " + str(img))
    return img

## Load boundary

Shapefile is downloaded from: 

| Country/City | Link |
| --- | --- |
| Bangladesh | https://www.kaggle.com/datasets/tsgreen/bangladesh-administrative-boundaries-shapefiles/data |

In [None]:
geometry = 'bangladesh'
gdf = gpd.read_file(f'/path/shapefiles/{geometry}/bgd_admbnda_adm0_bbs_20180410.shp')
gdf.plot();

## Create index of points to download

In [None]:
unary_shape = gdf.unary_union
min_lon, min_lat, max_lon, max_lat = unary_shape.bounds
print(min_lon, min_lat, max_lon, max_lat)

In [None]:
# The following gap has some overlap but it is choosen to have multiples of 0.01
lat_gap = 0.01
lon_gap = 0.01
lats = np.arange(np.round(min_lat, 2)-lat_gap, np.round(max_lat, 2)+lat_gap*2, lat_gap)
lons = np.arange(np.round(min_lon, 2)-lon_gap, np.round(max_lon, 2)+lon_gap*2, lon_gap)
print(len(lats), len(lons), len(lats)*len(lons))

In [None]:
Lat, Lon = np.meshgrid(lats, lons)
latlon_pairs = np.vstack([Lat.ravel(), Lon.ravel()]).T
print(latlon_pairs.shape)

Now, we discard the points that are not in a Geometry. We will use the `shapely` library to do this.

In [None]:
point_gdf = gpd.GeoDataFrame(geometry=[Point(y, x) for x, y in latlon_pairs])
def check_within(gdf_chunk):
    return gdf_chunk.within(unary_shape).values

chunk_size = 100
chunks = [point_gdf[i:i+chunk_size] for i in range(0, len(point_gdf), chunk_size)]
results = Parallel(n_jobs=48)(delayed(check_within)(chunk) for chunk in tqdm(chunks))
latlon_bool = np.concatenate(results)
print(len(latlon_bool), latlon_bool.sum())

In [None]:
latlon_pairs_in_geometry = latlon_pairs[latlon_bool]

def convert_to_string(latlon_pair):
    lat, lon = latlon_pair
    return f'{lat:.2f},{lon:.2f}'

latlon_pairs_in_geometry = [convert_to_string(latlon_pair) for latlon_pair in latlon_pairs_in_geometry]

print(latlon_pairs_in_geometry[:3])
print(len(latlon_pairs_in_geometry))

## Function to download a single image

In [None]:
lat_lag = np.array([-2, -1, 0, 1, 2], dtype=np.int8)
lon_lag = np.array([-2, -1, 0, 1, 2], dtype=np.int8)
rows = np.array(np.arange(224), dtype=np.uint8)
cols = np.array(np.arange(224), dtype=np.uint8)
channels = np.array([0, 1, 2], dtype=np.uint8)
labels = (np.zeros((5, 5)) - 1).astype(np.int8)
path = os.path.join(os.path.expanduser("~"), 'bkdb', geometry)
save_path = "/path/images" # path to save the images
key = os.getenv('SMTGML_GMS')  # put your Google Maps Static API key here

def download_it(lat_lon):
    # Download image
    img_bytes = get_google_maps_static_image(key, lat_lon, zoom=16, size=(640, 640), scale=2)
    img_io = io.BytesIO(img_bytes)
    img = plt.imread(img_io)
    
    # Center crop to allow 5x5 patches of size 224x224x3 from 1120x1120x3
    cut_img = img[80:-80, 80:-80, :3]
    assert cut_img.shape == (224*5, 224*5, 3), f"cut_img.shape = {cut_img.shape}"
    
    # Save image
    plt.imsave(os.path.join(save_path, f'{lat_lon}.png'), cut_img)

## Download

In [None]:
n_jobs = 4 # set based on number of cores
_ = Parallel(n_jobs=n_jobs)(delayed(download_it)(lat_lon) for lat_lon in tqdm(latlon_pairs_in_geometry))