In [203]:
import os, sys
import glob
import iris
import datetime
import numpy as np
import configparser
import json
import iris.coord_categorisation
import pandas as pd
from bokeh.plotting import figure, show, save, output_file
from bokeh.models import ColumnDataSource, Patches, Plot, Title
from bokeh.models import HoverTool
from bokeh.models import Range1d, LinearColorMapper, ColorBar
from bokeh.models import GeoJSONDataSource
from bokeh.palettes import GnBu9, Magma6, Greys256, Greys9, GnBu9, RdPu9, TolRainbow12
from bokeh.palettes import Iridescent23, TolYlOrBr9, Bokeh8, Greys9, Blues9
from skimage import measure
import warnings

# Set the global warning filter to ignore all warnings
warnings.simplefilter("ignore")
%matplotlib tk

In [102]:
def prepare_calendar(cube):
    # Setting up the dates on data
    for coord_name, coord_func in [('year', iris.coord_categorisation.add_year),
                                   ('month_number', iris.coord_categorisation.add_month_number),
                                   ('day_of_month', iris.coord_categorisation.add_day_of_month),
                                   ('hour', iris.coord_categorisation.add_hour)]:
        if not cube.coords(coord_name):
            coord_func(cube, 'time', name=coord_name)
    return cube

def create_dates_dt(cube):
    cube = prepare_calendar(cube)
    cube_dates_dt = [datetime.datetime(y, m, d, h) for y, m, d, h in zip(cube.coord('year').points,
                                                                         cube.coord('month_number').points,
                                                                         cube.coord('day_of_month').points,
                                                                         cube.coord('hour').points)]
    return cube_dates_dt

In [103]:
date = datetime.datetime(2024, 10, 20)
mog_forecast_processed_dir = "/scratch/hadpx/SEA_monitoring/processed_SEA_data/mogreps/eqwaves"

self_times2plot = [t for t in range(-96, 174, 6)]
mem_labels = [f'{fc:03}' for fc in range(0, 17)]
self_pressures = ['850']
self_thresholds = {'Precip': 5,
                           'Kelvin_850': -1 * 1e-6, 'Kelvin_200': -2 * 1e-6,
                           'WMRG_850': -1 * 1e-6, 'WMRG_200': -2 * 1e-6,
                           'R1_850': 5 * 1e-6, 'R1_200': 2 * 1e-6,
                           'R2_850': 2.5 * 1e-6, 'R2_200': 2 * 1e-6}

str_hr = '00'
date_label = date.strftime('%Y%m%d_%H')
outfile_dir = os.path.join(mog_forecast_processed_dir, date_label)
html_file_dir = "./"#os.path.join(self.config_values['mog_plot_ens'], date_label)

In [104]:
precip_files = [os.path.join(outfile_dir, f'precipitation_flux_combined_{date_label}Z_{mem}.nc') for mem in
                        mem_labels]
#print(precip_files)
precip_files = [file for file in precip_files if os.path.exists(file)]

# Precip cubes
pr_cube = iris.load_cube(precip_files)

ntimes = len(times2plot)

pr_cube = pr_cube[:, -ntimes:]
pr_cube

Unknown (kg m-2),realization,time,latitude,longitude
Shape,17,45,49,360
Dimension coordinates,,,,
realization,x,-,-,-
time,-,x,-,-
latitude,-,-,x,-
longitude,-,-,-,x
Attributes,,,,
Conventions,'CF-1.7','CF-1.7','CF-1.7','CF-1.7'


In [41]:
data = [(i, l, d) for i, l, d in zip(range(ntimes), self_times2plot, create_dates_dt(pr_cube))]

# Creating DataFrame
df = pd.DataFrame(data, columns=['Index', 'Lead', 'Date'])
#df

In [12]:
shade_var = pr_cube.collapsed('realization', iris.analysis.PROPORTION,
                                      function=lambda values: values > self_thresholds['precip'])

In [257]:
self_wave_names = np.array(['Precip','Kelvin', 'WMRG', 'R1', 'R2'])

In [258]:
for pressure_level in self_pressures:
    wave_data = []
    wave_timestep_dic = {}
    for wname in self_wave_names:
        print(f'PLOTTING WAVE {wname}!!!!!!!!!!!!!!!!!!!!!!!!!')
        if wname in ['Precip']:
            wave_files = [os.path.join(outfile_dir, f'precipitation_flux_combined_{date_label}Z_{mem}.nc') for mem in
                          mem_labels]
            wave_files = [file for file in wave_files if os.path.exists(file)]
            wave_variable = iris.load_cube(wave_files)
            wave_variable = wave_variable[:, -ntimes:]
            contour_var = wave_variable.collapsed('realization', iris.analysis.PROPORTION,
                                                  function=lambda values: values >= self_thresholds[wname])
            contour_cbar_title = f"Probability of {wname} >= {self_thresholds[wname]}"
            contour_var.rename(contour_cbar_title)
            
        if wname in ['Kelvin', 'WMRG']:
            wave_files = [os.path.join(outfile_dir, f'div_wave_{wname}_{date_label}Z_{mem}.nc') for mem in
                          mem_labels]
            wave_files = [file for file in wave_files if os.path.exists(file)]
            wave_variable = iris.load_cube(wave_files)
            wave_variable = wave_variable.extract(iris.Constraint(pressure=float(pressure_level)))
            wave_variable = wave_variable[:, -ntimes:]
            contour_var = wave_variable.collapsed('realization', iris.analysis.PROPORTION,
                                                  function=lambda values: values <= self_thresholds[
                                                      f'{wname}_{pressure_level}'])
            contour_cbar_title = f"Probability of {wname} divergence <= {self_thresholds[f'{wname}_{pressure_level}']:0.1e} s-1"
            contour_var.rename(contour_cbar_title)
        elif wname in ['R1', 'R2']:
            wave_files = [os.path.join(outfile_dir, f'vort_wave_{wname}_{date_label}Z_{mem}.nc') for mem in
                          mem_labels]
            wave_files = [file for file in wave_files if os.path.exists(file)]
            wave_variable = iris.load_cube(wave_files)
            wave_variable = wave_variable.extract(iris.Constraint(pressure=float(pressure_level)))
            wave_variable = wave_variable[:, -ntimes:]
            contour_var = wave_variable.collapsed('realization', iris.analysis.PROPORTION,
                                                  function=lambda values: values >= self_thresholds[
                                                      f'{wname}_{pressure_level}'])
            contour_cbar_title = f"Probability of {wname} vorticity >= {self_thresholds[f'{wname}_{pressure_level}']:0.1e} s-1"
            contour_var.rename(contour_cbar_title)
        for lead in self_times2plot[18:19]:
            t = df.loc[df['Lead'] == lead].Index.values[0]
            datetime_string = df['Date'].loc[df['Lead'] == lead].astype('O').tolist()[0].strftime(
                '%Y/%m/%d %HZ')
            wave_timestep_dic[wname] = contour_var[t].data
            wave_timestep_dic['latitude'] = contour_var[t].coord('latitude').points
            wave_timestep_dic['longitude'] = contour_var[t].coord('longitude').points

PLOTTING WAVE Precip!!!!!!!!!!!!!!!!!!!!!!!!!
PLOTTING WAVE Kelvin!!!!!!!!!!!!!!!!!!!!!!!!!
PLOTTING WAVE WMRG!!!!!!!!!!!!!!!!!!!!!!!!!
PLOTTING WAVE R1!!!!!!!!!!!!!!!!!!!!!!!!!
PLOTTING WAVE R2!!!!!!!!!!!!!!!!!!!!!!!!!


In [259]:
wave_timestep_dic.keys()

dict_keys(['Precip', 'latitude', 'longitude', 'Kelvin', 'WMRG', 'R1', 'R2'])

In [260]:
def get_skimage_contour_paths(lons, lats, cube_data, levels=[0.5, 0.75]):
    paths_x, paths_y = [], []
    for level in levels:
        contours = measure.find_contours(cube_data, level)
        for contour in contours:
            paths_x.append(contour[:, 1] + min(lons))
            paths_y.append(contour[:, 0] + min(lats))
    #for i in range(len(paths_x)):
    #    plt.plot(paths_x[i], paths_y[i])
    return paths_x, paths_y

In [262]:
#get_skimage_contour_paths(lons, lats, wave_timestep_dic['Kelvin'], levels=[0.5, 0.7, 0.9])

In [352]:
import numpy as np
from bokeh.io import output_file, save
from bokeh.layouts import column, row, Spacer
from bokeh.models import ColumnDataSource, CheckboxGroup, CheckboxButtonGroup, CustomJS, Button, LinearColorMapper, ColorBar
from bokeh.models import Legend, LegendItem
from bokeh.plotting import figure

# Set output HTML file
output_file("contour_selection_with_multiselect.html")

x_range = (0, 180)  # could be anything - e.g.(0,1)
y_range = (-24, 24)
width = 1100
contour_alpha = 0.5
shade_cbar_title = 'Probability of precipitation >=5 mm/day'
aspect = (max(x_range) - min(x_range)) / (max(y_range) - min(y_range))
height = int(width / (.65 * aspect))

# Prepare the initial image source
precip_source = ColumnDataSource(data=dict(Precip=[wave_timestep_dic['Precip']]))
contour_levels = [0.66]#[0.5, 0.7, 0.9]

# Prepare contour paths using functions for Kelvin and WMRG
lats, lons = wave_timestep_dic['latitude'], wave_timestep_dic['longitude']
contour_kelvin_x, contour_kelvin_y = get_skimage_contour_paths(lons, lats, wave_timestep_dic['Kelvin'], levels=contour_levels)
contour_wmrg_x, contour_wmrg_y = get_skimage_contour_paths(lons, lats, wave_timestep_dic['WMRG'], levels=contour_levels)
contour_R1_x, contour_R1_y = get_skimage_contour_paths(lons, lats, wave_timestep_dic['R1'], levels=contour_levels)
contour_R2_x, contour_R2_y = get_skimage_contour_paths(lons, lats, wave_timestep_dic['R2'], levels=contour_levels)

# Create separate ColumnDataSources for each contour field
kelvin_source = ColumnDataSource(data=dict(xs=contour_kelvin_x, ys=contour_kelvin_y))
wmrg_source = ColumnDataSource(data=dict(xs=contour_wmrg_x, ys=contour_wmrg_y))
r1_source = ColumnDataSource(data=dict(xs=contour_R1_x, ys=contour_R1_y))
r2_source = ColumnDataSource(data=dict(xs=contour_R2_x, ys=contour_R2_y))

# Set up the plot
plot = figure(height=height, width=width, x_range=x_range, y_range=y_range,
              tools=["pan", "reset", "save", "wheel_zoom", "hover"],
              x_axis_label='Longitude', y_axis_label='Latitude', aspect_scale=4)

# Create a color mapper for the image
color_mapper_z = LinearColorMapper(palette='Iridescent23', low=0.5, high=1)
color_bar = ColorBar(color_mapper=color_mapper_z, major_label_text_font_size="12pt",
                     label_standoff=6, border_line_color=None, orientation="horizontal",
                     location=(0, 0), width=400, title=shade_cbar_title, title_text_font_size="12pt")
image_renderer = plot.image('Precip', source=precip_source, x=0, y=-24, dw=360, dh=48, alpha=0.8, color_mapper=color_mapper_z)
plot.add_layout(color_bar, 'below')

self_map_outline_json_file = 'custom.geo.json'
with open(self_map_outline_json_file, "r") as f:
    countries = GeoJSONDataSource(geojson=f.read())

plot.patches("xs", "ys", color=None, line_color="grey", source=countries, alpha=0.75)

# Add empty MultiLine renderers for contours
kelvin_renderer = plot.multi_line(xs='xs', ys='ys', source=kelvin_source, line_width=4, color="blue", alpha=contour_alpha)
wmrg_renderer = plot.multi_line(xs='xs', ys='ys', source=wmrg_source, line_width=4, color="green", alpha=contour_alpha)
r1_renderer = plot.multi_line(xs='xs', ys='ys', source=r1_source, line_width=4, color="red", alpha=contour_alpha)
r2_renderer = plot.multi_line(xs='xs', ys='ys', source=r2_source, line_width=4, color="orange", alpha=contour_alpha)

# Create Legend items manually
legend_items = [
    LegendItem(label="Kelvin convergence", renderers=[kelvin_renderer]),
    LegendItem(label="WMRG convergence", renderers=[wmrg_renderer]),
    LegendItem(label="n=1 Rossby cyclonic vorticity", renderers=[r1_renderer]),
    LegendItem(label="n=2 Rossby cyclonic vorticity", renderers=[r2_renderer]),
]

# Create a Legend and set its properties
legend = Legend(items=legend_items, title="Click to show/hide", label_text_font_size="10pt", title_text_font_size="11pt", 
               location=(0, 0.5), background_fill_alpha = 0.75)
legend.click_policy = "hide"  # Allow toggling visibility on click

# Add legend to the plot (set it outside the main plot area)
plot.add_layout(legend)


# Create a manual legend with colored vertical lines and labels
#legend_fig = figure(width=200, height=height, toolbar_location=None)
#legend_fig.xaxis.visible = False
#legend_fig.yaxis.visible = False

# Draw lines for each legend item
#colors = ["blue", "green", "red", "orange"]
#labels = ["Kelvin convergence", "WMRG convergence", "n=1 Rossby cyclonic vorticity", "n=2 Rossby cyclonic vorticity"]
#y_positions = [20, 15, 10, 5]

#for color, label, y in zip(colors, labels, y_positions):
#    legend_fig.line(x=[0.2, 0.25], y=[y, y], line_width=4, color=color)
#    legend_fig.text(x=0.25, y=y, text=[label], text_font_size="10pt", text_baseline="middle")


# Create CheckboxGroup to select multiple fields
#checkbox_group = CheckboxGroup(labels=["Kelvin", "WMRG", "R1", "R2"], active=[0, 1, 2, 3], height=100, width=500)
checkbox_group = CheckboxButtonGroup(labels=["Kelvin", "WMRG", "n=1 Rossby", "n=2 Rossby"], active=[0, 1, 2, 3])

# Create a "Clear All" button
clear_button = Button(label="Clear All", button_type="danger")

# JavaScript callback to toggle visibility of contour lines based on selection
checkbox_callback = CustomJS(args=dict(kelvin_renderer=kelvin_renderer, wmrg_renderer=wmrg_renderer, 
                                       r1_renderer=r1_renderer, r2_renderer=r2_renderer, 
                                       checkbox=checkbox_group), code="""
    // Set visibility based on checkbox selection
    kelvin_renderer.visible = checkbox.active.includes(0); // Kelvin is label 0
    wmrg_renderer.visible = checkbox.active.includes(1); // WMRG is label 1
    r1_renderer.visible = checkbox.active.includes(2); // R1 is label 2
    r2_renderer.visible = checkbox.active.includes(3); // R2 is label 3
""")

# Attach the callback to checkbox group
checkbox_group.js_on_change('active', checkbox_callback)

# JavaScript callback for the "Clear All" button
clear_button_callback = CustomJS(args=dict(checkbox=checkbox_group), code="""
    // Clear all active checkboxes
    checkbox.active = [];
    checkbox.change.emit();
""")
# Link clear button with its callback
clear_button.js_on_click(clear_button_callback)

spacer = Spacer(width=50)  # Adjust width for desired space
layout = column(row(spacer, checkbox_group, clear_button), row(plot))

# Save the plot layout as a static HTML file
save(layout)


'/home/users/prince.xavier/MJO/Monitoring_new/EQWAVES/notebooks/contour_selection_with_multiselect.html'