# Monitor download state and task progress
Run this notebook to visualise download state and monitor download tasks in progress on Google Earth Engine.

In [None]:
# Necessary imports
import os
os.environ['USE_PYGEOS'] = '0'
import pandas as pd
import geopandas as gpd
import time
import folium
import geemap.foliumap as geemap
import branca.colormap

from db_utils import DB
from dotenv import load_dotenv

## Load environment and project details

As with the other notebooks, we load credentials and project details from a hidden ```.env``` file.

In [None]:
# Load environment variables (including path to credentials) from '.env' file
env_file_path = "../.env"

assert load_dotenv(dotenv_path=env_file_path) == True, "[ERR] Failed to load environment!"
assert "GOOGLE_APPLICATION_CREDENTIALS" in os.environ, "[ERR] Missing $GOOGLE_APPLICATION_CREDENTIAL!"
assert "GS_USER_PROJECT" in os.environ, "[ERR] Missing $GS_USER_PROJECT!"
key_file_path = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
assert os.path.exists(key_file_path), f"[ERR] Google credential key file does not exist: \n{key_file_path} "
assert "ML4FLOODS_BASE_DIR" in os.environ, "[ERR] Missing $ML4FLOODS_BASE_DIR!"
base_path = os.environ["ML4FLOODS_BASE_DIR"]
assert os.path.exists(base_path), f"[ERR] Base path does not exist: \n{base_path} "
print("[INFO] Successfully loaded FloodMapper environment.")

## Query the download state from the database

In [None]:
# All work is conducted under a unique session name
session_name = "EMSR586"

In [None]:
# Connect to the database (point to the .env file for credentials)
db_conn = DB(env_file_path)

In [None]:
# Query the patches and their download status
query = (f"SELECT DISTINCT im.image_id, im.satellite, "
         f"im.patch_name, im.status, ST_AsText(gr.geometry) "
         f"FROM image_downloads im "
         f"INNER JOIN grid_loc gr "
         f"ON im.patch_name = gr.patch_name "
         f"INNER JOIN session_patches sp "
         f"ON im.patch_name = sp.patch_name "
         f"WHERE sp.session = %s ;")
data = (session_name,)
grid_df = db_conn.run_query(query, data, fetch=True)
print(f"[INFO] Returned {len(grid_df)} rows.")

# Format the results into a correct GeoDataFrame
grid_df['geometry'] = gpd.GeoSeries.from_wkt(grid_df['st_astext'])
grid_df.drop(['st_astext'], axis=1, inplace = True)
grid_gdf = gpd.GeoDataFrame(grid_df, geometry='geometry', crs="EPSG:4326")
grid_gdf

## Parse number of downloads and skipped files in each patch

In [None]:
# Extract the patch polygons
geom = grid_gdf[["patch_name", "geometry"]].drop_duplicates()
geom = geom.set_index("patch_name")

# Count the downloaded and skipped files
grid_dl_gdf = grid_gdf.loc[grid_gdf.status == 1]
downloads = grid_dl_gdf.groupby("patch_name").image_id.count()
grid_skip_gdf = grid_gdf.loc[grid_gdf.status == 0]
skipped = grid_skip_gdf.groupby("patch_name").image_id.count()
skipped = skipped.rename({"image_id": "count"})

# Create a downloads gdf
downloads_df = pd.concat([downloads, geom], axis = 1)
downloads_df = downloads_df.rename(columns={"image_id": "count"})
downloads_gdf = gpd.GeoDataFrame(downloads_df, geometry='geometry', crs="EPSG:4326")

# Create a skipped gdf
skipped_df = pd.concat([skipped, geom], axis = 1)
skipped_df = skipped_df.rename(columns={"image_id": "count"})
skipped_gdf = gpd.GeoDataFrame(skipped_df, geometry='geometry', crs="EPSG:4326")

## Plot the number of skipped files

In [None]:
# Create an outline of the map area
aoi_outline = grid_gdf.geometry.unary_union
aoi_outline_gdf = gpd.GeoDataFrame(geometry=[aoi_outline], crs="EPSG:4326")

In [None]:
# Define a style function to set the colours
#cm = branca.colormap.LinearColormap(
#    ['red', 'orange', 'yellow', 'cyan', 'blue', 'darkblue'],
cm = branca.colormap.linear.YlOrRd_07.scale(
    vmin=downloads_gdf["count"].min(), 
    vmax=downloads_gdf["count"].max())
def style_fn(feature):
    return {
        'fillColor': cm(feature['properties']['count']),
        'color': cm(feature['properties']['count']),
        'weight': 0.5,
        "fillOpacity": 0.5
    }

# Plot the patches colour-coded by number of downloads
m = aoi_outline_gdf.explore(color="black", style_kwds={"fillOpacity": 0.0, "weight": 3}, 
                            name="AoI Outline", highlight=False)
folium.GeoJson(downloads_gdf, 
               style_function=style_fn,
               name="Downloaded Images",
               tooltip=folium.features.GeoJsonTooltip(["count"]),
              ).add_to(m)

# Add the colourmap, layer control and show
m.add_child(cm)
folium.LayerControl(collapsed=False).add_to(m)
m

## Plot the number of skipped files

In [None]:
# Define a style function to set the colours
#cm = branca.colormap.LinearColormap(
#    ['red', 'yellow', 'green'],
cm = branca.colormap.linear.YlOrRd_07.scale(
    vmin=skipped_gdf["count"].min(), 
    vmax=skipped_gdf["count"].max())
def style_fn(feature):
    return {
        'fillColor': cm(feature['properties']['count']),
        'color': cm(feature['properties']['count']),
        'weight': 0.5,
        "fillOpacity": 0.5
    }

# Plot the patches colour-coded by number of downloads
m = aoi_outline_gdf.explore(color="black", style_kwds={"fillOpacity": 0.0, "weight": 3}, 
                            name="AoI Outline", highlight=False)
folium.GeoJson(skipped_gdf, 
               style_function=style_fn,
               name="Skipped Images",
               tooltip=folium.features.GeoJsonTooltip(["count"]),
              ).add_to(m)

# Add the colourmap, layer control and show
m.add_child(cm)
folium.LayerControl(collapsed=False).add_to(m)
m

## Display a progress bar

The cells here can be quickly run in sequence to produce progress bars for the tasks being tracked by the database. Note that the ```01_download_images.py``` script must remain running for this notebook to work. 

In [None]:
# Query the DB for the task status
query = (f"SELECT description, state_code "
         f"FROM gee_task_tracker "
         f"WHERE session = %s;")
data = (session_name,)
tasks_df = db_conn.run_query(query, data, fetch=True)
tasks_df

In [None]:
# Initialise progress bar for all available tasks. 
"""
batch_bar = tqdm(total=len(tasks_df), 
                 dynamic_ncols=True, 
                 leave=False, 
                 position=0, 
                 desc="All Tasks",
                 colour="GREEN")

# Logic : Check all tasks, keep removing them as and when the 
# in_progress flag is set to 0 for the task in the database.
while len(tasks_df) >= 1:
    
    # Loop through the tasks grouped by gridname
    for name, gdf in tasks_df.groupby(by='gridname'):
        for i, task in gdf.iterrows():
            
            # Check if download is still marked as in-progress in the DB
            desc = task['description']
            ip = image_db[image_db['image_id'] == desc]['status'].item()
            
            # Do nothing if still in-progress
            if ip == -1:
                continue

            ## Drop entry if not still in-progress
            if ip == 0 or ip == 1:
                tasks_df.drop(i, inplace = True)
                batch_bar.update()
        
        time.sleep(0.25)
"""