# Plotly Animations
### By:  Reshama Shaikh
### COVID-19 JH Data: 
https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_time_series

---

In [None]:
from datetime import date

today = date.today()

print("Today's date:", today)

In [None]:
#!pip install matplotlib
#!pip install feather
#!pip install feather-format

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from dateutil import parser
import watermark
import feather
from pathlib import Path

%load_ext watermark
%matplotlib inline

In [None]:
import plotly.express as px
import plotly.graph_objects as go

In [None]:
%watermark -n -v -m -g -iv

---

# READ in cleaned dataframe

In [None]:
path_dataout = 'data_derived'
dfall = feather.read_dataframe(path_dataout+'/df_feather.file')

In [None]:
dfall.info()

In [None]:
dfall.head(2)

---

# Work with subset of data

<div class="admonition note alert alert-info">
<p class="first admonition-title" style="font-weight: bold;">Note</p>
<p class="last">Use subset of <tt class="docutils literal">data</tt> .</p>
</div>

In [None]:
dfall.shape

In [None]:
country_list = ['India', 'US']
df_subset1 = dfall[dfall['Country'].isin(country_list)]

In [None]:
df_subset1.head(2)

In [None]:
df_subset1.tail(2)

In [None]:
xval='day_of_case'
xval_list = df_subset1[xval].tolist()
x_max=max(xval_list)
x_min=min(xval_list)

print(x_min)
print(x_max)

In [None]:
align_cases = 100
xval='day_of_case'
yval='daily_case_count'

usedf = df_subset1[df_subset1[xval] >= int(align_cases)]

In [None]:
# 2) specify a color variable
fig = px.line(
    usedf
    , x=xval
    , y=yval
    , title = "COVID19 over Time (days)"
    , color="Country"
    #, hover_name="date"
    , range_x=[align_cases-50, x_max+50]
)
fig.show()

# 1) Animation: dot plot, two groups

- `method` [docs](https://plotly.com/python/reference/layout/updatemenus/#layout-updatemenus-items-updatemenu-buttons-items-button-method)

In [None]:
# read data in
path_dataout = 'data_derived'
dfall = feather.read_dataframe(path_dataout+'/df_feather.file')

# take subset of the data
country_list = ['India', 'US']
df_subset1 = dfall[dfall['Country'].isin(country_list)]

# get x-axis min and max values
xval_list = df_subset1['day_of_case'].tolist()
x_max=max(xval_list)
x_min=min(xval_list)

# align cases
align_cases = 10
xval='day_of_case'
yval='daily_case_count'

usedf = df_subset1[df_subset1[xval] >= int(align_cases)]

In [None]:
# adding 6/15/21
# this works, please don't touch

fig = px.scatter(usedf, x=xval, y=yval,
           animation_frame="day_of_case", animation_group="Country",
           color="Country",
           #size="pop", color="continent", hover_name="country",
           #log_x=True, size_max=55, 
           range_x=[-10,x_max+50], range_y=[-50000, 500000])

fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 10

frame_length = x_max + 50
slider_length = 0.75  #percent of x-axis length
method_type = 'animate'
method_type = 'restyle'   #['restyle', 'relayout', 'animate', 'update', 'skip']

sliders = [dict(steps = [dict(method= method_type,
                              args= [[f'frame{k+1}'], #HERE IS THE k^th FRAME NAME                          
                              dict(mode= 'immediate',
                                   frame= dict(duration=3, redraw=False),
                                   transition=dict(duration= 0))
                                 ],
                              label=f'{k+1}'  #label for each frame marked on the slider
                             ) 
                         for k in range(1, frame_length)], 
                active=1,
                transition= dict(duration= 0.2 ),
                x=0, # slider starting position; specify as: An int or float in the interval [-2, 3]  
                y=0, # 'y' property is a number, specify as: An int or float in the interval [-2, 3]
                currentvalue=dict(font=dict(size=12), 
                                  prefix=' ', 
                                  visible=True, #True, 
                                  xanchor= 'center'
                                 ),  
               len=slider_length) 
          ]

fig.update_layout(sliders=sliders)
#fig.update_layout(transition_duration=3)
fig.show()

In [None]:
# add

#text="Country",

In [None]:
# added 8/8/21
# don't touch above example; CAN EDIT THIS EXAMPLE

fig = px.scatter(usedf, x=xval, y=yval,
           animation_frame="day_of_case", animation_group="Country",
           color="Country",
           size="cases", 
           hover_name="Country",
           text="Country",
           #log_x=True, 
           size_max=55, 
           range_x=[-10,x_max + 50], range_y=[-10, 500000])

fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 20

sliders = [dict(steps = [dict(method= 'animate',
                              args= [[f'frame{k+1}'], #HERE IS THE k^th FRAME NAME                          
                              dict(mode= 'immediate',
                                   frame= dict(duration=3, redraw=False),
                                   transition=dict(duration= 0))
                                 ],
                              label=f'{k+1}'  #label for each frame marked on the slider
                             ) for k in range(1, len(usedf)-1)], 
                active=1,
                transition= dict(duration= 0 ),
                x=0, # slider starting position  
                y=0, 
                currentvalue=dict(font=dict(size=12), 
                                  prefix='frame: ', 
                                  visible=True, 
                                  xanchor= 'center'
                                 ),  
               len=0.5) #slider length
           ]

fig.update_layout(sliders=sliders)
#fig.update_layout(transition_duration=3)
fig.show()

In [2]:
import feather
dfall = feather.read_dataframe('data_derived' + '/df_feather.file')

In [3]:
# reference:  https://community.plotly.com/t/cumulative-lines-animation-in-python/25707/2
import numpy as np
import pandas as pd
import plotly.graph_objects as go  #plotly 4.0.0rc1

country_list = ['India', 'US']
#country_list = ['Brazil','US']
#country_list = ['Sweden','Norway']
#country_list = ['Italy','Germany']

df_subset = dfall[dfall['Country'].isin(country_list)]
df_subset

df_long = pd.melt(df_subset, id_vars =['date','Country'], value_vars =['daily_case_count'])
df_long

df_wide = df_long.pivot(
    index="date",
    columns="Country",
    values="value"
)

df = df_wide.copy()
df.reset_index(inplace=True)
df


timeval = 'date'
name1 = country_list[0] #'India'
name2 = country_list[1] #'US'
group1 = df[name1].tolist()
group2 = df[name2].tolist()

#x_min = df[timeval][0]
#x_max = df[timeval][len(df)-1]

x_min = '2020-01-01'
x_max = '2021-08-01'
x_max = '2021-10-01'

y_min = min(min(group1), min(group2))
y_max = max(max(group1), max(group2)) + 10000

In [4]:


# ----------------------------------------------------------------
fig = go.Figure(data=[trace1, trace2], frames=frames, layout=layout)

trace1 = go.Scatter(x=df[timeval][:2],
                    y=group1[:2],
                    mode='lines',
                    line=dict(width=1.5),
                    name=name1,
                    
)

trace2 = go.Scatter(x = df[timeval][:2],
                    y = group2[:2],
                    mode='lines',
                    line=dict(width=1.5),
                   name=name2)

increment = 20
frames = [dict(data= [dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group1[:k+increment]),
                      dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group2[:k+increment])],
               traces= [0, 1],  # frames[k]['data'][0]  updates trace1, and   frames[k]['data'][1], trace2 
              )
          for k  in  range(1, len(group1)-1)] 

layout = go.Layout(width=800,
                   height=500,
                   showlegend=True,
                   hovermode='closest',
                   updatemenus=[dict(type='buttons', showactive=False,
                                y=1.05,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Play',
                                              method='animate',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ,dict(type='buttons', showactive=False,
                                y=0.55,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Stop',
                                              method='restyle',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ],
                  )

layout.update(xaxis =dict(range=[x_min, x_max], autorange=False),
              yaxis =dict(range=[y_min-20, y_max+20]), 
              title="COVID-19 Cases Over Time",
             )




fig.add_annotation(text = (f"@reshamas / {today}<br>Source: JHU CSSE"), showarrow=False, x = 0, 
                   y = -0.11, xref='paper', yref='paper', xanchor='left', yanchor='bottom', xshift=-3,
                   yshift=-15, font=dict(size=10, color="grey"), align="left")

fig.show()

NameError: name 'trace1' is not defined

In [None]:
df_subset.head(2)

In [None]:
country_list = ['India', 'US']

df_subset = dfall[dfall['Country'].isin(country_list)]
#df_subset

name1 = country_list[0] #'India'
name2 = country_list[1] #'US'

name1

df_long = pd.melt(df_subset, id_vars =['date','Country'], value_vars =['daily_case_count'])
df_long


df_wide = df_long.pivot(
    index="date",
    columns="Country",
    values="value"
)

df = df_wide.copy()
df.reset_index(inplace=True)
df


timeval = 'date'
name1 = country_list[0] #'India'
name2 = country_list[1] #'US'
group1 = df[name1].tolist()
group2 = df[name2].tolist()

#x_min = df[timeval][0]
#x_max = df[timeval][len(df)-1]

x_min = '2020-01-01'
x_max = '2021-08-01'

y_min = min(min(group1), min(group2))
y_max = max(max(group1), max(group2)) + 10000


In [None]:
country_list = ['India', 'US']
df_subset = dfall[dfall['Country'].isin(country_list)]
df_subset['country'] = df_subset['Country']

In [None]:
# this works, but does not have separate circles for 2 countries

# try adjusting this code with COVID
import plotly.express as px

px.scatter(df_subset, x="daily_case_count", y="cases"
           , animation_frame="days_since_first_case"
           , animation_group="Country"
           , size="cases" 
           ,color="country"
           #, hover_name="Country/Region"
           ,log_x=False, size_max=100, range_x=[-100000,500000], range_y=[-10000000,50000000])


---

## trying to make a general function

In [None]:
dfall = feather.read_dataframe('data_derived' + '/df_feather.file')

In [None]:
#dfall = newdf.copy()
df_use=dfall.copy()

In [None]:
# reference:  https://community.plotly.com/t/cumulative-lines-animation-in-python/25707/2
import numpy as np
import pandas as pd
import plotly.graph_objects as go  #plotly 4.0.0rc1

country_list = ['India', 'US']
#country_list = ['Brazil','US']
#country_list = ['Sweden','Norway']
#country_list = ['Italy','Germany']
#country_list = ['Italy','France']
#country_list = ['Germany','France']


df_subset = df_use[df_use['Country'].isin(country_list)]
df_subset.head()

df_long = pd.melt(df_subset, id_vars =['date','Country'], value_vars =['daily_case_count'])
df_long


df_wide = df_long.pivot(
    index="date",
    columns="Country",
    values="value")
df_wide

df = df_wide.copy()
df.reset_index(inplace=True)

# What df looks like
#df[:5]
#Country	date	India	US
#0	2020-01-22	0	0
#1	2020-01-23	0	0
#2	2020-01-24	0	1
#3	2020-01-25	0	0
#4	2020-01-26	0	3

timeval = 'date'
group1 = country_list[0]
group2 = country_list[1] 
group1_list = df[group1].tolist()
group2_list = df[group2].tolist()

x_min = '2020-01-01'
x_max = '2021-10-01'

y_min = min(min(group1_list), min(group2_list))
y_max = max(max(group1_list), max(group2_list)) + 10000

In [None]:
df[timeval][:2]

In [None]:
# ----------------------------------------------------------------

trace1 = go.Scatter(x=df[timeval][:2],
                    y=group1_list[:2],
                    mode='lines',
                    line=dict(width=1.5),
                    name=group1,
                    
)

trace2 = go.Scatter(x = df[timeval][:2],
                    y = group2_list[:2],
                    mode='lines',
                    line=dict(width=1.5),
                    name=group2)


increment = 20
frames = [dict(data= [dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group1_list[:k+increment]),
                      dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group2_list[:k+increment])],
               traces= [0, 1],  # frames[k]['data'][0]  updates trace1, and   frames[k]['data'][1], trace2 
              )
          for k  in  range(1, len(group1_list)-1)] 

layout = go.Layout(width=800,
                   height=500,
                   showlegend=True,
                   hovermode='closest',
                   updatemenus=[dict(type='buttons', showactive=False,
                                y=1.05,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Play',
                                              method='animate',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ,dict(type='buttons', showactive=False,
                                y=0.55,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Stop',
                                              method='restyle',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ],
                  )

fig = go.Figure(data=[trace1, trace2], frames=frames, layout=layout)

layout.update(xaxis =dict(range=[x_min, x_max], autorange=False),
              yaxis =dict(range=[y_min-20, y_max+20]), 
              title="COVID-19 Cases Over Time",
             )

fig.add_annotation(text = (f"@reshamas / {today}<br>Source: JHU CSSE"), showarrow=False, x = 0, 
                   y = -0.11, xref='paper', yref='paper', xanchor='left', yanchor='bottom', xshift=-3,
                   yshift=-15, font=dict(size=10, color="grey"), align="left")

fig.show()

---

---

# FINAL USE 1:  Scatter Plot Animation 

Documentation Example: https://plotly.com/python/animations/#animated-figures-with-plotly-express

In [None]:
import pandas as pd
import watermark
import feather
import plotly.express as px
%reload_ext watermark

In [None]:
#%watermark -n -v -m -g -iv

In [None]:
path_dataout="data_derived"

dfall = feather.read_dataframe(path_dataout+'/df_feather.file')

In [None]:
# Get data ready

df_use=dfall.copy()

# rename country for label
df_use['Country'].mask(df_use['Country'] == 'United Kingdom', 'UK', inplace=True)

country_list = ['India', 'US','Brazil','Italy', 
               'UK','Turkey', 'Mexico', 'Peru']

df_subset = df_use[df_use['Country'].isin(country_list)]

keep_cols = ['date', 'Country', 'cases', 'daily_case_count', 'daily_death_count','deaths']
df= df_subset.filter(keep_cols)
df.sample(2)

<div class="admonition note alert alert-info">
<h1 class="first admonition-title" style="font-weight: bold;">To highlight</h1>
    
<li class="first admonition-title" style="font-weight: normal;">Bug with the legend</li>
<li class="first admonition-title" style="font-weight: normal;">x and y data, what makes sense?</li>
    
   
<p class="last">Use subset of <tt class="docutils literal">data</tt> .</p>
</div>

In [None]:
# show a blank graph
#fig = px.scatter()
#fig.show()

In [None]:
# Run animated scatterplot

fig = px.scatter(df, x="cases", y="deaths", animation_frame="date", animation_group="Country", 
                 size="cases", color="Country", text="Country", hover_name="Country",
                 #color_discrete_sequence=px.colors.qualitative.G10,
                 #log_x=True, 
                 size_max=65,
                 range_x=[-100000,45000000], range_y=[-100000,900000])

fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 25

fig.update_layout(showlegend=False)
fig.update_layout(title_text="Coronavirus Cases Over Time")

#fig.write_html(f'/Users/reshamashaikh/ds/my_repos/reshamas.github.io/assets/images/covid/1_anim_scatter_covid.html', include_plotlyjs="cdn")

fig.show()

---

# FINAL USE 2:  Line Plot Animation 

Documentation Example: https://community.plotly.com/t/cumulative-lines-animation-in-python/25707/2

In [None]:
import pandas as pd
import watermark
import feather
#import plotly.express as px
import plotly.graph_objects as go

%reload_ext watermark

In [None]:
from datetime import date
today = date.today()

In [None]:
#%watermark -n -v -m -g -iv

In [None]:
dfall = feather.read_dataframe('data_derived' + '/df_feather.file')

In [None]:
# Get data ready

country_list = ['India', 'US']
#country_list = ['Brazil','US']
#country_list = ['Sweden','Norway']
#country_list = ['Italy','Germany']

df_subset = dfall[dfall['Country'].isin(country_list)]
df_subset

df_long = pd.melt(df_subset, id_vars =['date','Country'], value_vars =['daily_case_count'])
df_long

df_wide = df_long.pivot(
    index="date",
    columns="Country",
    values="value"
)

df = df_wide.copy()
df.reset_index(inplace=True)
df


timeval = 'date'
name1 = country_list[0] #'India'
name2 = country_list[1] #'US'
group1 = df[name1].tolist()
group2 = df[name2].tolist()

#x_min = df[timeval][0]
#x_max = df[timeval][len(df)-1]

x_min = '2020-01-01'
x_max = '2021-08-01'
x_max = '2021-10-01'

y_min = min(min(group1), min(group2))
y_max = max(max(group1), max(group2)) + 10000

In [None]:




trace1 = go.Scatter(x=df[timeval][:2],
                    y=group1[:2],
                    mode='lines',
                    line=dict(width=1.5),
                    name=name1,
                    
)

trace2 = go.Scatter(x = df[timeval][:2],
                    y = group2[:2],
                    mode='lines',
                    line=dict(width=1.5),
                   name=name2)

increment = 20
frames = [dict(data= [dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group1[:k+increment]),
                      dict(type='scatter',
                           x=df[timeval][:k+increment],
                           y=group2[:k+increment])],
               traces= [0, 1],  # frames[k]['data'][0]  updates trace1, and   frames[k]['data'][1], trace2 
              )
          for k  in  range(1, len(group1)-1)] 

layout = go.Layout(width=800,
                   height=500,
                   showlegend=True,
                   hovermode='closest',
                   updatemenus=[dict(type='buttons', showactive=False,
                                y=1.05,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Play',
                                              method='animate',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ,dict(type='buttons', showactive=False,
                                y=0.55,
                                x=1.15,
                                xanchor='right',
                                yanchor='bottom',
                                pad=dict(t=0, r=10),
                                buttons=[dict(label='Stop',
                                              method='restyle',
                                              args=[None, 
                                                    dict(frame=dict(duration=0.1, 
                                                                    redraw=False),
                                                         transition=dict(duration=0),
                                                         fromcurrent=True,
                                                         mode='immediate')
                                                   ])
                                        ])
                               ],
                  )

# ----------------------------------------------------------------
fig = go.Figure(data=[trace1, trace2], frames=frames, layout=layout)


layout.update(xaxis =dict(range=[x_min, x_max], autorange=False),
              yaxis =dict(range=[y_min-20, y_max+20]), 
              title="COVID-19 Cases Over Time",
             )




fig.add_annotation(text = (f"@reshamas / {today}<br>Source: JHU CSSE"), showarrow=False, x = 0, 
                   y = -0.11, xref='paper', yref='paper', xanchor='left', yanchor='bottom', xshift=-3,
                   yshift=-15, font=dict(size=10, color="grey"), align="left")

fig.show()