In [None]:
# hide
%load_ext autoreload
%autoreload 2

In [None]:
# default_exp core

# lockdowndates

> Retrieve the dates of restructions imposed on countries around the world during the covid pandemic. Helpful for maching learning projects with a time element during the feature engineering phase.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export

import pandas as pd
import numpy as np
from datetime import datetime as dt
from typing import List, Union, Tuple, Optional

In [None]:
# export
class LockdownDates:
    '''
       Retrieve the dates of the restrictions in countries imposed by governments around the world during the covid-19 pandemic. 
       
       `country`: Country from table of countries in README.md
       <br/>`start_date`: Date you wish to collect dates from in "YYYY-MM-DD" format
       <br/>`end_date`: Date you wish to collect dates from in "YYYY-MM-DD" format
       `restrictions`: List of restrictions to be returned listed in README.md
    '''
        
    def __init__(self, country:Union[List[str],str], start_date:str, end_date:str, restrictions: Union[List, Tuple]):
        
        if isinstance(country, list):
            self.country = country
        else:
            self.country = [country]
            
        if isinstance(start_date, str):
            self.start_date = dt.strptime(start_date, "%Y-%m-%d")
        else:
            print("Incorrect format for start_date, expecting %Y-%m-%d")
            
        if isinstance(end_date, str):
            self.end_date = dt.strptime(end_date, "%Y-%m-%d")
        else:
            print("Incorrect format for end_date, expecting %Y-%m-%d")
            
        if isinstance(restrictions, list) or isinstance(restrictions, tuple):
            self.restrictions = restrictions
            self.restriction_keys = {
                "stay_at_home" : "C6_Stay at home requirements",
                "masks" : "H6_Facial Coverings"
            }
        else:
            print(f"Incorrect format for restrictions, you provided {type(restrictions)}, needed to be list or tuple")
        
    def fetch(self) -> pd.DataFrame:

        restrictions_to_be_used = [self.restriction_keys[restriction] for restriction in self.restrictions]
        
        usecols= ["CountryName", "CountryCode", "Date"] + restrictions_to_be_used
        
        df_dtype = {
            "CountryName": str,
            "CountryCode": str,
            "Date": str,
        }

        print("Fetching lockdown dates...")
        try:
            urls = [f"https://github.com/seanyboi/lockdowndates_data/blob/main/data/{country.lower().replace(' ', '')}.parquet?raw=true" for country in self.country]
            lockdown_df = pd.concat((pd.read_parquet(u, columns=usecols, engine="pyarrow") for u in urls))
            lockdown_df = lockdown_df.astype(df_dtype)
            return lockdown_df
        except Exception as e:
            print(f"Error fetching lockdown data - {e}")

        
    def engineer_df(self) -> pd.DataFrame:
        # fetch data
        df = self.fetch()
        
        try:
            
            restriction_names = dict(zip(self.restriction_keys.values(), self.restriction_keys.keys()))
            # rename columns
            rename_columns = columns={
                "CountryName": "country", 
                "CountryCode": "country_code", 
                "Date": "timestamp", 
                **restriction_names
            }

            df = df.rename(columns=rename_columns)

            # configure dates and set_index
            df["timestamp"] = df["timestamp"].str.replace(r'(\d{4})(\d{2})(\d{2})', r'\g<1>-\g<2>-\g<3>', regex=True)
            df["timestamp"] = pd.to_datetime(df["timestamp"])

            # convert columns to categories
            restriction_cats = list(self.restriction_keys.keys()) + ['country', 'country_code']
            for col in restriction_cats:
                df[col] = df[col].astype('category')

            if len(self.country) == 1:
                print(f"Fetched lockdown dates for: {self.country[0]}")
            else:
                print(f"Fetched lockdown dates for: {', '.join(self.country)}")

            return df
        except Exception as e:
            print(f"Formatting data failed - please raise an issue on our repo! - {e}")
    
    def filter_df(self) -> pd.DataFrame:
        df = self.engineer_df()
        try: 
            df = df[df['country'].isin(self.country)]
            df = df.pivot_table(index="timestamp", columns='country', aggfunc='first')
            df.columns = ["{}_{}".format(col[1].lower().replace("'", "").replace(" ", ""), col[0]) for col in df.columns.values]
            df = df.loc[self.start_date : self.end_date]
        
            if df.empty and len(self.country) == 1:
                raise Exception(f"No lockdown data for {self.country[0]}")
            if df.empty:
                raise Exception(f"No lockdown data for {self.country}")
        except:
            print(f"No lockdown data for {self.country} between {self.start_date} and {self.end_date}")
        
        return df
    
    def dates(self, save:bool = False) -> pd.DataFrame:
        '''
        Returns the restriction lockdown dates for a specific set of countries.
        
        <b>Parameters</b>
            <br/> &nbsp;&nbsp;&nbsp;&nbsp; `save` : bool, optional
            <br/>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; saves restrictions to a csv file for caching (default is False)
    
        <b>Returns</b>:
            <br/> &nbsp;&nbsp;&nbsp;&nbsp; DataFrame containing the dates a country was subject to certain restrictions during the pandemic.

        <b>Raises</b>:
            <br/> &nbsp;&nbsp;&nbsp;&nbsp; Exception: failed to collect data.
        '''
        restrictions = self.filter_df()
        try:
            if save:
                output_file = f"{self.country.lower().replace(' ','')}-lockdown-restrictions.csv"
                output_dir = Path("lockdown_data")
                output_dir.mkdir(parents=True, exist_ok=True)
                restrictions.to_csv(output_dir / output_file, index=False)
                print(f"Saved restrictions to - {output_dir}/{output_file}")
        except Exception as e:
            print(f"Failed to save restrictions to csv file - {e}")
        
        return restrictions

In [None]:
ld = LockdownDates("Aruba", "2022-01-01", "2022-01-30", ("stay_at_home", "masks"))
x = ld.dates()
x

Fetching lockdown dates...
Fetched lockdown dates for: Aruba


Unnamed: 0_level_0,aruba_country_code,aruba_masks,aruba_stay_at_home
timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2022-01-01,ABW,2.0,2.0
2022-01-02,ABW,2.0,2.0
2022-01-03,ABW,2.0,2.0
2022-01-04,ABW,2.0,2.0
2022-01-05,ABW,2.0,2.0
2022-01-06,ABW,2.0,2.0
2022-01-07,ABW,2.0,2.0
2022-01-08,ABW,2.0,2.0
2022-01-09,ABW,2.0,2.0
2022-01-10,ABW,2.0,2.0


In [None]:
show_doc(LockdownDates.dates)

<h4 id="LockdownDates.dates" class="doc_header"><code>LockdownDates.dates</code><a href="__main__.py#L97" class="source_link" style="float:right">[source]</a></h4>

> <code>LockdownDates.dates</code>(**`save`**:`bool`=*`False`*)

Returns the restriction lockdown dates for a specific set of countries.

<b>Parameters</b>
    <br/> &nbsp;&nbsp;&nbsp;&nbsp; `save` : bool, optional
    <br/>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; saves restrictions to a csv file for caching (default is False)

<b>Returns</b>:
    <br/> &nbsp;&nbsp;&nbsp;&nbsp; DataFrame containing the dates a country was subject to certain restrictions during the pandemic.

<b>Raises</b>:
    <br/> &nbsp;&nbsp;&nbsp;&nbsp; Exception: failed to collect data.

In [None]:
#ignore
ld = LockdownDates("Aruba", "2022-01-01", "2022-01-30", ["stay_at_home", "masks"])

In [None]:
#ignore
assert ld.country==["Aruba"]
assert type(ld.country)==list

In [None]:
#ignore
assert ld.start_date==dt(2022, 1, 1, 0, 0)
assert type(ld.start_date)==dt

In [None]:
#ignore
assert ld.end_date==dt(2022, 1, 30, 0, 0)
assert type(ld.end_date)==dt

In [None]:
#ignore
assert ld.restrictions==["stay_at_home", "masks"]
assert type(ld.restrictions)==list

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
