# Create .PNG images of all timesteps in IDX firesmoke dataset

## Import necessary libraries

In [22]:
# for numerical work
import numpy as np

# for accessing file system
import os

# for downloading latest firesmoke netCDF
import requests

# for loading netcdf files, for metadata
import xarray as xr
# for connecting OpenVisus framework to xarray
# from https://github.com/sci-visus/openvisuspy, 
from openvisuspy.xarray_backend import OpenVisusBackendEntrypoint

# Used for processing netCDF time data
import time
import datetime

# Used for indexing via metadata
import pandas as pd

# for plotting
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt

# for exporting the dictionary of issue files at the end of notebook
import pickle

# Stores the OpenVisus cache in the local direcrtory 
import os
os.environ["VISUS_CACHE"]="./visus_cache_can_be_erased"

# Accessory, used to generate progress bar for running for loops
# from tqdm.notebook import tqdm
# import ipywidgets
# import jupyterlab_widgets
from tqdm import tqdm

from OpenVisus import *

### In this section, we load our data using `xr.open_dataset`.

In [2]:
# path to tiny NetCDF
url = 'https://github.com/sci-visus/NSDF-WIRED/raw/main/data/firesmoke_metadata.nc'

# # Download the file using requests
# response = requests.get(url)
local_netcdf = 'firesmoke_metadata.nc'
# with open(local_netcdf, 'wb') as f:
#     f.write(response.content)
    
# open tiny netcdf with xarray and OpenVisus backend
ds = xr.open_dataset(local_netcdf, engine=OpenVisusBackendEntrypoint)

ov.LoadDataset(http://atlantis.sci.utah.edu/mod_visus?dataset=UBC_fire_smoke_BSC&cached=1)
PM25
Adding field  PM25 shape  [27357, 381, 1081, 21] dtype  float32 labels  ['time', 'ROW', 'COL', 'resolution'] Max Resolution  20


## Calculate derived metadata using original metadata above to create coordinates
### This is required to allow for indexing of data via metadata

#### Calculate latitude and longitude grid

In [3]:
# Get metadata to compute lon and lat
xorig = ds.XORIG
yorig = ds.YORIG
xcell = ds.XCELL
ycell = ds.YCELL
ncols = ds.NCOLS
nrows = ds.NROWS

longitude = np.linspace(xorig, xorig + xcell * (ncols - 1), ncols)
latitude = np.linspace(yorig, yorig + ycell * (nrows - 1), nrows)

#### Using calculated latitude and longitude, create coordinates allowing for indexing data using lat/lon

In [4]:
# Create coordinates for lat and lon (credit: Aashish Panta)
ds.coords['lat'] = ('ROW', latitude)
ds.coords['lon'] = ('COL', longitude)

# Replace col and row dimensions with newly calculated lon and lat arrays (credit: Aashish Panta)
ds = ds.swap_dims({'COL': 'lon', 'ROW': 'lat'})

##### Return an array of the tflags as pandas timestamps

In [5]:
def parse_tflag(tflag):
    """
    Return the tflag as a datetime object
    :param list tflag: a list of two int32, the 1st representing date and 2nd representing time
    """
    # obtain year and day of year from tflag[0] (date)
    date = int(tflag[0])
    year = date // 1000 # first 4 digits of tflag[0]
    day_of_year = date % 1000 # last 3 digits of tflag[0]

    # create datetime object representing date
    final_date = datetime.datetime(year, 1, 1) + datetime.timedelta(days=day_of_year - 1)

    # obtain hour, mins, and secs from tflag[1] (time)
    time = int(tflag[1])
    hours = time // 10000 # first 2 digits of tflag[1]
    minutes = (time % 10000) // 100 # 3rd and 4th digits of tflag[1] 
    seconds = time % 100  # last 2 digits of tflag[1]

    # create final datetime object
    full_datetime = datetime.datetime(year, final_date.month, final_date.day, hours, minutes, seconds)
    return full_datetime

In [6]:
# get all tflags
tflag_values = ds['TFLAG'].values

# to store pandas timestamps
timestamps = []

# convert all tflags to pandas timestamps, store in timestamps list
for tflag in tflag_values:
    timestamps.append(pd.Timestamp(parse_tflag(tflag[0])))

# check out the first 3 timestamps
timestamps[0:3]

[Timestamp('2021-03-04 00:00:00'),
 Timestamp('2021-03-04 01:00:00'),
 Timestamp('2021-03-04 02:00:00')]

In [7]:
# set coordinates to each timestep with these pandas timestamps
ds.coords['time'] = ('time', timestamps)

### Load the ECCC data, generated as in `eccc_data_clean.ipynb`

In [8]:
df_eccc = pd.read_csv('PM25_2021_2022.csv')
eccc_lons = df_eccc['Longitude//Longitude'].values
eccc_lats = df_eccc['Latitude//Latitude'].values
eccc_dates = np.sort(df_eccc['Date//Date'].unique())

### Create set of all hours to query for

In [9]:
all_dates = []
for d in eccc_dates:
    for i in range(24):
        all_dates.append(pd.Timestamp(f'{d} {i}:00:00'))

In [10]:
for i in enumerate(all_dates[0:5]):
    print(i)

(0, Timestamp('2021-01-01 00:00:00'))
(1, Timestamp('2021-01-01 01:00:00'))
(2, Timestamp('2021-01-01 02:00:00'))
(3, Timestamp('2021-01-01 03:00:00'))
(4, Timestamp('2021-01-01 04:00:00'))


### Get IDX forecast at points closest to ECCC points

In [11]:
# find nearest lat, lon values in idx to eccc lat, lon values
data_resolution = 0
ds_eccc_pts = ds.loc[dict(resolution=data_resolution)].sel(
    lat=eccc_lats, lon=eccc_lons, method='nearest')

# loop thru the coords, populate array with coord vals and pm25 vals
idx_coords = np.column_stack([ds_eccc_pts.lat.values, ds_eccc_pts.lon.values])

In [12]:
idx_lats = idx_coords[:, 0]
idx_lons = idx_coords[:, 1]

## Create the frames

In [74]:
# set parameters for creating visualization of each timestep with matplotlib
my_norm = "log"
my_extent = [np.min(eccc_lons), np.max(eccc_lons), np.min(eccc_lats), np.max(eccc_lats)]
my_aspect = 'auto'
my_origin = 'lower'
my_cmap = 'hot'
my_vmin = 1e-1
my_vmax = 1500
fig_w, fig_h = 15, 6
issue_files = {}
full_data = True
if full_data:
    save_dir = "/usr/sci/scratch_nvme/arleth/frames/idx_v_eccc/full_data/"
else:
    save_dir = "/usr/sci/scratch_nvme/arleth/dump/idx_vs_eccc_frames/point_data/"
# google map tile parameters
tile_style = 'satellite'
tile_zoom = 5

def create_frame_catch_issues(frame_date_tuple):
    # frame number to save PNG as and date to visualize
    frame_num = frame_date_tuple[0]
    date = frame_date_tuple[1]

    # create visualization using matplotlib and cartopy geography lines
    try: # visualize data if it's available
        data_array_at_time = None

        # set up visualization
        google_terrain = cimgt.GoogleTiles(style=tile_style, cache=True)
        my_fig, my_plt = plt.subplots(figsize=(fig_w, fig_h), subplot_kw=dict(projection=ccrs.PlateCarree()))
        my_plt.set_extent(my_extent, crs=ccrs.PlateCarree())
        my_plt.set_aspect('auto')
        
        my_plt.gridlines(draw_labels=True)

        data_array_at_time = ds['PM25'].loc[date, :, :, data_resolution]
        # my_plt.add_image(google_terrain, tile_zoom)
        plot = my_plt.imshow(data_array_at_time, extent=my_extent, transform=ccrs.PlateCarree(),
                         aspect=my_aspect, origin=my_origin, cmap=my_cmap,
                         norm=my_norm, vmax=my_vmax, vmin=my_vmin,alpha=1)

        my_fig.suptitle(f'Ground level concentration of PM2.5 microns and smaller {date}\n')
        my_fig.colorbar(plot, location='right', label='ug/m^3')
        # add caption showing this is from IDX dataset
        plt.text(0.5, -0.1, 'IDX Data', ha='center', va='center', transform=my_plt.transAxes)

        # save visualization as a .PNG to our folder
        plt.savefig(save_dir + "frames%010d.png" % frame_num, dpi=280)
        plt.close(my_fig);  # close the figure after saving
        matplotlib.pyplot.close()
    except: # return empty image otherwise
        fig, ax = plt.subplots(figsize=(fig_w, fig_h))
        ax.axis('off')
        plt.text(.5, .5, 'IDX Data UNAVAILABLE', fontsize=20, horizontalalignment='center',
     verticalalignment='center',)
        # save visualization as a .PNG to our folder
        plt.savefig(save_dir + "frames%010d.png" % frame_num, dpi=280)
        plt.close(fig);  # close the figure after saving
        # plt.show()
        matplotlib.pyplot.close()

In [75]:
proc_lim = 10
# create frames, capturing issues 
with multiprocessing.Pool(processes=proc_lim) as pool:
    # Start a timer to measure how long the conversion takes
    start_time = time.time()
    print('starting')
    issues = pool.map(create_frame_catch_issues, enumerate([all_dates[0], all_dates[2000]]))
    print('done!')
    # End the timer and print the elapsed time
    end_time = time.time()
    print(f'Total elapsed time: {end_time - start_time}')

starting
Using Max Resolution:  20
Time: 512, max_resolution: 20, logic_box=(0, 1081, 0, 381), field: PM25
done!
Total elapsed time: 3.4202749729156494


#### old basemap attempt

Get latitude and longitude values to use basemap to plot them with corresponding PM2.5 values

In [None]:
# # populate these with all coordinates and values at current tstep
# # get PM25 values and provide 4 values, the colons mean select all lat and lon indices
# curr_ds = ds['PM25'].loc[0, :, :, 0]
# arr_size = ds.sizes['lat'] * ds.sizes['lon']
# lons = np.zeros(arr_size)
# lats = np.zeros(arr_size)
# vals = np.zeros(arr_size)
# c = 0

# for i in np.arange(ds.sizes['lat']):
#     for j in np.arange(ds.sizes['lon']):
#         lats[c] = curr_ds.lat.values[i]
#         lons[c] = curr_ds.lon.values[j]
#         vals[c] = curr_ds.values[i][j]
#         c += 1
# a = lons
# b = lats
# c = vals

In [None]:
# data_resolution = 0
# save_dir = "/usr/sci/scratch_nvme/arleth/dump/idx_vs_eccc_frames/"
# issue_files = {}

# def create_frame_catch_issues(tstep):
#     # get PM25 values and provide 4 values, the colons mean select all lat and lon indices
#     curr_ds = ds['PM25'].loc[tstep, :, :, data_resolution]
    
#     # populate these with all coordinates and values at current tstep
#     arr_size = ds.sizes['lat'] * ds.sizes['lon']
#     lons = np.zeros(arr_size)
#     lats = np.zeros(arr_size)
#     vals = np.zeros(arr_size)
#     c = 0
    
#     for i in np.arange(ds.sizes['lat']):
#         for j in np.arange(ds.sizes['lon']):
#             lats[c] = curr_ds.lat.values[i]
#             lons[c] = curr_ds.lon.values[j]
#             vals[c] = curr_ds.values[i][j]
#             c += 1
#     # # create visualization using matplotlib and cartopy geography lines, 
#     # # catch exceptions accordingly
#     # try:
#     # use basemap to plot values: https://basemaptutorial.readthedocs.io/en/latest/plotting_data.html#scatter
#     # use `cyl` project: https://matplotlib.org/basemap/stable/users/cyl.html
#     # set parameters: https://basemaptutorial.readthedocs.io/en/latest/basemap.html
#     m = Basemap(projection='cyl', llcrnrlat=np.min(lats),urcrnrlat=np.max(lats),
#             llcrnrlon=np.min(lons),urcrnrlon=np.max(lons), lat_0=51, lon_0=-106, resolution='l',
#             fix_aspect=False, area_thresh=1e6)

#     # Draw map features
#     m.drawcoastlines()
#     # m.drawparallels(np.arange(45.,66.,5.),labels=[1,1,1,1]) # draw parallels
#     # m.drawmeridians(np.arange(-120.,-59.,20.),labels=[1,1,1,1]) # draw parallels

#     # Convert lat/lon to map coordinates for Basemap scatter plot
#     x, y = m(lons, lats)

#     # Plot the lats and lons
#     m.scatter(x, y, marker=',', s=.001, c=vals, cmap='hot')
#     plt.show()
#     # except:
#     #     t = pd.Timestamp(parse_tflag(ds.TFLAG[tstep][0]))
#     #     print(f"issue! {t}")
#     #     return t, curr_values