In [1]:
"""
Solution notebook for visualization large images.
"""

from localtileserver import TileClient, get_leaflet_tile_layer
from niceview.utils.raster import geo_ref_raster
from niceview.utils.mask import mask_overlay_image, mask_to_bbox
from scipy.sparse import load_npz
import plotly.graph_objects as go
import numpy as np
import rasterio

In [2]:
DATA_PATH = './data/'
PLOT_PATH = './plots/'

In [3]:
# mask overlay
mask = load_npz(DATA_PATH + 'raster.npz')
mask = mask.toarray()

mask_overlay_path = mask_overlay_image(
    DATA_PATH + 'raster.tiff',
    mask,
    PLOT_PATH + 'raster.png',
)

File ./plots/raster.png already exists.


In [4]:
# geo reference image
georef_img_path = geo_ref_raster(
    PLOT_PATH + 'raster.png',
    PLOT_PATH + 'raster.tiff',
    affine_coefs=(0.1, 0.0, 0.0, 0.0, 0.1, 0.0), 
    overwrite=True,
)

In [5]:
# bboxes
mask_info_dict = mask_to_bbox(mask)
centroids = mask_info_dict['centroids']

In [6]:
# read geo-refereced image
geo_ref_img = rasterio.open(georef_img_path)
lon, lat = geo_ref_img.xy(5000, 5000)
lon_min, lat_min, lon_max, lat_max = geo_ref_img.bounds
num_digits = len(str(lon)) - len(str(int(lon))) - 1
print(f'lon min: {lon_min}, lon max: {lon_max}')
print(f'lat min: {lat_min}, lat max: {lat_max}')
print(f'number of digits after decimal point: {num_digits}')

# make meshgrid
xs = np.linspace(lon_min, lon_max, geo_ref_img.width, dtype=np.float64)
ys = np.linspace(lat_min, lat_max, geo_ref_img.height, dtype=np.float64)
xx, yy = np.meshgrid(xs, ys)
y_max, x_max = xx.shape
print(f'y max: {y_max}, x max: {x_max}')

lon min: 4.511256064316617, lon max: 4.5200747984478875
lat min: 4.9540469370876195e-08, lat max: 0.008682156662469064
number of digits after decimal point: 15
y max: 9626, x max: 9843


In [7]:
# ravel for easy access
xx_1d = xx.ravel()
yy_1d = yy.ravel()
centroids_1d = np.ravel_multi_index(centroids.T, (y_max, x_max))

for idx in range(3):
    point_from_1d = xx_1d[centroids_1d[idx]], yy_1d[centroids_1d[idx]]
    point_from_2d = xx[centroids[idx, 0], centroids[idx, 1]], yy[centroids[idx, 0], centroids[idx, 1]]
    print(f'idx {idx} - whether two points are equal: {point_from_1d == point_from_2d}')
    print(f'loc {idx} - lon: {point_from_1d[0]}, lat: {point_from_1d[1]}')
    print()

idx 0 - whether two points are equal: True
loc 0 - lon: 4.51269956977135, lat: 0.006934910792826113

idx 1 - whether two points are equal: True
loc 1 - lon: 4.513743445534704, lat: 0.0024445700911830508

idx 2 - whether two points are equal: True
loc 2 - lon: 4.5183024497269315, lat: 0.0017139200372797



In [8]:
lons = xx_1d[centroids_1d]
lats = yy_1d[centroids_1d]

In [15]:
# plot image

map_client = TileClient(georef_img_path, cors_all=True)
map_tile_layer = get_leaflet_tile_layer(map_client)

fig = go.Figure(
    data=(
        go.Scattermapbox(
            lon=lons,
            lat=lats,
            mode='markers',
            marker=go.scattermapbox.Marker(
                size=0.5,
                opacity=0.0,
            ),
            text='center',
            hoverinfo='lat+lon+text',
        ),
    ),
    layout=go.Layout(
        mapbox=dict(
            style='white-bg',
            center=dict(lon=map_client.center()[1], lat=map_client.center()[0]),
            zoom=map_client.default_zoom,
            pitch=0,
            layers=[
                dict(
                    below='traces',
                    sourcetype='raster',
                    source=[map_tile_layer.url],
                ),
            ],
        ),
        margin=dict(l=30, r=30, t=30, b=30),
        autosize=False,
        width=700,
        height=700,
        paper_bgcolor='LightSteelBlue',
    ),
)
fig.show()