In [None]:
import os
from dotenv import load_dotenv
import io

import h5py
from tqdm import tqdm
from glob import glob
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
from shapely.geometry import Point
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from IPython.display import clear_output

from numcodecs import GZip, Zstd, Blosc

from albk.data.utils import get_google_maps_static_image

load_dotenv()


## Load boundary

Shapefile is downloaded from: 

| Country/City | Link |
| --- | --- |
| Bangladesh | https://www.kaggle.com/datasets/tsgreen/bangladesh-administrative-boundaries-shapefiles/data |

In [None]:
geometry = 'bangladesh'
gdf = gpd.read_file(f'shapefiles/{geometry}/bgd_admbnda_adm0_bbs_20180410.shp')
gdf.plot();

## Create index of points to download

In [None]:
unary_shape = gdf.unary_union
min_lon, min_lat, max_lon, max_lat = unary_shape.bounds
print(min_lon, min_lat, max_lon, max_lat)

In [None]:
# The following gap has some overlap but it is choosen to have multiples of 0.01
lat_gap = 0.01
lon_gap = 0.01
lats = np.arange(np.round(min_lat, 2)-lat_gap, np.round(max_lat, 2)+lat_gap*2, lat_gap)
lons = np.arange(np.round(min_lon, 2)-lon_gap, np.round(max_lon, 2)+lon_gap*2, lon_gap)
print(len(lats), len(lons), len(lats)*len(lons))

In [None]:
Lat, Lon = np.meshgrid(lats, lons)
latlon_pairs = np.vstack([Lat.ravel(), Lon.ravel()]).T
print(latlon_pairs.shape)

Now, we discard the points that are not in a Geometry. We will use the `shapely` library to do this.

In [None]:
point_gdf = gpd.GeoDataFrame(geometry=[Point(y, x) for x, y in latlon_pairs])
def check_within(gdf_chunk):
    return gdf_chunk.within(unary_shape).values

chunk_size = 100
chunks = [point_gdf[i:i+chunk_size] for i in range(0, len(point_gdf), chunk_size)]
results = Parallel(n_jobs=48)(delayed(check_within)(chunk) for chunk in tqdm(chunks))
latlon_bool = np.concatenate(results)
print(len(latlon_bool), latlon_bool.sum())

In [None]:
latlon_pairs_in_geometry = latlon_pairs[latlon_bool]

def convert_to_string(latlon_pair):
    lat, lon = latlon_pair
    return f'{lat:.2f},{lon:.2f}'

latlon_pairs_in_geometry = [convert_to_string(latlon_pair) for latlon_pair in latlon_pairs_in_geometry]

print(latlon_pairs_in_geometry[:3])
print(len(latlon_pairs_in_geometry))

## Function to download a single image

In [None]:
lat_lag = np.array([-2, -1, 0, 1, 2], dtype=np.int8)
lon_lag = np.array([-2, -1, 0, 1, 2], dtype=np.int8)
rows = np.array(np.arange(224), dtype=np.uint8)
cols = np.array(np.arange(224), dtype=np.uint8)
channels = np.array([0, 1, 2], dtype=np.uint8)
labels = (np.zeros((5, 5)) - 1).astype(np.int8)
path = os.path.join(os.path.expanduser("~"), 'bkdb', geometry)
key = os.getenv('SMTGML_GMS') # DHRUV_GMS, ZEEL_IITGN_GMS, SURAJ_GMS, ZEEL_GMS, ANONY_GMS, VISHESH_GMS, SMTGML_GMS

def download_it(lat_lon):
    # Download image
    img_bytes = get_google_maps_static_image(key, lat_lon, zoom=16, size=(640, 640), scale=2)
    img_io = io.BytesIO(img_bytes)
    img = plt.imread(img_io)
    
    # Center crop to allow 5x5 patches of size 224x224x3 from 1120x1120x3
    cut_img = img[80:-80, 80:-80, :3]
    assert cut_img.shape == (224*5, 224*5, 3), f"cut_img.shape = {cut_img.shape}"
    
    # Split into 5x5 patches: (5, 5, 224, 224, 3)
    data = np.array([np.split(tmp_img, 5, axis=1) for tmp_img in np.split(cut_img, 5, axis=0)])
    data = (data * 255).astype(np.uint8)
    
    save_path = os.path.join(path, f"{lat_lon}.zarr")
    lat, lon = lat_lon.split(',')
    lat, lon = float(lat), float(lon)
    ds = xr.Dataset(
    {
        'data': (['lat_lag', 'lon_lag', 'row', 'col', 'channel'], data),
        'label': (['lat_lag', 'lon_lag'], labels),
    },
    coords={
        'lat': lat,
        'lon': lon,
        'row': rows,
        'col': cols,
        'channel': channels,
        'lat_lag': lat_lag,
        'lon_lag': lon_lag,
    },
    )
    encoding = {'data': {'compressor': GZip(level=9)}}
    ds.to_zarr(save_path, consolidated=False, encoding=encoding, mode='w')

In [None]:
# Test download
# download_it(latlon_pairs_in_geometry[0])

## Download

In [None]:
# We got these pairs after finding closest pairs to predicted brick kilns
#any special pairs to download
special_pairs = []

In [None]:
existing_folders = os.listdir(path)
def check_if_exists(folder):
    try:
        ds = xr.open_zarr(os.path.join(path, folder), consolidated=False)
        assert ds.data.shape == (5, 5, 224, 224, 3)
        return True
    except Exception as e:
        # print(e)
        return False

does_exists = Parallel(n_jobs=48)(delayed(check_if_exists)(folder) for folder in tqdm(existing_folders))
existing_pairs = [folder.replace(".zarr", "") for folder, exists in zip(existing_folders, does_exists) if exists]

to_download_pairs = sorted(set(latlon_pairs_in_geometry+special_pairs) - set(existing_pairs))
print(len(latlon_pairs_in_geometry+special_pairs), "in", geometry)
print(len(existing_pairs), "already downloaded")
print(len(to_download_pairs), "to download")

_ = Parallel(n_jobs=48)(delayed(download_it)(lat_lon) for lat_lon in tqdm(to_download_pairs))

| File Format                                   | Disk Space Consumption            | Time Taken               |
|-----------------------------------------------|-----------------------------------|--------------------------|
| `png` (200 points, 5000 images)               | 461 MB                            | 8.4 seconds              |
| `npy` (200 points, 5000 images)               | 725 MB                            | 6.4 seconds              |
| `npz` (200 points, 5000 images)               | 241 MB                            | 7.5 seconds              |
| `h5` (gzip compression, 200 points, 5000 images)| 423 MB                           | 6.7 seconds              |
| `zarr` (Zstd compression, 200 points, 5000 images)| 364 MB                         | 7.0 seconds              |
| `zarr` (Gzip(level=1) compression, 200 points, 5000 images)| 299 MB                    | 8.1 seconds              |
| `zarr` (Gzip(level=9) compression, 200 points, 5000 images)| 261 MB                    | 8.7 seconds              |

## Appendix

In [None]:
# for i in range(5):
#     img = np.load(os.path.join(path, "21.03,92.25", f"{i}_0.npz"))['arr_0']
#     plt.figure()
#     plt.imshow(img)
#     plt.show()
#     label = int(input("Enter label: "))
#     clear_output(wait=True)

In [None]:
files = glob(os.path.join(path, "*.zarr"))
ds = xr.open_zarr(files[0], consolidated=False)
ds.sel(lat_lag=0, lon_lag=0)['data']

In [None]:
plt.imshow(ds.sel(lat_lag=0, lon_lag=0)['data'])