In [None]:
# COVID-19 infections per country
# Copyright 2020 Denis Meyer

In [None]:
import logging
import io
import requests
import os
import datetime

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import scipy.optimize
import numpy as np

from heapq import nlargest

In [None]:
# Settings


### Misc settings ###

# Ignores file cache if 'True', tries to load data for the current day from file cache otherwise
FORCE_REFRESH_DATA = False
# Ignores all plotting flags and plots (except "all countries")
FORCE_PLOT = False
# Ignores all flags "..._SAVE_PLOT_TO_FILE" and saves to file (except "all countries") - plotting must be activated
FORCE_SAVE_PLOT_TO_FILE = False

# Image path to save the created plot images to (relative to the current directory)
PLOT_IMAGE_PATH = 'images'

# Directory names
CSV_SUBDIR_NAME = 'data'
CSV_INFECTIONS_SUBDIR_NAME = 'infections'
CSV_DEATHS_SUBDIR_NAME = 'deaths'

# Cache file name
CSV_INFECTIONS_FILENAME = 'time_series_19-covid-Confirmed-{}.csv'
CSV_DEATHS_FILENAME = 'time_series_19-covid-Deaths-{}.csv'

# Plot configuration
PLOT_SIZE=(20, 15)
PLOT_TITLE = 'COVID-19 infections per country'
PLOT_LABEL_X = 'Date'
PLOT_LABEL_Y = 'Nr of infections'
PLOT_LABEL_DEATHS_X = 'Date'
PLOT_LABEL_DEATHS_Y = 'Nr of deaths'

# Logging configuration
LOGGING_LOGLEVEL = logging.INFO
LOGGING_DATE_FORMAT = '%d-%m-%Y %H:%M:%S'
LOGGING_FORMAT = '[%(asctime)s] [%(levelname)-5s] [%(module)-20s:%(lineno)-4s] %(message)s'


### Data plotting settings ###

# Boolean flag whether to create a plot containing all countries
PLOT_ALL_COUNTRIES = False # May take some time in the current implementation
# Boolean flag whether to save the plot to file
PLOT_ALL_COUNTRIES_SAVE_PLOT_TO_FILE = False
# Plot start and end day, use a number <= 0 as end day to plot til last day
ALL_COUNTRIES_START_DAY = -1
ALL_COUNTRIES_END_DAY = -1

# Boolean flag whether to create a plot containing specific countries
PLOT_SPECIFIC_COUNTRIES = False
# Boolean flag whether to save the plot to file
PLOT_SPECIFIC_COUNTRIES_SAVE_PLOT_TO_FILE = False
PLOT_COUNTRIES = ['Germany', 'Spain', 'Iran', 'US', 'France', 'Korea, South', 'Switzerland', 'United Kingdom']
# Plot start and end day, use a number <= 0 as end day to plot til last day
SPECIFIC_COUNTRIES_START_DAY = 40
SPECIFIC_COUNTRIES_END_DAY = -1

# Boolean flag whether to create a plot containing just the n countries with highest number of infections
PLOT_HIGHEST_COUNTRIES = False
# Boolean flag whether to save the plot to file
PLOT_HIGHEST_COUNTRIES_SAVE_PLOT_TO_FILE = False
NR_OF_HIGHEST_COUNTRIES = 10
# Plot start and end day, use a number <= 0 as end day to plot til last day
HIGHEST_COUNTRIES_START_DAY = 40
HIGHEST_COUNTRIES_END_DAY = -1

# Boolean flag whether to create a plot containing just the n countries with highest number of deaths
PLOT_HIGHEST_DEATHS_COUNTRIES = False
# Boolean flag whether to save the plot to file
PLOT_HIGHEST_DEATHS_COUNTRIES_SAVE_PLOT_TO_FILE = False
NR_OF_HIGHEST_DEATHS_COUNTRIES = 10
# Plot start and end day, use a number <= 0 as end day to plot til last day
HIGHEST_DEATHS_COUNTRIES_START_DAY = 40
HIGHEST_DEATHS_COUNTRIES_END_DAY = -1


### Curve fitting settings ###

# Boolean flag whether to plot days as x-lavel instead of dates
CURVE_FIT_PLOT_DAYS_AS_X_LABEL = False

# Boolean flag whether to create a plot with a curve fit for a specific country
PLOT_CURVE_FIT = False
# Boolean flag whether to save the plot to file
PLOT_CURVE_FIT_SAVE_PLOT_TO_FILE = False
CURVE_FIT_COUNTRY = 'Germany'
# Check best fit data for start and end day
CURVE_FIT_START_DAY = 51
CURVE_FIT_END_DAY = 58
# Plot start and end day, use a number <= 0 as end day to plot til last day
CURVE_FIT_PLOT_START_DAY = 40
CURVE_FIT_PLOT_END_DAY = -1
# For debugging and parameter tweaking purposes: Activate to plot only the data in the full range
CURVE_FIT_PLOT_RAW_DATA_ONLY = False

# Boolean flag whether to create a plot with a multi curve fit for a specific country
PLOT_MULTI_CURVE_FIT = False
# Boolean flag whether to save the plot to file
PLOT_MULTI_CURVE_FIT_SAVE_PLOT_TO_FILE = False
MULTI_CURVE_FIT_COUNTRY = 'Germany'
# Check best fit data for start and end day
MULTI_CURVE_FIT_DAYS = [
    {
        'start_day': 40,
        'end_day': 50,
        'plot_start_day': 40,
        'plot_end_day': 52,
        'color': 'lightskyblue'
    },
    {
        'start_day': 51,
        'end_day': 58,
        'plot_start_day': 49,
        'plot_end_day': 64,
        'color': 'blue'#,
        #'fit_func': lambda x, a, b, c: a * np.exp(b * x)
    },
    {
        'start_day': 58,
        'end_day': 65,
        'plot_start_day': 56,
        'plot_end_day': -1,
        'color': 'steelblue'
    }
]
# For debugging and parameter tweaking purposes: Activate to plot only the data in the full range
MULTI_CURVE_FIT_PLOT_RAW_DATA_ONLY = False

# Boolean flag whether to create a plot with a curve fit for specific countries
PLOT_CURVE_FIT_MULTI = False
# Boolean flag whether to save the plot to file
PLOT_CURVE_FIT_MULTI_SAVE_PLOT_TO_FILE = False
# Country data (see PLOT_CURVE_FIT for reference)
# Good fitting parameters (as of 2020-03-26)
# Named colors: https://matplotlib.org/3.1.0/gallery/color/named_colors.html
CURVE_FIT_MULTI_COUNTRIES = [
    {
        'name': 'Italy',
        'start_day': 51,
        'end_day': 60,
        'color': 'tomato'
    },
    {
        'name': 'US',
        'start_day': 50,
        'end_day': 59,
        'color': 'seagreen'
    },
    {
        'name': 'Spain',
        'start_day': 51,
        'end_day': 59,
        'color': 'gold'
    },
    {
        'name': 'Germany',
        'start_day': 51,
        'end_day': 58,
        'color': 'lightskyblue'#,
        #'fit_func': lambda x, a, b, c: a * np.exp(b * x)
    }
]
# Plot start and end day, use a number <= 0 as end day to plot til last day
CURVE_FIT_MULTI_PLOT_START_DAY = 40
CURVE_FIT_MULTI_PLOT_END_DAY = -1

In [None]:
def initialize_logger(loglevel, frmt, datefmt):
    '''Initializes the logger
    
    :param loglevel: The log level
    :param frmt: The log format
    :param datefmt: The date format
    '''
    logging.basicConfig(level=loglevel,
                        format=frmt,
                        datefmt=datefmt)

def download_csv_data(url):
    '''Downloads the data
    
    :param url: The data source URL
    '''
    if not url:
        return None

    s = requests.get(url).content
    return pd.read_csv(io.StringIO(s.decode('utf-8')))

def get_data(dir_csv, csv_subpath, filename_csv, url, force_refresh_data=False):
    '''Retrieves the data, either from file or download

    :param dir_csv: The CSV directory
    :param csv_subpath: The CSV sub-directory
    :param filename_csv: The CSV filename
    :param url: The URL
    :param force_refresh_data: Boolean whether to force refreshing the data
    '''
    df = None

    path_data = os.path.join(dir_csv, csv_subpath)
    if not os.path.exists(path_data):
        os.makedirs(path_data)
    csv_file = os.path.join(path_data, filename_csv)

    file_loaded = False
    try:
        if not force_refresh_data:
            logging.info('Not force refreshing data')
            logging.info('Trying to load from file "{}"'.format(csv_file))
            df = pd.read_csv('{}'.format(csv_file), encoding='utf-8')
            file_loaded = True
            logging.info('Successfully loaded data from file "{}"'.format(csv_file))
        else:
            logging.info('Force refreshing data')
    except FileNotFoundError:
        df = None
    if not file_loaded:
        logging.info('Downloading fresh data from "{}"...'.format(url))
        df = download_csv_data(url)
        logging.info('Trying to save to file "{}"'.format(csv_file))
        df.to_csv('{}'.format(csv_file), encoding='utf-8', index=False)
        logging.info('Successfully saved to file "{}"'.format(csv_file))

    return df

def get_clean_image_name(name):
    '''Returns a clean image name
    
    :param name: The image name
    '''
    return name.replace(',', '-').replace(' ', '')

def save_plot(curr_dir, fig, path, date, name):
    '''Saves the plot of the fig to "<current_dir>/<name>"

    :param curr_dir: The directory
    :param fig: The figure
    :param path: The image path
    :param date: The date
    :param name: The name of the image
    '''
    try:
        path_data = os.path.join(curr_dir, path if path else 'images', str(date.date()) if date else 'unknown')
        if not os.path.exists(path_data):
            os.makedirs(path_data)
        full_path = os.path.join(path_data, get_clean_image_name(name))
        logging.info('Saving plot to "{}"'.format(full_path))
        fig.savefig(full_path)
        return True
    except Exception as e:
        logging.info('Could not save plot to file "{}" in path "{}": {}'.format(name, path, e))
        return False

def func_fit(x, a, b, c=1.0):
    '''Curve fitting fitting function
    
    :param x: x
    :param a: a
    :param b: b
    :param c: c
    '''
    return np.exp(a + b * x)

def func_sigma(y):
    '''Curve fitting error function
    
    :param y: y
    '''
    return np.sqrt(y)

In [None]:
# initialize_logger(LOGGING_LOGLEVEL, LOGGING_DATE_FORMAT, LOGGING_FORMAT)
# Logging + Jupyter is currently not working together (on my machine...)
logging.info = print

sns.set(palette='muted')

In [None]:
# Data source:
# 2019 Novel Coronavirus COVID-19 (2019-nCoV) Data Repository by Johns Hopkins CSSE
# https://github.com/CSSEGISandData/COVID-19
DATA_CONFIRMED_CSV_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv'
DATA_DEATHS_CSV_URL = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv'

In [None]:
csv_dir = os.path.join(os.getcwd(), CSV_SUBDIR_NAME)
current_date_str = datetime.date.today().strftime('%Y-%m-%d')

In [None]:
# Load 'confirmed' data

df = get_data(csv_dir, CSV_INFECTIONS_SUBDIR_NAME, CSV_INFECTIONS_FILENAME.format(current_date_str), DATA_CONFIRMED_CSV_URL, force_refresh_data=FORCE_REFRESH_DATA)

# Drop unnecessary columns
df = df.drop(['Province/State', 'Lat', 'Long'], axis=1)

In [None]:
# Load 'deaths' data

df2 = get_data(csv_dir, CSV_DEATHS_SUBDIR_NAME, CSV_DEATHS_FILENAME.format(current_date_str), DATA_DEATHS_CSV_URL, force_refresh_data=FORCE_REFRESH_DATA)

# Drop unnecessary columns
df2 = df2.drop(['Province/State', 'Lat', 'Long'], axis=1)

In [None]:
# Group by Country/Region, sum the values and reset the index
df_grouped_summed = df.groupby('Country/Region').sum().reset_index()
dates = list(df_grouped_summed.columns.values)[1:]
date_first = datetime.datetime.strptime(dates[0], '%m/%d/%y')
date_last = datetime.datetime.strptime(dates[-1], '%m/%d/%y')
logging.info('Working with "confirmed" data from {} to {}'.format(date_first.date(), date_last.date()))

In [None]:
# Group by Country/Region, sum the values and reset the index
df2_grouped_summed = df2.groupby('Country/Region').sum().reset_index()
dates2 = list(df2_grouped_summed.columns.values)[1:]
date_first2 = datetime.datetime.strptime(dates2[0], '%m/%d/%y')
date_last2 = datetime.datetime.strptime(dates2[-1], '%m/%d/%y')
logging.info('Working with "deaths" data from {} to {}'.format(date_first2.date(), date_last2.date()))

In [None]:
# Gather all countries
all_countries_list = list(df_grouped_summed['Country/Region'])

In [None]:
# Plot: All countries

plot_name = 'All countries'

if PLOT_ALL_COUNTRIES:
    logging.info('Plotting "{}"'.format(plot_name))

    countries = df_grouped_summed['Country/Region']

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    # Validate plot start and end days
    plot_day_end = ALL_COUNTRIES_END_DAY if (ALL_COUNTRIES_END_DAY > 0 and ALL_COUNTRIES_END_DAY < len(dates)) else len(dates)
    plot_day_start = ALL_COUNTRIES_START_DAY if ALL_COUNTRIES_START_DAY > 0 and ALL_COUNTRIES_START_DAY < plot_day_end else 0
    logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))

    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')[plot_day_start:plot_day_end]
        df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))

    plt.show()

    if PLOT_ALL_COUNTRIES_SAVE_PLOT_TO_FILE:
        save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, 'All-Countries.png')

    plt.close(fig)

In [None]:
# Plot: Specific countries

plot_name = 'Specific countries: "{}"'.format(', '.join(PLOT_COUNTRIES))

if FORCE_PLOT or PLOT_SPECIFIC_COUNTRIES:
    logging.info('Plotting "{}"'.format(plot_name))

    countries = PLOT_COUNTRIES

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    # Validate plot start and end days
    plot_day_end = SPECIFIC_COUNTRIES_END_DAY if (SPECIFIC_COUNTRIES_END_DAY > 0 and SPECIFIC_COUNTRIES_END_DAY < len(dates)) else len(dates)
    plot_day_start = SPECIFIC_COUNTRIES_START_DAY if SPECIFIC_COUNTRIES_START_DAY > 0 and SPECIFIC_COUNTRIES_START_DAY < plot_day_end else 0
    logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))

    for cr in countries:
        if cr in all_countries_list:
            df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
            # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
            df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')[plot_day_start:plot_day_end]
            df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)
        else:
            logging.info('Could not find given country "{}"'.format(cr))

    ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))

    plt.show()

    if FORCE_SAVE_PLOT_TO_FILE or PLOT_SPECIFIC_COUNTRIES_SAVE_PLOT_TO_FILE:
        save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, 'Specific-Countries-{}.png'.format('-'.join(PLOT_COUNTRIES)))

    plt.close(fig)

In [None]:
# Plot: Countries with highest infection rates

plot_name = '{} Countries with highest infection rates'.format(NR_OF_HIGHEST_COUNTRIES)

if FORCE_PLOT or PLOT_HIGHEST_COUNTRIES:
    logging.info('Plotting "{}"'.format(plot_name))

    # Validate plot start and end days
    plot_day_end = HIGHEST_COUNTRIES_END_DAY if (HIGHEST_COUNTRIES_END_DAY > 0 and HIGHEST_COUNTRIES_END_DAY < len(dates)) else len(dates)
    plot_day_start = HIGHEST_COUNTRIES_START_DAY if HIGHEST_COUNTRIES_START_DAY > 0 and HIGHEST_COUNTRIES_START_DAY < plot_day_end else 0
    logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))
    
    # Calculate the n highest countries
    dict_highest_all = {}
    countries = df_grouped_summed['Country/Region']
    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        dict_highest_all[cr] = df_melted.max().Value

    # Extract the n highest country names
    countries = nlargest(NR_OF_HIGHEST_COUNTRIES, dict_highest_all, key=dict_highest_all.get)

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    for cr in countries:
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')[plot_day_start:plot_day_end]
        df_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.show()

    if FORCE_SAVE_PLOT_TO_FILE or PLOT_HIGHEST_COUNTRIES_SAVE_PLOT_TO_FILE:
        save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, '{}-Countries-with-highest-infection-rates.png'.format(NR_OF_HIGHEST_COUNTRIES))

    plt.close(fig)

In [None]:
# Plot: Countries with highest number of deaths

plot_name = '{} Countries with highest number of deaths'.format(NR_OF_HIGHEST_COUNTRIES)

if FORCE_PLOT or PLOT_HIGHEST_DEATHS_COUNTRIES:
    logging.info('Plotting "{}"'.format(plot_name))

    # Validate plot start and end days
    plot_day_end = HIGHEST_DEATHS_COUNTRIES_END_DAY if (HIGHEST_DEATHS_COUNTRIES_END_DAY > 0 and HIGHEST_DEATHS_COUNTRIES_END_DAY < len(dates)) else len(dates)
    plot_day_start = HIGHEST_DEATHS_COUNTRIES_START_DAY if HIGHEST_DEATHS_COUNTRIES_START_DAY > 0 and HIGHEST_DEATHS_COUNTRIES_START_DAY < plot_day_end else 0
    logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))

    # Calculate the n highest countries
    dict_highest_all = {}
    countries = df2_grouped_summed['Country/Region']
    for cr in countries:
        df2_tmp = df2_grouped_summed[df2_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df2_melted = df2_tmp.melt(id_vars=df2_tmp.columns.values[:1], var_name='Date', value_name='Value')
        dict_highest_all[cr] = df2_melted.max().Value

    # Extract the n highest country names
    countries = nlargest(NR_OF_HIGHEST_COUNTRIES, dict_highest_all, key=dict_highest_all.get)

    # Plot
    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    for cr in countries:
        df2_tmp = df2_grouped_summed[df2_grouped_summed['Country/Region']==cr]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df2_melted = df2_tmp.melt(id_vars=df2_tmp.columns.values[:1], var_name='Date', value_name='Value')[plot_day_start:plot_day_end]
        df2_melted.plot(kind='line', x='Date', y='Value', ax=ax, label=cr)

    ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_DEATHS_X)
    ax.set_ylabel(PLOT_LABEL_DEATHS_Y)

    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.show()

    if FORCE_SAVE_PLOT_TO_FILE or PLOT_HIGHEST_COUNTRIES_SAVE_PLOT_TO_FILE:
        save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, '{}-Countries-with-highest-number-of-deaths.png'.format(NR_OF_HIGHEST_COUNTRIES))

    plt.close(fig)

In [None]:
# Curve fit

plot_name = 'Curve fit for country "{}"'.format(CURVE_FIT_COUNTRY)

if FORCE_PLOT or PLOT_CURVE_FIT:
    logging.info('Plotting "{}"'.format(plot_name))

    if CURVE_FIT_COUNTRY in all_countries_list:
        fig, ax = plt.subplots(figsize=PLOT_SIZE)
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==CURVE_FIT_COUNTRY]
        df2_tmp = df2_grouped_summed[df2_grouped_summed['Country/Region']==CURVE_FIT_COUNTRY]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        df2_melted = df2_tmp.melt(id_vars=df2_tmp.columns.values[:1], var_name='Date', value_name='Value')

        # Validate start and end days
        day_end = len(df_melted.Value)
        day_start = 0
        if not CURVE_FIT_PLOT_RAW_DATA_ONLY:
            day_end = CURVE_FIT_END_DAY if (CURVE_FIT_END_DAY > 0 and CURVE_FIT_END_DAY < len(df_melted.Value)) else len(df_melted.Value)
            day_start = CURVE_FIT_START_DAY if CURVE_FIT_START_DAY > 0 and CURVE_FIT_START_DAY < day_end else 0
            logging.info('Fitting to days [{}, {}]'.format(day_start, day_end))
        # Validate plot start and end days
        plot_day_end = CURVE_FIT_PLOT_END_DAY if (CURVE_FIT_PLOT_END_DAY > 0 and CURVE_FIT_PLOT_END_DAY < len(df_melted.Value)) else len(df_melted.Value)
        plot_day_start = CURVE_FIT_PLOT_START_DAY if CURVE_FIT_PLOT_START_DAY > 0 and CURVE_FIT_PLOT_START_DAY < plot_day_end else 0
        logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))

        vals_x = np.linspace(0, len(df_melted.Value), num = len(df_melted.Value))[day_start:day_end]
        vals_y = list(df_melted.Value)[day_start:day_end]
        vals_sigma = [func_sigma(y) for y in vals_y]

        vals_x_to_end = [t for t in range(plot_day_start, plot_day_end)]
        vals_y_to_end = list(df_melted.Value)[plot_day_start:plot_day_end]
        vals_y2_to_end = list(df2_melted.Value)[plot_day_start:plot_day_end]

        if not CURVE_FIT_PLOT_RAW_DATA_ONLY:
            # Scipy curve fit
            try:
                params, params_cov = scipy.optimize.curve_fit(func_fit, xdata=vals_x, ydata=vals_y, sigma=vals_sigma)
                vals_y_fit = [func_fit(x, params[0], params[1]) for x in vals_x_to_end]
                plt.plot(vals_x_to_end, vals_y_fit, '--', color ='blue', label ='Fit (days {}-{})'.format(day_start, day_end))
            except:
                logging.info('Could not find curve fit')
        else:
            logging.info('Just logging data')

        # Plot deaths
        plt.plot(vals_x_to_end, vals_y2_to_end, '-', color ='red', label ='Deaths')

        # Plot data
        plt.plot(vals_x_to_end, vals_y_to_end, 'o', color ='green', label ='Data')

        ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
        ax.set_xlabel(PLOT_LABEL_X)
        ax.set_ylabel(PLOT_LABEL_Y)

        # Calculate ticks and labels (=the dates on the x-axis)
        ticks = [t for t in range(plot_day_start, plot_day_end)][::2]
        if plot_day_end not in ticks and len(ticks) % 2 == 0:
            ticks = ticks + [plot_day_end]
        if CURVE_FIT_PLOT_DAYS_AS_X_LABEL:
            labels = [d for d in ticks]
        else:
            labels = [str((date_first + datetime.timedelta(days=d)).date()) for d in ticks]
        plt.xticks(ticks=ticks, labels=labels)

        plt.legend(loc='upper left')
        plt.show()

        if FORCE_SAVE_PLOT_TO_FILE or PLOT_CURVE_FIT_SAVE_PLOT_TO_FILE:
            save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, 'Curve-Fit-{}.png'.format(CURVE_FIT_COUNTRY))

        plt.close(fig)
    else:
        logging.info('Could not find given country "{}"'.format(CURVE_FIT_COUNTRY))

In [None]:
# Multi curve fit

plot_name = 'Multi curve fit for country "{}"'.format(CURVE_FIT_COUNTRY)

if FORCE_PLOT or PLOT_MULTI_CURVE_FIT:
    logging.info('Plotting "{}"'.format(plot_name))

    if MULTI_CURVE_FIT_COUNTRY in all_countries_list:
        fig, ax = plt.subplots(figsize=PLOT_SIZE)
        df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==MULTI_CURVE_FIT_COUNTRY]
        df2_tmp = df2_grouped_summed[df2_grouped_summed['Country/Region']==CURVE_FIT_COUNTRY]
        # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
        df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')
        df2_melted = df2_tmp.melt(id_vars=df2_tmp.columns.values[:1], var_name='Date', value_name='Value')

        lowest_start_day = len(df_melted.Value)
        highest_end_day = 0
        for data in MULTI_CURVE_FIT_DAYS:
            # Validate start and end days
            day_end = data['end_day'] if (data['end_day'] > 0 and data['end_day'] < len(df_melted.Value)) else len(df_melted.Value)
            day_start = data['start_day'] if data['start_day'] > 0 and data['start_day'] < day_end else 0
            logging.info('Fitting to days [{}, {}]'.format(day_start, day_end))
        
            # Validate plot start and end days
            plot_day_end = data['plot_end_day'] if (data['plot_end_day'] > 0 and data['plot_end_day'] < len(df_melted.Value)) else len(df_melted.Value)
            plot_day_start = data['plot_start_day'] if data['plot_start_day'] > 0 and data['plot_start_day'] < plot_day_end else 0
            logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))
            lowest_start_day = plot_day_start if plot_day_start < lowest_start_day else lowest_start_day
            highest_end_day = plot_day_end if plot_day_end > highest_end_day else highest_end_day

            vals_x = np.linspace(0, len(df_melted.Value), num = len(df_melted.Value))[day_start:day_end]
            vals_y = list(df_melted.Value)[day_start:day_end]
            vals_sigma = [func_sigma(y) for y in vals_y]

            vals_x_to_end = [t for t in range(plot_day_start, plot_day_end)]
            vals_y_to_end = list(df_melted.Value)[plot_day_start:plot_day_end]
            vals_y2_to_end = list(df2_melted.Value)[plot_day_start:plot_day_end]

            if not MULTI_CURVE_FIT_PLOT_RAW_DATA_ONLY:
                # Scipy curve fit
                try:
                    func = data['fit_func'] if 'fit_func' in data else func_fit
                    params, params_cov = scipy.optimize.curve_fit(func, xdata=vals_x, ydata=vals_y, sigma=vals_sigma)
                    vals_y_fit = [func(x, params[0], params[1], 0) for x in vals_x_to_end]
                    plt.plot(vals_x_to_end, vals_y_fit, '--', color=data['color'], label ='Fit (days {}-{})'.format(day_start, day_end))
                except Exception as e:
                    logging.info('Could not find curve fit, exception: {}'.format(e))
            else:
                logging.info('Just logging data')

        # Plot deaths
        vals_x_to_end = [t for t in range(lowest_start_day, highest_end_day)]
        vals_y2_to_end = list(df2_melted.Value)[lowest_start_day:highest_end_day]
        plt.plot(vals_x_to_end, vals_y2_to_end, '-', color ='red', label ='Deaths')

        # Plot data
        plot_vals_x = [t for t in range(lowest_start_day, highest_end_day)]
        plot_vals_y = list(df_melted.Value)[lowest_start_day:highest_end_day]
        plt.plot(plot_vals_x, plot_vals_y, 'o', color ='green', label ='Data')

        ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
        ax.set_xlabel(PLOT_LABEL_X)
        ax.set_ylabel(PLOT_LABEL_Y)

        logging.info('Calculating ticks for days [{}, {}]'.format(lowest_start_day, highest_end_day))
        # Calculate ticks and labels (=the dates on the x-axis)
        ticks = [t for t in range(lowest_start_day, highest_end_day)][::2]
        if plot_day_end not in ticks and len(ticks) % 2 == 0:
            ticks = ticks + [plot_day_end]
        if CURVE_FIT_PLOT_DAYS_AS_X_LABEL:
            labels = [d for d in ticks]
        else:
            labels = [str((date_first + datetime.timedelta(days=d)).date()) for d in ticks]
        plt.xticks(ticks=ticks, labels=labels)

        plt.legend(loc='upper left')
        plt.show()

        if FORCE_SAVE_PLOT_TO_FILE or PLOT_CURVE_FIT_SAVE_PLOT_TO_FILE:
            save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, 'Multi-Curve-Fit-{}.png'.format(MULTI_CURVE_FIT_COUNTRY))

        plt.close(fig)
    else:
        logging.info('Could not find given country "{}"'.format(MULTI_CURVE_FIT_COUNTRY))

In [None]:
# Curve fit multiple countries

countries = [x['name'] for x in CURVE_FIT_MULTI_COUNTRIES]
plot_name = 'Curve fit for countries "{}"'.format(', '.join(countries))

if FORCE_PLOT or PLOT_CURVE_FIT_MULTI:
    logging.info('Plotting "{}"'.format(plot_name))

    fig, ax = plt.subplots(figsize=PLOT_SIZE)

    # Validate start and end days
    # Validate plot start and end days
    plot_day_end = CURVE_FIT_MULTI_PLOT_END_DAY if (CURVE_FIT_MULTI_PLOT_END_DAY > 0 and CURVE_FIT_MULTI_PLOT_END_DAY < len(dates)) else len(dates)
    plot_day_start = CURVE_FIT_MULTI_PLOT_START_DAY if CURVE_FIT_MULTI_PLOT_START_DAY > 0 and CURVE_FIT_MULTI_PLOT_START_DAY < plot_day_end else 0
    logging.info('Plotting to days [{}, {}]'.format(plot_day_start, plot_day_end))

    for country in CURVE_FIT_MULTI_COUNTRIES:
        logging.info('Preparing country "{}"'.format(country['name']))
        country_name = country['name']

        if country_name in all_countries_list:
            # Unpivot a DataFrame from wide to long format, optionally leaving identifiers set.
            df_tmp = df_grouped_summed[df_grouped_summed['Country/Region']==country_name]
            df_melted = df_tmp.melt(id_vars=df_tmp.columns.values[:1], var_name='Date', value_name='Value')

            # Check best fit data for start and end day
            start_day = country['start_day']
            end_day = country['end_day']
            # Validate start and end days
            day_end = len(dates)
            day_start = 0
            if not CURVE_FIT_PLOT_RAW_DATA_ONLY:
                day_end = end_day if (end_day > 0 and end_day < len(dates)) else len(dates)
                day_start = start_day if start_day > 0 and start_day < day_end else 0
                logging.info('Fitting to days [{}, {}]'.format(day_start, day_end))

            vals_x = np.linspace(0, len(df_melted.Value), num = len(df_melted.Value))[day_start:day_end]
            vals_y = list(df_melted.Value)[day_start:day_end]
            vals_sigma = [func_sigma(y) for y in vals_y]

            vals_x_to_end = [t for t in range(plot_day_start, plot_day_end)]
            vals_y_to_end = list(df_melted.Value)[plot_day_start:plot_day_end]

            ax.plot(vals_x_to_end, vals_y_to_end, 'o', color=country['color'], label='Data - {}'.format(country_name))
            if not CURVE_FIT_PLOT_RAW_DATA_ONLY:
                # Scipy curve fit
                try:
                    func = data['fit_func'] if 'fit_func' in data else func_fit
                    params, params_cov = scipy.optimize.curve_fit(func, xdata=vals_x, ydata=vals_y, sigma=vals_sigma)
                    vals_y_fit = [func(x, params[0], params[1], 0) for x in vals_x_to_end]
                    ax.plot(vals_x_to_end, vals_y_fit, '--', color=country['color'], label='Fit (days {}-{}) - {}'.format(day_start, day_end, country_name))
                except:
                    logging.info('Could not find curve fit')
            else:
                logging.info('Just logging data')
        else:
            logging.info('Could not find given country "{}"'.format(country_name))

    ax.set_title('{} - {} - {}'.format(PLOT_TITLE, date_last.date(), plot_name), loc='center')
    ax.set_xlabel(PLOT_LABEL_X)
    ax.set_ylabel(PLOT_LABEL_Y)

    # Calculate ticks and labels (=the dates on the x-axis)
    ticks = [t for t in range(plot_day_start, plot_day_end)][::2]
    if plot_day_end not in ticks and len(ticks) % 2 == 0:
        ticks = ticks + [plot_day_end]
    if CURVE_FIT_PLOT_DAYS_AS_X_LABEL:
        labels = [d for d in ticks]
    else:
        labels = [str((date_first + datetime.timedelta(days=d)).date()) for d in ticks]
    plt.xticks(ticks=ticks, labels=labels)

    plt.legend(loc='upper left')
    plt.show()

    if FORCE_SAVE_PLOT_TO_FILE or PLOT_CURVE_FIT_MULTI_SAVE_PLOT_TO_FILE:
        save_plot(os.getcwd(), fig, PLOT_IMAGE_PATH, date_last, 'Curve-Fit-{}.png'.format('-'.join(countries)))

    plt.close(fig)