# Monitor download state and task progress
Run this notebook to visualise download state and monitor download tasks in progress on Google Earth Engine.

In [None]:
# Necessary imports
import os
os.environ['USE_PYGEOS'] = '0'
import pandas as pd
import geopandas as gpd
import time
import folium
import geemap.foliumap as geemap
import branca.colormap
from tqdm.notebook import tqdm
from datetime import datetime

from db_utils import DB
from dotenv import load_dotenv

## Load environment and project details

As with the other notebooks, we load credentials and project details from a hidden ```.env``` file.

In [None]:
# Load environment variables (including path to credentials) from '.env' file
env_file_path = "../.env"

# Uncomment for alternative version for Windows (r"" indicates raw string)
#env_file_path = r"C:/Users/User/floodmapper/.env"

assert load_dotenv(dotenv_path=env_file_path) == True, "[ERR] Failed to load environment!"
assert "GOOGLE_APPLICATION_CREDENTIALS" in os.environ, "[ERR] Missing $GOOGLE_APPLICATION_CREDENTIAL!"
assert "GS_USER_PROJECT" in os.environ, "[ERR] Missing $GS_USER_PROJECT!"
key_file_path = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
assert os.path.exists(key_file_path), f"[ERR] Google credential key file does not exist: \n{key_file_path} "
assert "ML4FLOODS_BASE_DIR" in os.environ, "[ERR] Missing $ML4FLOODS_BASE_DIR!"
base_path = os.environ["ML4FLOODS_BASE_DIR"]
assert os.path.exists(base_path), f"[ERR] Base path does not exist: \n{base_path} "
bucket_name = os.environ["BUCKET_URI"]
assert bucket_name is not None and bucket_name != "", f"Bucket name not defined {bucket_name}"
print("[INFO] Successfully loaded FloodMapper environment.")

## Query the download state from the database

**Set the name of the session here and run all remaining cells in order.**

In [None]:
# EDIT THE NAME OF THE SESSION
session_name = "EMSR586"

In [None]:
# Connect to the database (point to the .env file for credentials)
db_conn = DB(env_file_path)

In [None]:
# Query the geometry of the selected area
query = (f"SELECT DISTINCT sp.patch_name, ST_AsText(gr.geometry) "
         f"FROM session_patches sp "
         f"INNER JOIN grid_loc gr "
         f"ON sp.patch_name = gr.patch_name "
         f"WHERE sp.session = %s ;")
data = (session_name,)
grid_sel_df = db_conn.run_query(query, data, fetch=True)
print(f"[INFO] Returned {len(grid_sel_df)} rows.")

# Format the results into a correct GeoDataFrame
grid_sel_df['geometry'] = gpd.GeoSeries.from_wkt(grid_sel_df['st_astext'])
grid_sel_df.drop(['st_astext'], axis=1, inplace = True)
grid_sel_gdf = gpd.GeoDataFrame(grid_sel_df, geometry='geometry', crs="EPSG:4326")
print(grid_sel_gdf.head(3))

# Create an outline of the map area
aoi_outline = grid_sel_gdf.geometry.unary_union
aoi_outline_gdf = gpd.GeoDataFrame(geometry=[aoi_outline], crs="EPSG:4326")

In [None]:
# Query the patches and their download status
query = (f"SELECT DISTINCT im.image_id, im.satellite, "
         f"im.patch_name, im.status, ST_AsText(gr.geometry) "
         f"FROM image_downloads im "
         f"INNER JOIN grid_loc gr "
         f"ON im.patch_name = gr.patch_name "
         f"INNER JOIN session_patches sp "
         f"ON im.patch_name = sp.patch_name "
         f"WHERE sp.session = %s ;")
data = (session_name,)
grid_df = db_conn.run_query(query, data, fetch=True)
print(f"[INFO] Returned {len(grid_df)} rows.")

# Stop execution if no patches returned
if len(grid_df) == 0:
    print(f"[INFO] No patches for session '{session_name}'.")
    raise KeyboardInterrupt

# Format the results into a correct GeoDataFrame
grid_df['geometry'] = gpd.GeoSeries.from_wkt(grid_df['st_astext'])
grid_df.drop(['st_astext'], axis=1, inplace = True)
grid_gdf = gpd.GeoDataFrame(grid_df, geometry='geometry', crs="EPSG:4326")
grid_gdf.head(3)

## Parse number of downloads and skipped files in each patch

In [None]:
# Extract the patch polygons
geom = grid_gdf[["patch_name", "geometry"]].drop_duplicates()
geom = geom.set_index("patch_name")

# Count the downloaded and skipped files
grid_dl_gdf = grid_gdf.loc[grid_gdf.status == 1]
downloads = grid_dl_gdf.groupby("patch_name").image_id.count()
grid_skip_gdf = grid_gdf.loc[grid_gdf.status == 0]
skipped = grid_skip_gdf.groupby("patch_name").image_id.count()
skipped = skipped.rename({"image_id": "count"})

# Create a downloads gdf
downloads_df = pd.concat([downloads, geom], axis = 1)
downloads_df = downloads_df.rename(columns={"image_id": "count"})
downloads_gdf = gpd.GeoDataFrame(downloads_df, geometry='geometry', crs="EPSG:4326")
downloads_gdf = downloads_gdf.fillna(0)

# Create a skipped gdf
skipped_df = pd.concat([skipped, geom], axis = 1)
skipped_df = skipped_df.rename(columns={"image_id": "count"})
skipped_gdf = gpd.GeoDataFrame(skipped_df, geometry='geometry', crs="EPSG:4326")
skipped_gdf = skipped_gdf.fillna(0)

## Plot the number of downloaded files

In [None]:
# Define a style function to set the colours
#cm = branca.colormap.LinearColormap(
#    ['red', 'orange', 'yellow', 'cyan', 'blue', 'darkblue'],
cm = branca.colormap.linear.YlOrRd_07.scale(
    vmin=downloads_gdf["count"].min(), 
    vmax=downloads_gdf["count"].max())
def style_fn(feature):
    return {
        'fillColor': cm(feature['properties']['count']),
        'color': cm(feature['properties']['count']),
        'weight': 0.5,
        "fillOpacity": 0.5
    }

# Plot the patches colour-coded by number of downloads
m = aoi_outline_gdf.explore(color="black", style_kwds={"fillOpacity": 0.0, "weight": 3}, 
                            name="AoI Outline", highlight=False)
folium.GeoJson(downloads_gdf, 
               style_function=style_fn,
               name="Downloaded Images",
               tooltip=folium.features.GeoJsonTooltip(["count"]),
              ).add_to(m)

# Add the colourmap, layer control and show
m.add_child(cm)
folium.LayerControl(collapsed=False).add_to(m)
m

## Plot the number of skipped files

In [None]:
# Define a style function to set the colours
#cm = branca.colormap.LinearColormap(
#    ['red', 'yellow', 'green'],
cm = branca.colormap.linear.YlOrRd_07.scale(
    vmin=skipped_gdf["count"].min(), 
    vmax=skipped_gdf["count"].max())
def style_fn(feature):
    return {
        'fillColor': cm(feature['properties']['count']),
        'color': cm(feature['properties']['count']),
        'weight': 0.5,
        "fillOpacity": 0.5
    }

# Plot the patches colour-coded by number of downloads
m = aoi_outline_gdf.explore(color="black", style_kwds={"fillOpacity": 0.0, "weight": 3}, 
                            name="AoI Outline", highlight=False)
folium.GeoJson(skipped_gdf, 
               style_function=style_fn,
               name="Skipped Images",
               tooltip=folium.features.GeoJsonTooltip(["count"]),
              ).add_to(m)

# Add the colourmap, layer control and show
m.add_child(cm)
folium.LayerControl(collapsed=False).add_to(m)
m

## Monitor the task count and progress

The cell below can be run to check the number of tasks being tracked by the database. Note that the ```01_download_images.py``` script must remain running for this notebook to work correctly.

Individual tasks can also be viewed on this GEE [task tracking page](https://code.earthengine.google.com/tasks), provided you are logged Google Earth Engine via your web browser. 

In [None]:
# Function to query the task status in the database
def query_tasks():
    query = (f"SELECT description, state_code "
             f"FROM gee_task_tracker "
             f"WHERE session = %s;")
    data = (session_name,)
    tasks_df = db_conn.run_query(query, data, fetch=True)
    tasks_grp = tasks_df.groupby("state_code")
    
    # Get the total number of tasks
    num_tasks = len(tasks_df)
    
    # Get the number completed
    num_completed = 0
    if "COMPLETED" in tasks_grp.groups.keys():
        num_completed = tasks_grp.count().loc["COMPLETED"][0]

    return num_tasks, num_completed

# Query the total number of tasks being tracked
num_tasks_old, _ = query_tasks()
print(f"[INFO] There are currently {num_tasks_old} tasks being tracked.")

# Initialise a progress bar
pbar = tqdm(total=num_tasks_old,
           dynamic_ncols=True,
           leave=False,
           position=0,
           desc="Task Progress")

# Query the status every <interval_s> and update bar
interval_s = 1
num_completed_old = 0
while True:
    num_tasks, num_completed = query_tasks()
    if num_tasks > num_tasks_old:
        print(f"[WARN] Total number of tasks has changed!\n"
              f"       Restart cell after all tasks have been submitted.")
        break
    num_increment = num_completed - num_completed_old
    num_completed_old = num_completed
    pbar.update(num_increment)
    now = datetime.now()
    poll_time = now.strftime("%H:%M:%S")
    print(f"[INFO] Last polling time: {poll_time}", end="\r")
    
    # Exit if all tasks completed
    if num_tasks == 0 or num_completed == num_tasks:
        print(f"[INFO] No active tasks remaining.")
        break
    
    # Increment after interval
    time.sleep(interval_s)