In [None]:
# COVID-19 infections per country
# Copyright 2020 Denis Meyer
# Data source: https://github.com/CSSEGISandData/COVID-19

In [None]:
import io
import requests
import os
import pandas as pd
import matplotlib.pyplot as plt

from heapq import nlargest

In [None]:
# Data source
DATA_CSV_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'

# Cache file name
CSV_FILENAME = 'time_series_19-covid-Confirmed.csv'

In [None]:
# Ignores file cache if 'True'
REFRESH_DATA = True

# Boolean flag whether to create a plot containing all countries
PLOT_ALL_COUNTRIES = False # May take some time in the current implementation

# Boolean flag whether to create a plot containing specific countries
PLOT_SPECIFIC_COUNTRIES = False
PLOT_COUNTRIES = ['Germany', 'Spain', 'Italy']

# Boolean flag whether to create a plot containing just the n highest countries
PLOT_HIGHEST_COUNTRIES = True
NR_OF_HIGHEST_COUNTRIES = 10

PLOT_SIZE=(20, 10)
PLOT_TITLE = 'COVID-19 infections per country'
PLOT_LABEL_X = 'Date'
PLOT_LABEL_Y = 'Nr of infections'

In [None]:
def download_csv_data(url):
    '''Downloads the data
    
    :param url: The data source URL
    '''
    if not url:
        return None

    s = requests.get(url).content
    return pd.read_csv(io.StringIO(s.decode('utf-8')))

def get_data(dir_csv, filename_csv, url, refresh_data=False):
    '''Retrieves the data, either from file or download
    
    :param dir_csv: The CSV directory
    :param dir_csv: The CSV filename
    :param dir_csv: The URL
    :param dir_csv: Boolean whether to refresh the data
    '''
    df = None

    csv_file = os.path.join(csv_dir, CSV_FILENAME)

    file_loaded = False
    loaded = False
    try:
        if not refresh_data:
            print('Not refreshing data')
            print('Trying to load from file "{}"'.format(csv_file))
            df = pd.read_csv('{}'.format(csv_file), encoding='utf-8')
            file_loaded = True
            loaded = True
            print('Successfully loaded data from file "{}"'.format(csv_file))
        else:
            print('Refreshing data')
    except FileNotFoundError:
        df = None
    if not file_loaded:
        print('Downloading fresh data from "{}"...'.format(url))
        df = download_csv_data(DATA_CSV_URL)
        loaded = True
        print('Trying to save to file "{}"'.format(csv_file))
        df.to_csv('{}'.format(csv_file), encoding='utf-8', index=False)
        print('Successfully saved to file "{}"'.format(csv_file))

    return df

In [None]:
csv_dir = os.getcwd()

df = get_data(csv_dir, CSV_FILENAME, DATA_CSV_URL, refresh_data=REFRESH_DATA)

# Drop unnecessary columns
df = df.drop(['Province/State', 'Lat', 'Long'], axis=1)

In [None]:
# Group by Country/Region, sum the values and reset the index
df_grouped_summed = df.groupby('Country/Region').sum().reset_index()
dates = list(df_grouped_summed.columns.values)[1:]
print('Plotting data from {} to {}'.format(dates[0], dates[-1]))

In [None]:
# Plot: All countries

plot_name = 'All countries'

if PLOT_ALL_COUNTRIES:
    print('Plotting "{}"'.format(plot_name))

    countries = df_grouped_summed['Country/Region']

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {}'.format(PLOT_TITLE, plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))

    plt.show()

In [None]:
# Plot: Specific countries

plot_name = 'Specific countries'

if PLOT_SPECIFIC_COUNTRIES:
    print('Plotting "{}"'.format(plot_name))

    countries = PLOT_COUNTRIES

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {}'.format(PLOT_TITLE, plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))

    plt.show()

In [None]:
# Plot: Countries with highest infection rates

plot_name = '{} Countries with highest infection rates'.format(NR_OF_HIGHEST_COUNTRIES)

if PLOT_HIGHEST_COUNTRIES:
    print('Plotting "{}"'.format(plot_name))
    
    # Calculate the n highest countries
    dict_highest_all = {}
    countries = df_grouped_summed['Country/Region']
    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        dict_highest_all[cr] = df_melted.max().Value

    # Extract the n highest country names
    countries = nlargest(NR_OF_HIGHEST_COUNTRIES, dict_highest_all, key=dict_highest_all.get)

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {}'.format(PLOT_TITLE, plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.show()