In [15]:
import dask.dataframe as dd
import dask.array as da
import dask.bag as db
import dask

import gc
import math
import glob
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator

import cv2
import dask_geopandas as dgpd
import geopandas as gpd
import numpy as np
import pandas as pd
import pygeos.creation
import pygeos.creation
from spatialpandas import sjoin
from geopandas import GeoDataFrame
from pandas import Series
from pyproj import Transformer
from pathlib import Path
from tqdm.notebook import trange, tqdm

from cutil import (
    load_image,
    deg2num,
    nums2degs,
    num2deg
)

In [2]:
from dask.distributed import Client
client = Client(n_workers=2)

In [3]:
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 2
Total threads: 20,Total memory: 31.87 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:58732,Workers: 2
Dashboard: http://127.0.0.1:8787/status,Total threads: 20
Started: Just now,Total memory: 31.87 GiB

0,1
Comm: tcp://127.0.0.1:58749,Total threads: 10
Dashboard: http://127.0.0.1:58750/status,Memory: 15.93 GiB
Nanny: tcp://127.0.0.1:58736,
Local directory: C:\Dropbox\uic\deep-shadow\git\dask-worker-space\worker-xj60vbds,Local directory: C:\Dropbox\uic\deep-shadow\git\dask-worker-space\worker-xj60vbds

0,1
Comm: tcp://127.0.0.1:58752,Total threads: 10
Dashboard: http://127.0.0.1:58753/status,Memory: 15.93 GiB
Nanny: tcp://127.0.0.1:58735,
Local directory: C:\Dropbox\uic\deep-shadow\git\dask-worker-space\worker-_zw6w8lw,Local directory: C:\Dropbox\uic\deep-shadow\git\dask-worker-space\worker-_zw6w8lw


In [20]:
def get_tiles(gdf: GeoDataFrame, zoom: int) -> GeoDataFrame:
    pw, ps, pe, pn = gdf.total_bounds

    trans = Transformer.from_crs(gdf.crs, 4326, always_xy=True)
    gw, gn = trans.transform(pw, pn)
    ge, gs = trans.transform(pe, ps)

    tw, tn = deg2num(gw, gn, zoom, always_xy=True)
    te, ts = deg2num(ge, gs, zoom, always_xy=True)

    # Just making sure that the tiles are actually north, west
    tn, ts = min(tn, ts), max(tn, ts)
    tw, te = min(tw, te), max(tw, te)

    # np.ndarray indexing is [row, column], so I am using [north, west] to maintain that convention
    # Convention: repeat rows, tile columns

    # Slippy Tiles
    tn = np.arange(tn, ts, dtype=np.uint64)  # xtile goes from n to s
    tw = np.arange(tw, te, dtype=np.uint64)  # ytile goes from w to e

    # Geographic
    # Generate from northmost tiles and westmost tiles O(n) instead of all tiles O(n^2)
    _, tgn = nums2degs(np.repeat(tw[0], len(tn)), tn, zoom, always_xy=True)
    tgw, _ = nums2degs(tw, np.repeat(tn[0], len(tw)), zoom, always_xy=True)
    tgs = np.append(
        tgn[1:],
        num2deg(tw[0], ts, zoom, always_xy=True)[1]
    )
    tge = np.append(
        tgw[1:],
        num2deg(te, tn[0], zoom, always_xy=True)[0]
    )

    # Projected
    # Generate from northmost geographic and westmost geographic O(n) instead of all tiles O(n^2)
    trans = Transformer.from_crs(4326, gdf.crs, always_xy=True)
    _, tpn = trans.transform(np.repeat(tgw[0], len(tgn)), tgn)
    tpw, _ = trans.transform(tgw, np.repeat(tgn[0], len(tgw)))
    tps = np.append(
        tpn[1:],
        trans.transform(tgw[0], tgs[-1])[1]
    )
    tpe = np.append(
        tpw[1:],
        trans.transform(tge[-1], tgn[0])[0]
    )

    repeat_rows = len(tw)
    tile_columns = len(tn)
    tn = np.repeat(tn, repeat_rows)
    tw = np.tile(tw, tile_columns)

    tns = tn << 32
    tntw = np.bitwise_or(tns, tw)
    tntw = pd.Index(tntw, name='tntw', dtype=np.uint64)

    tpw = np.tile(tpw, tile_columns)
    tps = np.repeat(tps, repeat_rows)
    tpe = np.tile(tpe, tile_columns)
    tpn = np.repeat(tpn, repeat_rows)
    geometry = pygeos.creation.box(tpw, tps, tpe, tpn)
    h = (tps - tpn)
    w = (tpe - tpw)

    tiles = GeoDataFrame({
        'tn': tn, 'tw': tw,
        'tpn': tpn, 'tpw': tpw,
        # 'tpw': tpw, 'tps': tps, 'tpe': tpe, 'tpn': tpn,
        'h': h, 'w': w,
        # }, geometry=geometry, crs=gdf.crs)
    }, index=tntw, geometry=geometry, crs=gdf.crs)

    itile, igdf = gdf.sindex.query_bulk(tiles.geometry)
    loc = tiles.index[itile].unique()
    tiles: GeoDataFrame = tiles.loc[loc]
    # tiles = tiles.sort_values(['tn', 'tw'], ascending=True)
    tiles = tiles.sort_index(ascending=True)
    return tiles


def get_cells(tiles: GeoDataFrame) -> tuple[dgpd.GeoDataFrame, int, int]:
    # s, w, n, e = tiles.geometry.iloc[0].bounds
    # rows = math.ceil(
    #     abs(s - n) / cell_length
    # )
    # columns = math.ceil(
    #     abs(e - w) / cell_length
    # )
    rows = 256
    columns = 256
    cells_per_tile = rows * columns
    tile_count = len(tiles)
    # TODO: everything 256x256

    mb_per_tile = 8 * 8 * cells_per_tile / 1024 / 1024
    tiles_per_chunk = math.floor(75 / mb_per_tile)
    chunksize = cells_per_tile * tiles_per_chunk

    dh = tiles['h'].values / rows
    dw = tiles['w'].values / columns
    if rows > 256:
        raise ValueError(
            f"{rows=}>256. This means that the image will be downscaled, and cells require more than"
            f" uint8. Increase zoom level."
        )
    cn = np.repeat(
        np.arange(rows, dtype=np.uint8), columns,
    )
    cw = np.tile(
        np.arange(columns, dtype=np.uint8), rows,
    )
    # This is the cause of the artifacts: cannot store 256 as np.uint8
    cs = np.repeat(
        np.arange(1, rows + 1, dtype=np.uint16), columns
    )
    ce = np.tile(
        np.arange(1, columns + 1, dtype=np.uint16), rows
    )
    cnr = np.tile(cn, tile_count)
    cwr = np.tile(cw, tile_count)
    csr = np.tile(cs, tile_count)
    cer = np.tile(ce, tile_count)

    tpnr = da.from_array(
        np.repeat(tiles['tpn'].values, cells_per_tile),
        name='tpnr',
        chunks=chunksize,
    )
    tpwr = da.from_array(
        np.repeat(tiles['tpw'].values, cells_per_tile),
        name='tpwr',
        chunks=chunksize,
    )
    dhr = da.from_array(
        np.repeat(dh, cells_per_tile),
        name='dhr',
        chunks=chunksize,
    )
    dwr = da.from_array(
        np.repeat(dw, cells_per_tile),
        name='dwr',
        chunks=chunksize,
    )

    cpn = tpnr + (dhr * cnr)
    cps = tpnr + (dhr * csr)
    cpw = tpwr + (dwr * cwr)
    cpe = tpwr + (dwr * cer)

    tntw = dd.from_dask_array(da.from_array(
        tiles.index.values.repeat(cells_per_tile),
        chunksize,
    ), columns='tntw')
    area = np.abs(dh * dw)
    arear = dd.from_dask_array(da.from_array(
        np.repeat(area, cells_per_tile),
        chunksize,
    ), columns='area')

    geometry = da.map_blocks(
        pygeos.creation.box, cpw, cps, cpe, cpn,
        dtype=object,
    )
    geometry = dd.from_dask_array(geometry, columns='geometry')
    cn = dd.from_dask_array(da.from_array(
        cnr, chunksize
    ), 'cn')
    cw = dd.from_dask_array(da.from_array(
        cwr, chunksize,
    ), 'cw')
    cells = dd.concat([cn, cw, arear, geometry, tntw], axis=1)
    cells: dgpd.GeoDataFrame = dgpd.from_dask_dataframe(cells)
    cells.crs = tiles.crs

    iloc = list(range(0, tile_count - 1, tiles_per_chunk))
    iloc.append(tile_count - 1)
    divisions = list(tiles.index[iloc])
    cells = cells.set_index('tntw', sorted=True, divisions=divisions)
    return cells, rows, columns


def partition_mapping(cells: GeoDataFrame, directory: str, rows: int, columns: int, zoom: int, ):
    # TODO: handle memory limit
    weight: Series = cells.groupby(['tntw', 'cn', 'cw'], sort=False).weight.sum()
    weight: Series = weight.astype(np.uint16)
    groups = weight.groupby('tntw', sort=False).groups
    tntw = np.fromiter(groups.keys(), dtype=np.uint64)
    tn = dask.array.bitwise_and(tntw, (2 ** 64 - (2 ** 32))) >> 32
    tw = dask.array.bitwise_and(tntw, (2 ** 32 - 1))

    paths = [
        os.path.join(directory, f'{zoom}/{tw_}/{tn_}.png')
        for tn_, tw_ in zip(tn, tw)
    ]
    nodirs = (
        dir
        for path in paths
        if not os.path.exists(dir := os.path.dirname(path))
    )
    subaggs: Iterator[Series] = (
        weight.loc[loc]
        for loc in groups.values()
    )
    images = (
        load_image(
            cn=subagg.index.get_level_values('cn').values,
            cw=subagg.index.get_level_values('cw').values,
            weights=subagg.values,
            rows=rows,
            columns=columns,
        )
        for subagg in subaggs
    )
    
#     db.from_sequence(nodirs).map(os.makedirs).compute()
#     db.from_sequence(zip(paths, images)).map(cv2.imwrite).compute()
    
    with ThreadPoolExecutor(max_workers=1) as te:
        te.map(os.makedirs, nodirs)
    with ThreadPoolExecutor(max_workers=1) as te:
        te.map(cv2.imwrite, paths, images)
        
#     client.run(gc.collect)
#     gc.collect()


# @dask.delayed
def run(gdf: GeoDataFrame, zoom: int, max_height: float, outputfolder: str):
    tiles = get_tiles(gdf, zoom)
    cells, rows, columns = get_cells(tiles)
    
    cells = dgpd.from_geopandas(cells, npartitions=16)

    cells = sjoin(cells, gdf)
    # TODO: Is this wasteful? Should I just call .intersection(gdf.loc[cells['index_right'], 'geometry'] ?
    cells = cells.merge(
        gdf[['geometry']], how='left', left_on='index_right', right_index=True, suffixes=('_cells', '_gdf'),
    )
    del gdf
    cells: dgpd.GeoDataFrame = dgpd.from_dask_dataframe(cells, geometry='geometry_gdf')

    cells['weight'] = (
            dgpd.GeoSeries.intersection(cells['geometry_gdf'], cells['geometry_cells']).area
            / cells['area']
            * cells['height']
            / max_height
            * (2 ** 16 - 1)
    )
    """
    When generating the elevation maps, we assigned unitless weights to cells with the function:
    
    """

#     warnings.filterwarnings('ignore', '.*empty Series.*')
    meta = dd.utils.make_meta((None, None))
#     warnings.filterwarnings('default', '.*empty Series.*')
#     return partition_mapping
    cells = cells[['cn', 'cw', 'weight']]
    return cells.map_partitions(
        partition_mapping,
        directory=outputfolder,
        rows=rows,
        columns=columns,
        zoom=zoom,
        meta=meta,
    ).compute()

In [3]:
# files = glob.glob('data/osm_new/*.feather')
# zooms = [16]
# max_height = 550

# delayed_results = []
# count = 1

# for filepath in files:
#     city = os.path.basename(filepath.split('.')[0])
#     gdf = gpd.read_feather('data/osm_new/%s.feather'%city)
#     for i in range(0,len(zooms)):
#         delayed_results.append(delayed(run)(gdf, zooms[i], max_height, 'data/heights_new/%s/'%city))
#         count+=1

# with tqdm(desc="Joblib Calc", total=len(delayed_results)) as progress_bar:
#     Parallel(n_jobs=2, backend='loky')(delayed_results)

In [4]:
# files = glob.glob('data/osm_new/*.feather')
# count = 1
# delayed = []
# elapsed = time.time()
# for filepath in files:
#     city = os.path.basename(filepath.split('.')[0])
# #     gdf = dask.delayed(gpd.read_feather)(filepath)
#     gdf = gpd.read_feather(filepath)
#     aux = run(gdf, 16, 550, './data/heights_new/%s/'%(city))
#     delayed.append(aux)
#     count+=1
#     if count == 4 or filepath == files[-1]:
#         elapsed = time.time() - elapsed
#         dask.compute(*delayed)
#         print('(%d of %d): %f sec'%(count, len(files), elapsed))
#         delayed = []
#         count = 0
#         elapsed = time.time()

In [None]:
# files = glob.glob('data/osm_new/*.feather')
# count = 1
# for filepath in files:
#     city = os.path.basename(filepath.split('.')[0])
#     gdf = gpd.read_feather(filepath)
#     elapsed = time.time()
#     run(gdf, 16, 550, './data/heights_new/%s/'%(city))
#     elapsed = time.time() - elapsed
#     print('%s (%d of %d): %f sec'%(city, count, len(files), elapsed))
#     count+=1
#     break

In [None]:
files = glob.glob('data/osm_new/*.feather')
count = 1
for filepath in files:
    city = os.path.basename(filepath.split('.')[0])
    gdf = gpd.read_feather(filepath)
    elapsed = time.time()
    run(gdf, 16, 550, './data/heights_new/%s/'%(city))
    elapsed = time.time() - elapsed
    print('%s (%d of %d): %f sec'%(city, count, len(files), elapsed))
    count+=1
    break