# Libraries

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
from numba import njit

# Paths setup

In [None]:
input_path = Path('../input/m5-forecasting-accuracy/')
output_path = Path('processed')
output_path.mkdir()

# Calendar

In [None]:
cal_dtypes = {
    'd': 'category',
    'wm_yr_wk': np.uint16,
    'event_name_1': 'category',
    'event_type_1': 'category',
    'event_name_2': 'category',
    'event_type_2': 'category',
    'snap_CA': np.uint8,
    'snap_TX': np.uint8,
    'snap_WI': np.uint8,
}
cal = pd.read_csv(input_path/'calendar.csv', 
                  dtype=cal_dtypes, 
                  usecols=list(cal_dtypes.keys()) + ['date'], 
                  parse_dates=['date'])
cal

In [None]:
event_cols = [k for k in cal_dtypes if k.startswith('event')]
for col in event_cols:
    cal[col] = cal[col].cat.add_categories('nan').fillna('nan')

# Prices

In [None]:
prices_dtypes = {
    'store_id': 'category',
    'item_id': 'category',
    'wm_yr_wk': np.uint16,
    'sell_price': np.float32
}

prices = pd.read_csv(input_path/'sell_prices.csv', dtype=prices_dtypes)
prices

# Sales

## Read

In [None]:
sales_dtypes = {
    'id': 'category',
    'item_id': prices.item_id.dtype,
    'dept_id': 'category',
    'cat_id': 'category',
    'store_id': 'category',
    'state_id': 'category',
    **{f'd_{i+1}': np.float32 for i in range(1913)}
}
sales = pd.read_csv(input_path/'sales_train_validation.csv', dtype=sales_dtypes)
sales

## Convert to long format

In [None]:
long = sales.melt(id_vars=['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], var_name='d', value_name='y')
long

## Merge with calendar

In [None]:
long['d'] = long['d'].astype(cal.d.dtype)
long = long.merge(cal, on=['d'])
long

## Merge with prices

In [None]:
long = long.merge(prices, on=['store_id', 'item_id', 'wm_yr_wk'])
long

## Save future calendar and prices for updating the features

In [None]:
last_date_train = long['date'].max()
cal = cal[cal['date'] > last_date_train]
cal.to_parquet(output_path/'calendar.parquet')

last_wmyrwk = long['wm_yr_wk'].max()
prices = prices[prices['wm_yr_wk'] >= last_wmyrwk]
prices.to_parquet(output_path/'prices.parquet')

## Remove unnecessary information from sales

### Unnecessary columns

In [None]:
long = long.drop(columns=['d', 'wm_yr_wk'])

### Remove zeros at the start of each serie

In [None]:
@njit
def first_nz_mask(x):
    """Return a boolean mask where the True starts at the first non-zero value."""
    mask = np.full(x.size, True)
    for idx, value in enumerate(x):
        if value == 0:
            mask[idx] = False
        else:
            break
    return mask

In [None]:
long = long.sort_values(['id', 'date'])
keep_mask = long.groupby('id')['y'].transform(lambda x: first_nz_mask(x.values))
long = long[keep_mask]
long.to_parquet(output_path/'sales.parquet')