## Import lib

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import plotly
import plotly.graph_objects as go
import cv2
import os, sys


## Define functions

In [None]:
def get_dict_raster(df_data, round_digit=0, max_img_size=300, replace_nan_with_zero=False):
    '''
    Get a dictionary of raster for Plotly Surface

    Params
    ---------
    df_data: Pandas Dataframe
        A dataframe of monthly PM2.5 prediction

    Returns
    ---------
    A tuple of (arr_raster_dem, dict_raster_pm25)    
    arr_raster_dem:Numpy Array of DEM,
    dict_raster_pm25:{
        '<year_month_1>': Numpy array of PM2.5 raster,
        '<year_month_2>': Numpy array of PM2.5 raster,
        ...
        '<year_month_N>': Numpy array of PM2.5 raster
    }


    '''
    
    
    #DEM    
    df_dem = df_data.iloc[:, []].reset_index()
    arr_raster_dem = df_dem.pivot_table(values='dem', index='lat', columns='lon', aggfunc='mean').values    
    nrows, ncols = arr_raster_dem.shape
    img_size = np.max([nrows, ncols])
    if ncols <= nrows:
        pad_size = round((nrows-ncols)/2)
        arr_raster_dem = np.pad(arr_raster_dem, pad_width=((0, 0), (pad_size, pad_size)),mode='constant', constant_values=np.nan)
        if img_size > max_img_size:
            arr_raster_dem = res = cv2.resize(arr_raster_dem, dsize=(max_img_size, max_img_size), interpolation=cv2.INTER_CUBIC)
        arr_raster_dem = np.round(arr_raster_dem, round_digit)
    if replace_nan_with_zero:
        arr_raster_dem = np.where(np.isnan(arr_raster_dem), 0, arr_raster_dem)

    
    

    #PM2.5
    list_year_month = sorted(df_data.columns)
    dict_raster_pm25 = {}

    for year_month in list_year_month:

        df_pm25 = df_data.loc[:, [year_month]].copy().reset_index()
        arr_raster_pm25 = df_pm25.pivot_table(values=year_month, index='lat', columns='lon', aggfunc='mean').values
        assert(arr_raster_pm25.shape==(nrows, ncols))        
        if ncols <= nrows:
            pad_size = round((nrows-ncols)/2)
            arr_raster_pm25 = np.pad(arr_raster_pm25, pad_width=((0, 0), (pad_size, pad_size)),mode='constant', constant_values=np.nan)
            if img_size > max_img_size:
                arr_raster_pm25 = res = cv2.resize(arr_raster_pm25, dsize=(max_img_size, max_img_size), interpolation=cv2.INTER_CUBIC)
            arr_raster_pm25 = np.round(arr_raster_pm25, round_digit)
        if replace_nan_with_zero:
            arr_raster_pm25 = np.where(np.isnan(arr_raster_pm25), 0, arr_raster_pm25)
        dict_raster_pm25[year_month] = arr_raster_pm25
        
        
    return arr_raster_dem, dict_raster_pm25

In [None]:
def get_go_surface_trace(z, s, year_month):
    cmax = 300
    cmin = 0 
    return go.Surface(z=z, surfacecolor=s, colorscale ='Reds',
            hovertemplate="PM2.5: %{surfacecolor:0.0f} μg/m3<br>Elevation: %{z} m",
            hovertext=s,
            cmax=cmax, cmin=cmin ,name=year_month
            )   

## Load data

In [None]:
df_predict_pm25_monthly = pd.read_parquet(r'../data/df_predict_pm25_monthly_mean.parquet')
df_predict_pm25_monthly = df_predict_pm25_monthly.iloc[:, ].round(1).copy()
df_predict_pm25_monthly

In [None]:
df_predict_pm25_monthly.to_csv(r'../data/df_predict_pm25_monthly.csv')

## Animated 3D Surface

In [None]:
list_apr_col = [col for col in df_predict_pm25_monthly.columns if col.endswith('-04')]
list_apr_col

In [None]:
max_resol = 200

df_data = df_predict_pm25_monthly.loc[:, list_apr_col].copy()
arr_raster_dem, dict_raster_pm25 = get_dict_raster(df_data, 1, max_resol, replace_nan_with_zero=True)



cmax = 300
cmin = 0
list_year_month = sorted(dict_raster_pm25.keys())
year_month_start = list_year_month[0]
target_size = np.max(arr_raster_dem.shape)
if max_resol < target_size:
    target_size = max_resol



list_data = []


sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Year-Month:",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 300, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": []
}

   

fig = go.Figure(    
    layout=go.Layout(        
        title='3D Map of Chiangmai PM2.5 in April (Y2000 to Present)',
        updatemenus=[
    {
        "buttons": [
            {
                "args": [None, {"frame": {"duration": 500, "redraw": True},
                                "fromcurrent": True, "transition": {"duration": 300,
                                                                    "easing": "quadratic-in-out"}}],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {"frame": {"duration": 0, "redraw": True},
                                  "mode": "immediate",
                                  "transition": {"duration": 0}}],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }
]
    ),
    
)

list_frames = []
first_year_month = True
for year_month in dict_raster_pm25.keys():
    arr_raster_pm25 = dict_raster_pm25[year_month]
    go_tmp = get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25, year_month=year_month)
    if first_year_month:                    
        fig.add_trace(go_tmp)
        first_year_month = False
    list_frames.append(
        go.Frame(
            data=[get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25, year_month=year_month)],            
            name=year_month
        )
    )
    sliders_dict['steps'].append(
            {'args': [
                [year_month],
                {'frame': {'duration': 500, 'redraw': True},
                'mode': 'immediate',
            'transition': {'duration': 0}}
            ],
            'label': year_month,
            'method': 'animate'}    
                )


fig.update_scenes(
    xaxis_range=(0, target_size),  
    yaxis_range=(0, target_size),  
    zaxis_range=(0, 50000),  
    xaxis_visible=False,
    yaxis_visible=False,
    zaxis_visible=False 
)



x_eye = 0.25
y_eye = -.75
z_eye = 0.3
x_cen = 0.1
y_cen = -.3
z_cen = -0.2
camera = dict(
    eye=dict(x=x_eye, y=y_eye, z=z_eye),
    center=dict(x=x_cen, y=y_cen, z=z_cen)
    # up=dict(x=0, y=0., z=3),
)

fig.update_layout(scene_camera=camera,)

fig.update_layout( 
    autosize=True,
    width=800, height=800,
    margin=dict(l=100, r=100, b=200, t=100),         
    )

fig.frames = list_frames

fig.update_layout(
    sliders=[sliders_dict]
)


fig.show()

fig.write_json(f'fig-{max_resol}.json')
fig.write_html(f'fig-{max_resol}.html', full_html=False, include_plotlyjs=False)

In [None]:
def get_camera_setup(xeye, yeye, zeye, xup, yup, zup, xcen, ycen, zcen):
    dict_camera_setup = dict(
        eye=dict(x=xeye, y=yeye, z=zeye),
        up=dict(x=xup, y=yup, z=zup),
        center=dict(x=xcen, y=ycen, z=zcen),
    )
    return dict_camera_setup

In [None]:

for i in range(-5, 5):
    dict_camera_setup = get_camera_setup(
        0,0,0,
        0,0,i,
        1,-1,5
    )
    print(dict_camera_setup)
    list_camera_setup.append(dict_camera_setup)
    fig.update_layout(
        scene_camera=dict_camera_setup,
        width=500, height=500,
        )
    fig.show()
    print('##############')


In [None]:

# for i in range(0, 10):
#     dict_camera_setup = get_camera_setup(*np.random.randint(-3, 3, size=9))
#     print(dict_camera_setup)
#     list_camera_setup.append(dict_camera_setup)
#     fig.update_layout(
#         scene_camera=dict_camera_setup,
#         width=500, height=500,
#         )
#     fig.show()
#     print('##############')
