<https://github.com/CSSEGISandData/COVID-19>

In [1]:
import re
from datetime import date, datetime, timedelta

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression

from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
import plotly.express as px

init_notebook_mode(connected=True)

In [2]:
class AmericanDatesTransformer:
    
    PATTERN = r'(\d+)/(\d+)/(\d+)'
    
    def __init__(self, countries_col="Country/Region"):
        self.countries_col = countries_col
        
    @classmethod
    def get_day_month_year(cls, match):
        month = int(match.group(1))
        day = int(match.group(2))
        year = int('20' + match.group(3))
        return str(date(year, month, day))
    
    def transform_raw(self, df_raw):
        df = df_raw.copy()
        rename_dict = {}
        dates_cols = []
        for colname in df.columns.values:
            match = re.search(self.PATTERN, colname)
            if match:
                d = self.get_day_month_year(match)
                rename_dict[colname] = d
                dates_cols.append(d)
            else:
                rename_dict[colname] = colname
        return df.rename(rename_dict, axis=1), dates_cols

    def transpose_df(self, df, dates_cols):
        df_countries = df[[self.countries_col] + dates_cols]
        df_countries = df_countries\
            .groupby(self.countries_col)\
            .sum().transpose()\
            .reset_index().rename({"index": "date"}, axis=1).copy()
        dates = df_countries.date
        date0 = dates[0]
        days = []
        for date in dates:
            days.append(
                (datetime.strptime(date, "%Y-%m-%d") - datetime.strptime(date0, "%Y-%m-%d")).days
            )
        df_countries['day'] = days
        df_countries.columns.name = ""
        return df_countries

    def transform(self, df_raw):
        df, dates_cols = self.transform_raw(df_raw)
        return self.transpose_df(df, dates_cols)
    
class EuropeanDatesTransformer(AmericanDatesTransformer):
    
    PATTERN = r'(\d+)-(\d+)-(\d+)'
        
    @classmethod
    def get_day_month_year(cls, match):
        year = int(match.group(1))
        month = int(match.group(2))
        day = int(match.group(3))
        return str(date(year, month, day))

In [3]:
american_transformer = AmericanDatesTransformer()

df_confirmed_raw = pd.read_csv("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv")
print(df_confirmed_raw.shape)
df_confirmed = american_transformer.transform(df_confirmed_raw)

(501, 66)


In [4]:
df_confirmed.columns.values

array(['date', 'Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola',
       'Antigua and Barbuda', 'Argentina', 'Armenia', 'Australia',
       'Austria', 'Azerbaijan', 'Bahamas, The', 'Bahrain', 'Bangladesh',
       'Barbados', 'Belarus', 'Belgium', 'Benin', 'Bhutan', 'Bolivia',
       'Bosnia and Herzegovina', 'Brazil', 'Brunei', 'Bulgaria',
       'Burkina Faso', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada',
       'Cape Verde', 'Central African Republic', 'Chad', 'Chile', 'China',
       'Colombia', 'Congo (Brazzaville)', 'Congo (Kinshasa)',
       'Costa Rica', "Cote d'Ivoire", 'Croatia', 'Cruise Ship', 'Cuba',
       'Cyprus', 'Czechia', 'Denmark', 'Djibouti', 'Dominica',
       'Dominican Republic', 'East Timor', 'Ecuador', 'Egypt',
       'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia',
       'Eswatini', 'Ethiopia', 'Fiji', 'Finland', 'France',
       'French Guiana', 'Gabon', 'Gambia, The', 'Georgia', 'Germany',
       'Ghana', 'Greece', 'Greenland', 'Grenada', 

In [5]:
df_deaths_raw = pd.read_csv("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Deaths.csv")
df_deaths = american_transformer.transform(df_deaths_raw)
df_deaths.head()

Unnamed: 0,date,Afghanistan,Albania,Algeria,Andorra,Angola,Antigua and Barbuda,Argentina,Armenia,Australia,...,Ukraine,United Arab Emirates,United Kingdom,Uruguay,Uzbekistan,Venezuela,Vietnam,Zambia,Zimbabwe,day
0,2020-01-22,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0
1,2020-01-23,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
2,2020-01-24,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2
3,2020-01-25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3
4,2020-01-26,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4


In [6]:
df_spain_deaths_raw = pd.read_csv("https://raw.githubusercontent.com/datadista/datasets/master/COVID%2019/ccaa_covid19_fallecidos.csv")

In [7]:
european_transformer = EuropeanDatesTransformer(countries_col="CCAA")

df_spain_deaths = european_transformer.transform(df_spain_deaths_raw)
df_spain_deaths.head()

Unnamed: 0,date,Andalucía,Aragón,Asturias,Baleares,C. Valenciana,Canarias,Cantabria,Castilla y León,Castilla-La Mancha,...,Extremadura,Galicia,La Rioja,Madrid,Melilla,Murcia,Navarra,País Vasco,Total,day
0,2020-03-03,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,2020-03-04,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,1,1
2,2020-03-05,0,0,0,0,1,0,0,0,0,...,0,0,0,1,0,0,0,1,3,2
3,2020-03-06,0,1,0,0,1,0,0,0,0,...,0,0,0,2,0,0,0,1,5,3
4,2020-03-07,0,1,0,0,1,0,0,0,0,...,0,0,0,4,0,0,0,1,8,4


In [8]:
all_ccaa = df_spain_deaths.columns.values[2:-1]
all_ccaa

array(['Aragón', 'Asturias', 'Baleares', 'C. Valenciana', 'Canarias',
       'Cantabria', 'Castilla y León', 'Castilla-La Mancha', 'Cataluña',
       'Ceuta', 'Extremadura', 'Galicia', 'La Rioja', 'Madrid', 'Melilla',
       'Murcia', 'Navarra', 'País Vasco', 'Total'], dtype=object)

In [9]:
df_spain_deaths.columns.values

array(['date', 'Andalucía', 'Aragón', 'Asturias', 'Baleares',
       'C. Valenciana', 'Canarias', 'Cantabria', 'Castilla y León',
       'Castilla-La Mancha', 'Cataluña', 'Ceuta', 'Extremadura',
       'Galicia', 'La Rioja', 'Madrid', 'Melilla', 'Murcia', 'Navarra',
       'País Vasco', 'Total', 'day'], dtype=object)

In [10]:
COUNTRIES_POLISH = {
    "Germany": "Niemcy",
    "Italy": "Włochy",
    "Poland": "Polska",
    "Spain": "Hiszpania", 
    'United Kingdom': "Wielka Brytania", 
    "France": "Francja", 
    "Belgium": "Belgia", 
    "Sweden": "Szwecja"
}


def fig_update(fig):
    fig.update_layout(
        xaxis=dict(
            showline=True,
            showgrid=False,
            showticklabels=True,
            linecolor='rgb(204, 204, 204)',
            linewidth=2,
            ticks='outside',
            tickfont=dict(
                family='Arial',
                size=12,
                color='rgb(82, 82, 82)',
            )
        ),
        yaxis=dict(
            showgrid=True,
            zeroline=False,
            showline=False,
            showticklabels=True,
            gridcolor='aliceblue'
        ),
        showlegend=True,
        plot_bgcolor='white'
    )
    return fig

class CountryStat:
    
    def __init__(self, country_name, global_df, start_with=0, max_days=None, language="English"):
        self.country_name = country_name
        self.global_df = global_df
        self.start_with = start_with
        self.max_days = max_days
        self.df = self.global_df[["day", "date", country_name]].rename({country_name: 'total'}, axis=1)
        self.language = language
        
    @property
    def data(self):
        cdf = self.df.copy()
        cdf['new_cases'] = cdf['total'] - cdf['total'].shift(1)    
        cdf = cdf[(cdf['new_cases'] != 0) & (cdf['new_cases'].notnull())]
        cdf = cdf[cdf['total'] >= self.start_with].copy()
        if len(cdf) == 0:
            return cdf
        cdf['day'] = cdf['day'] - cdf.iloc[0]['day']
        if self.max_days is not None:
            cdf = cdf[cdf.day <= self.max_days]
        return cdf
    
    def log_linear_prediction(self, last_day=None):
        data = self.data
        model = LogLinModel(last_day)
        return model.predict(data)
        
    def plotly_lm_plot(self, xcol='day', color=None):
        data = self.data
        if len(data) != 0:
            return PlotData(data, self.country_name, xcol, ycol='total', color=color, language=self.language).plot
        return None
    
    def plotly_log_linear_prediction(self, xcol='day', last_day=None, color=None):
        data = self.log_linear_prediction(last_day)
        if len(data) != 0:
            return PlotPrediction(data, self.country_name, xcol, ycol='total', color=color, last_day=last_day).plot
        return None
    
    def plotly_fig(self, log_scale=False, xcol='day', color=None):
        trace = self.plotly_lm_plot(xcol, color)
        fig = go.Figure()
        if trace is not None:
            fig.add_trace(trace)
            fig = fig_update(fig)
            if log_scale:
                fig.layout.yaxis.update(type='log')
        return fig
    
    def plotly_fig_with_prediction(self, log_scale=False, xcol='day', last_day=None, color=None):
        fig = self.plotly_fig(log_scale, xcol, color)
        pred_trace = self.plotly_log_linear_prediction(xcol, last_day, color)
        fig.add_trace(pred_trace)
        return fig    

    def add_last_date_annotation(self, fig, xcol='day', axy=None):
        data = self.data
        x = data.iloc[-1][xcol]
        y = data.iloc[-1]['total']
        fig.add_annotation(
            x=x,
            y=y,
            text=data.iloc[-1]['date'])
        if axy is None:
            ax = 0
            ay = 40
        else:
            ax = axy[0]
            ay = axy[1]
        fig.update_annotations(dict(
            xref="x",
            yref="y",
            showarrow=True,
            arrowhead=7,
            ax=ax,
            ay=ay
        ))
        return fig
    
class LogLinModel:
    
    def __init__(self, last_day=None):
        self.last_day = last_day
    
    def predict(self, data):
        x_train = data["day"]
        if self.last_day is None:
            self.last_day = x_train.iloc[-1] + 1 + 7
        day0 = data["date"].iloc[0]
        X_train = np.array([x_train]).T
        y_train = np.log(data['total'])
        regr = LinearRegression()       
        regr.fit(X_train, y_train)
        days = list(range(x_train.iloc[0], self.last_day + 1))
        X_test = np.array([days]).T
        y_pred = regr.predict(X_test)
        pred = np.exp(y_pred)
        dt0 = datetime.strptime(day0, '%Y-%m-%d')
        dates = [dt0 + timedelta(days=day) for day in days]
        pred_df = pd.DataFrame({"day": days, "date": dates, "total": pred})       
        return pred_df

class PlotData:
    
    def __init__(self, data, country_name, xcol='day', ycol='total', color=None, language="English"):
        self.data = data
        self._country_name = country_name
        self.xcol = xcol
        self.ycol = ycol 
        self.color = color
        self.language=language
    
    @property
    def country_name(self):
        if self.language == "Polish":
            return COUNTRIES_POLISH.get(self._country_name, self._country_name)
        return self._country_name

    def trace(self, x, y):
        return go.Scatter(x=x, y=y, mode='lines+markers', name=self.country_name)
    
    @property
    def plot(self):
        x=self.data[self.xcol]
        y=self.data[self.ycol]
        trace = self.trace(x, y)
        if self.color is not None:
            trace.line.color = self.color
        return trace
    
class PlotPrediction(PlotData):
    
    def __init__(self, *args, last_day=None, **kwargs):
        self.last_day = last_day
        super().__init__(*args, **kwargs)
        
    def trace(self, x, y):
        return go.Scatter(x=x, y=y, mode='lines', line = dict(dash='dot'), showlegend=False)

In [11]:
spain = CountryStat("Spain", df_confirmed, start_with=100, language="Polish")
fig = spain.plotly_fig(False)
spain.add_last_date_annotation(fig)

In [12]:
spain.log_linear_prediction()

Unnamed: 0,day,date,total
0,0,2020-03-02,136.054476
1,1,2020-03-03,181.813678
2,2,2020-03-04,242.963073
3,3,2020-03-05,324.678844
4,4,2020-03-06,433.878081
5,5,2020-03-07,579.804298
6,6,2020-03-08,774.809879
7,7,2020-03-09,1035.401688
8,8,2020-03-10,1383.638341
9,9,2020-03-11,1848.997429


In [13]:
iplot(spain.plotly_fig_with_prediction(True, 'date', 20, color="LightBlue"))

In [14]:
madrid = CountryStat("Madrid", df_spain_deaths, start_with=1)
madrid.log_linear_prediction(last_day=17)

Unnamed: 0,day,date,total
0,0,2020-03-05,2.481648
1,1,2020-03-06,3.599946
2,2,2020-03-07,5.222179
3,3,2020-03-08,7.575435
4,4,2020-03-09,10.989131
5,5,2020-03-10,15.941132
6,6,2020-03-11,23.124638
7,7,2020-03-12,33.545228
8,8,2020-03-13,48.661618
9,9,2020-03-14,70.589864


In [15]:
madrid.data

Unnamed: 0,day,date,total,new_cases
2,0,2020-03-05,1,1.0
3,1,2020-03-06,2,1.0
4,2,2020-03-07,4,2.0
5,3,2020-03-08,8,4.0
7,5,2020-03-10,21,13.0
8,6,2020-03-11,31,10.0
9,7,2020-03-12,56,25.0
10,8,2020-03-13,81,25.0
11,9,2020-03-14,86,5.0
12,10,2020-03-15,213,127.0


In [16]:
madrid.plotly_fig(True, "date")

In [17]:
madrid.plotly_fig_with_prediction(True, 'date', 17, color="black")

In [18]:
class ConfirmedCasesStat:
    
    COLOR_SCALE = px.colors.sequential.Agsunset
    
    def __init__(self, countries, global_df, start_with=1, max_days=None, prediction=None, 
            annotations=None, language="English"):
        self.countries = countries
        self.global_df = global_df
        self.start_with = start_with
        self.max_days = max_days
        self.language = language
        self.countries_stats = self.get_coutries_stats()
        self.prediction = prediction
        self.annotations = annotations
        
    def get_coutries_stats(self):
        countries_stats = {}
        for country_name in self.countries:
            countries_stats[country_name] =\
                CountryStat(country_name, self.global_df, self.start_with, self.max_days, language=self.language)
        return countries_stats
   
    @property
    def plot_title(self):
        return "COVID-19: Confirmed Cases"
        
    @property
    def xaxis_title(self):
        if self.language == "Polish":
            return f"Ilość dni od {self.start_with} potrwierdzonych przypadków"
        return f"Days since {self.start_with} confirmed cases"
    
    @property
    def log_scale_subtitle(self):
        if self.language == "Polish":
            return "skala logarytmiczna"
        return "log scale"
    
    
    def plotly_lm_plot(self, y_log_scale=False):      
        fig = go.Figure()
        fig = fig_update(fig)
        fig.update_layout(xaxis_title=self.xaxis_title)
        if y_log_scale:
            fig.layout.yaxis.update(type='log')
            fig.update_layout(title=self.plot_title+", " + self.log_scale_subtitle)
            fig.update_yaxes(tickvals=[100, 1000, 10000])
        else:
            fig.update_layout(title=self.plot_title)
        countries_stats = list(self.countries_stats.values())
        for country_idx in range(len( countries_stats)):
            country_stat = countries_stats[country_idx]
            color = self.COLOR_SCALE[country_idx % len(self.COLOR_SCALE)]
            plot = country_stat.plotly_lm_plot(color=color)
            if plot is not None:
                fig.add_trace(plot)
                if self.prediction is not None:
                    if (self.prediction == True) or self.prediction[country_idx]:
                        pred_plot = country_stat.plotly_log_linear_prediction(last_day=self.max_days, color=color)
                        fig.add_trace(pred_plot)
                if self.annotations is not None:
                    if self.annotations is True:
                        country_stat.add_last_date_annotation(fig)
                    else:
                        country_stat.add_last_date_annotation(fig, axy=self.annotations[country_idx])
        return fig

class DeathsStat(ConfirmedCasesStat):
    
    @property
    def plot_title(self):
        if self.language == "Polish":
            return f"Koronawirus: przypadki śmiertelne"
        return "COVID-19: Deaths"
        
    @property
    def xaxis_title(self):
        if self.language == "Polish":
            if self.start_with == 1:
                return f"Ilość dni od pierwszego przypadku śmiertelnego"
            return f"Ilość dni od {self.start_with} przypadków śmiertelnych"
        if self.start_with == 1:
            return "Days since first death"
        return f"Days since {self.start_with} deaths"

In [31]:
countries = ConfirmedCasesStat(
    ["Poland", "Spain"], 
    df_confirmed,
    start_with=25,
    max_days=15, 
    prediction=[True, False],
    annotations=[(3, 40), (40, -1)], 
    language="Polish"
)
iplot(countries.plotly_lm_plot())

In [20]:
'United Kingdom', "France" "Spain","Japan", "France",

('United Kingdom', 'FranceSpain', 'Japan', 'France')

In [44]:
countries = ConfirmedCasesStat(
    ["Italy", "Spain", "Mexico"], 
    df_confirmed, 
    start_with=10
)
countries.COLOR_SCALE = px.colors.sequential.Agsunset
iplot(countries.plotly_lm_plot(True))

In [22]:
"US",  'United Kingdom', "Italy",  "Germany", "Spain", 'Korea, South', 'United Kingdom'
countries = ConfirmedCasesStat(
    ["Chile"], 
    df_confirmed, 
    start_with=100
)
countries.COLOR_SCALE = px.colors.sequential.Agsunset
iplot(countries.plotly_lm_plot(False))

In [23]:
"Italy", "Spain", "France", 
countries = DeathsStat(
    ["Chile"], 
    df_deaths,
    start_with=1, 
    language="Polish"
)
iplot(countries.plotly_lm_plot(False))

In [24]:
countries = DeathsStat(
    ["Italy",  "Germany", "Spain", "France", "Poland"], 
    df_deaths,
    start_with=5,
    language="Polish"
)
iplot(countries.plotly_lm_plot(False))

In [32]:
"France", 'Korea, South', 'United Kingdom', 'Japan', "France", "Belgium", "US", "Chile", "Italy",  "Germany",
countries = DeathsStat(
    ["Spain", "Poland", 'United Kingdom', "Germany"], 
    df_deaths,
    start_with=10,
    language="Polish"
)
iplot(countries.plotly_lm_plot(True))

In [26]:
print(countries.countries_stats["Spain"].data)
print(countries.countries_stats["Germany"].data)

    day        date   total  new_cases
45    0  2020-03-07    10.0        5.0
46    1  2020-03-08    17.0        7.0
47    2  2020-03-09    28.0       11.0
48    3  2020-03-10    35.0        7.0
49    4  2020-03-11    54.0       19.0
50    5  2020-03-12    55.0        1.0
51    6  2020-03-13   133.0       78.0
52    7  2020-03-14   195.0       62.0
53    8  2020-03-15   289.0       94.0
54    9  2020-03-16   342.0       53.0
55   10  2020-03-17   533.0      191.0
56   11  2020-03-18   623.0       90.0
57   12  2020-03-19   830.0      207.0
58   13  2020-03-20  1043.0      213.0
59   14  2020-03-21  1375.0      332.0
60   15  2020-03-22  1772.0      397.0
    day        date  total  new_cases
53    0  2020-03-15   11.0        2.0
54    1  2020-03-16   17.0        6.0
55    2  2020-03-17   24.0        7.0
56    3  2020-03-18   28.0        4.0
57    4  2020-03-19   44.0       16.0
58    5  2020-03-20   67.0       23.0
59    6  2020-03-21   84.0       17.0
60    7  2020-03-22   94.0       

In [27]:
countries = DeathsStat(
    ["Italy",  "Germany", "Spain", "France", ], 
    df_deaths,
    start_with=10,
    prediction=[False, True, True, True],
    annotations=True
)
iplot(countries.plotly_lm_plot(False))

In [28]:
countries = DeathsStat(
    ["Italy",  "Germany", "Spain", "France"], 
    df_deaths,
    start_with=5,
    max_days=19,
    prediction=[False, True, True, True]
)
iplot(countries.plotly_lm_plot(False))

In [29]:
ccaa = DeathsStat(
    all_ccaa, 
    df_spain_deaths,
     start_with=5,
)
ccaa.COLOR_SCALE = px.colors.sequential.Agsunset
iplot(ccaa.plotly_lm_plot(True))

In [30]:
ccaa = DeathsStat(
        ['Cataluña',
        'La Rioja', 'Madrid', 'País Vasco'], 
    df_spain_deaths,
    start_with=1,
    max_days=13,
    prediction=True
)
ccaa.COLOR_SCALE = px.colors.sequential.Agsunset
iplot(ccaa.plotly_lm_plot(False))