In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import plotly

import plotly.graph_objects as go

In [None]:
def get_dict_raster(df_data, round_digit=0):
    '''
    Get a dictionary of raster for Plotly Surface

    Params
    ---------
    df_data: Pandas Dataframe
        A dataframe of monthly PM2.5 prediction

    Returns
    ---------
    A tuple of (arr_raster_dem, dict_raster_pm25)    
    arr_raster_dem:Numpy Array of DEM,
    dict_raster_pm25:{
        '<year_month_1>': Numpy array of PM2.5 raster,
        '<year_month_2>': Numpy array of PM2.5 raster,
        ...
        '<year_month_N>': Numpy array of PM2.5 raster
    }


    '''
    
    
    #DEM    
    df_dem = df_data.iloc[:, []].reset_index()
    arr_raster_dem = df_dem.pivot_table(values='dem', index='lat', columns='lon', aggfunc='mean').values
    nrows, ncols = arr_raster_dem.shape
    if ncols <= nrows:
        pad_size = round((nrows-ncols)/2)
        arr_raster_dem = np.pad(arr_raster_dem, pad_width=((0, 0), (pad_size, pad_size)),mode='constant', constant_values=np.nan)
        arr_raster_dem = np.round(arr_raster_dem, round_digit)

    #PM2.5
    list_year_month = sorted(df_data.columns)
    dict_raster_pm25 = {}

    for year_month in list_year_month:

        df_pm25 = df_data.loc[:, [year_month]].copy().reset_index()
        arr_raster_pm25 = df_pm25.pivot_table(values=year_month, index='lat', columns='lon', aggfunc='mean').values
        assert(arr_raster_pm25.shape==(nrows, ncols))        
        if ncols <= nrows:
            pad_size = round((nrows-ncols)/2)
            arr_raster_pm25 = np.pad(arr_raster_pm25, pad_width=((0, 0), (pad_size, pad_size)),mode='constant', constant_values=np.nan)
            arr_raster_pm25 = np.round(arr_raster_pm25, round_digit)
        dict_raster_pm25[year_month] = arr_raster_pm25
        
        
    return arr_raster_dem, dict_raster_pm25

In [None]:
df_predict_pm25_monthly = pd.read_parquet(r'../data/df_predict_pm25_monthly.parquet')

In [None]:
df_data = df_predict_pm25_monthly.iloc[:, -10:].copy()

arr_raster_dem, dict_raster_pm25 = get_dict_raster(df_data, 1)


for year_month in dict_raster_pm25.keys():
    plt.imshow(np.flip(dict_raster_pm25[year_month], axis=0))
    plt.show()



## Version 1

In [None]:
def get_go_surface_trace(z, s, year_month):
    cmax = 300
    cmin = 0 
    return go.Surface(z=z, surfacecolor=s,
            hovertemplate="PM2.5: %{surfacecolor:0.0f} μg/m3<br>Elevation: %{z} m",
            hovertext=s,
            cmax=cmax, cmin=cmin ,name=year_month)   

In [None]:
import plotly.graph_objects as go
zoom_level = 2.5
cmax = 300
cmin = 0
list_year_month = sorted(dict_raster_pm25.keys())
year_month_start = list_year_month[0]
arr_raster_pm25_start = dict_raster_pm25[year_month_start]
target_size = np.max(arr_raster_pm25_start.shape)
list_data_start = [
    get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25_start, year_month=year_month_start)
]

list_frames = []
for year_month in dict_raster_pm25.keys():
    if year_month != year_month_start:       
        arr_raster_pm25 = dict_raster_pm25[year_month]
        list_frames.append(
            go.Frame(
                data=[get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25, year_month=year_month)],
                layout=go.Layout(title_text=year_month)
            )
        )
    

fig = go.Figure(
    data=list_data_start,
    layout=go.Layout(        
        title=year_month_start,
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=list_frames
)


fig.update_scenes(
    xaxis_range=(0, target_size),  
    yaxis_range=(0, target_size),  
    zaxis_range=(0, 20000),  
    xaxis_visible=False,
    yaxis_visible=False,
    zaxis_visible=False 
)

xcamera = 3 / zoom_level
ycamera = -4 / zoom_level
zcamera = 1.5 / zoom_level
camera = dict(
    eye=dict(x=xcamera, y=ycamera, z=zcamera)
)

fig.update_layout(
    scene_camera=camera,
    autosize=True,
    width=800, height=800,
    margin=dict(l=0, r=0, b=0, t=0), 
    )

fig.show()

## DEV

In [None]:
import plotly.graph_objects as go

df_data = df_predict_pm25_monthly.iloc[:, -5:].copy()
arr_raster_dem, dict_raster_pm25 = get_dict_raster(df_data, 1)


zoom_level = 2.5
cmax = 300
cmin = 0
list_year_month = sorted(dict_raster_pm25.keys())
year_month_start = list_year_month[0]
target_size = np.max(arr_raster_pm25_start.shape)

list_data = []

sliders_dict = {
    'active': 2,
    'yanchor': 'top',
    'xanchor': 'left',
    'currentvalue': {
        'font': {'size': 14},
        'prefix': 'Year:',
        'visible': True,
        'xanchor': 'right'
    },
    'transition': {'duration': 300},
    'pad': {'b': 10, 't': 10},
    'len': 0.9,
    'x': 0.1,
    'y': 0.1,
    'steps': []
}


    

fig = go.Figure(    
    layout=go.Layout(        
        title='Chiangmai PM2.5 Map',
        updatemenus=[
    {
        'buttons': [
            {
                'args': [None, {'frame': {'duration': 300, 'redraw': True},
                         'fromcurrent': True, 'transition': {'duration': 0}}],
                'label': 'Play',
                'method': 'animate'
            },
            {
                'args': [[None], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate',
                'transition': {'duration': 0}}],
                'label': 'Pause',
                'method': 'animate'
            }
        ],
        'direction': 'left',
        'pad': {'r': 10, 't': 10},
        'showactive': False,
        'type': 'buttons',
        'x': 0.1,
        'xanchor': 'right',
        'y': 0,
        'yanchor': 'top'
    }
]
    ),
    
)

list_frames = []
first_year_month = True
for year_month in dict_raster_pm25.keys():
    arr_raster_pm25 = dict_raster_pm25[year_month]
    if first_year_month:            
        fig.add_trace(        
                get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25, year_month=year_month)
            )
        first_year_month = False
    else:
        list_frames.append(
            go.Frame(
                data=[get_go_surface_trace(z=arr_raster_dem, s=arr_raster_pm25, year_month=year_month)],
                layout=go.Layout(title_text=year_month)
            )
        )
        sliders_dict['steps'].append(
                {'args': [
                    [year_month],
                    {'frame': {'duration': 500, 'redraw': False},
                    'mode': 'immediate',
                'transition': {'duration': 0}}
                ],
                'label': year_month,
                'method': 'update'}    
                    )




fig.update_scenes(
    xaxis_range=(0, target_size),  
    yaxis_range=(0, target_size),  
    zaxis_range=(0, 20000),  
    # xaxis_visible=False,
    # yaxis_visible=False,
    # zaxis_visible=False 
)

xcamera = 3 / zoom_level
ycamera = -4 / zoom_level
zcamera = 1.5 / zoom_level
camera = dict(
    eye=dict(x=xcamera, y=ycamera, z=zcamera)
)

fig.update_layout(
    scene_camera=camera,
    autosize=True,
    width=500, height=500,
    margin=dict(l=100, r=100, b=200, t=100),         
    )

fig.update_layout(
    sliders=[sliders_dict]
)

fig.frames = list_frames


fig.show()

# V3

In [None]:
import pandas as pd
import plotly.graph_objects as go



# Dataset
x = [10, 1, 3, 4, 5, 6, 7, 8, 9, 10]
y = [10, 1, 3, 4, 5, 6, 7, 8, 9, 10]
df = pd.DataFrame(list(zip(x, y)), columns = ['x', 'y'])

# Adding a trace

trace = go.Surface(x=df.x[0:2], y=df.y[0:2], z=df.y[0:2],
                            name='Location',
                            # mode='markers',
                            # marker=dict(color="white", 
                            #             size=10,
                            #             line=dict(
                            #             color='DarkSlateGrey',
                            #             width=2)
                            #            )
                            )

# Adding frames
frames = [dict(name=k,data= [dict(type='surface',
                           x=df.x[k:k + 1],
                           y=df.y[k:k + 1],
                            ),
                        ],
               traces = [0], 
              ) for k  in  range(len(df) - 1)][:3]

layout = go.Layout(width=650,
                height=650,
                showlegend=False,
                hovermode='closest',
                updatemenus=[dict(type='buttons', showactive=False,
                                y=-.1,
                                x=0,
                                xanchor='left',
                                yanchor='top',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Play',
                                            method='animate',
                                            args=[None, 
                                                    dict(frame=dict(duration=200, redraw=False), 
                                                        transition=dict(duration=0),
                                                        fromcurrent=True,
                                                        mode='immediate'
                                                                )
                                            ]),
                                        dict(label='Pause', # https://github.com/plotly/plotly.js/issues/1221 / https://plotly.com/python/animations/#adding-control-buttons-to-animations
                                            method='animate',
                                            args=[[None],
                                                    dict(frame=dict(duration=0, redraw=False), 
                                                        transition=dict(duration=0),
                                                        fromcurrent=True,
                                                        mode='immediate' )
                                            ])
                                        ])
                            ])

fig = go.Figure(data=[trace], frames=frames, layout=layout)

# Adding a slider
sliders = [{
        'yanchor': 'top',
        'xanchor': 'left', 
        'active': 1,
        'currentvalue': {'font': {'size': 16}, 'prefix': 'Steps: ', 'visible': True, 'xanchor': 'right'},
        'transition': {'duration': 200, 'easing': 'linear'},
        'pad': {'b': 10, 't': 50}, 
        'len': 0.9, 'x': 0.15, 'y': 0, 
        'steps': [{'args': [[k+1], {'frame': {'duration': 200, 'easing': 'linear', 'redraw': False},
                                    'transition': {'duration': 0, 'easing': 'linear'}}], 
                    'label': k+1, 'method': 'animate'} for k in range(len(df) - 1)       
                ][:3]}]

fig['layout'].update(sliders=sliders)
fig.show()