## Import libraries

In [45]:
import warnings

warnings.filterwarnings("ignore")

import os
import time
import random
import pandas as pd
import pickle
import numpy as np
from tqdm.auto import tqdm
from datetime import datetime
from itertools import product
import torch
from torch import nn
from typing import List, Tuple, Dict, Optional
from sklearn.preprocessing import MaxAbsScaler
from sklearn.linear_model import Ridge
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.utils.losses import SmapeLoss
from darts.dataprocessing.transformers import Scaler
from darts.metrics import smape
from darts.utils.utils import SeasonalityMode, TrendMode, ModelMode
from darts.models import *

horizon = 1 ## Forecast 1 month to keep the predictions reasonable

## Import Data
- We try multiple datasets to see which provides us with the best outcome

In [127]:
from darts.datasets import TrafficDataset, AirPassengersDataset, AustralianTourismDataset
from ucimlrepo import fetch_ucirepo 
  
# fetch dataset 
room_occupancy_estimation = fetch_ucirepo(id=864) 

traffic_data = TrafficDataset().load().pd_dataframe().reset_index()
traffic_data.columns.name = None
airpass_data = AirPassengersDataset().load().pd_dataframe().reset_index()
airpass_data.columns.name = None

aus_tour_data = AustralianTourismDataset().load().pd_dataframe().reset_index()
aus_tour_data.columns.name = None
aus_tour_data = aus_tour_data[['time', 'NSW', 'VIC', 'QLD', 'SA', 'WA', "TAS", "NT"]].rename(columns = {"time":'Month'})

traffic_data.head()

Unnamed: 0,Date,0,1,2,3,4,5,6,7,8,...,852,853,854,855,856,857,858,859,860,861
0,2015-01-01 00:00:00,0.0048,0.0146,0.0289,0.0142,0.0064,0.0232,0.0162,0.0242,0.0341,...,0.0051,0.0051,0.0074,0.0079,0.0051,0.0051,0.0339,0.0051,0.01,0.0121
1,2015-01-01 01:00:00,0.0072,0.0148,0.035,0.0174,0.0084,0.024,0.0201,0.0338,0.0434,...,0.0036,0.0036,0.0107,0.0058,0.0036,0.0036,0.0348,0.0036,0.0087,0.0136
2,2015-01-01 02:00:00,0.004,0.0101,0.0267,0.0124,0.0049,0.017,0.0127,0.0255,0.0332,...,0.003,0.003,0.0043,0.005,0.003,0.003,0.0327,0.003,0.0061,0.0107
3,2015-01-01 03:00:00,0.0039,0.006,0.0218,0.009,0.0029,0.0118,0.0088,0.0163,0.0211,...,0.0033,0.0033,0.0019,0.0052,0.0033,0.0033,0.0292,0.0033,0.004,0.0071
4,2015-01-01 04:00:00,0.0042,0.0055,0.0191,0.0082,0.0024,0.0095,0.0064,0.0087,0.0144,...,0.0049,0.0049,0.0011,0.0071,0.0049,0.0049,0.0264,0.0049,0.004,0.0039


## Scale the data
- Make sure each data entry is $0\leq x\leq 1$

In [130]:
scaler = MaxAbsScaler()
scaled_traffic = pd.DataFrame(scaler.fit_transform(np.array(traffic_data.drop(columns = 'Date'))))
scaled_traffic = pd.concat([traffic_data['Date'], scaled_traffic], axis = 1)
scaled_traffic['Scaled Avg'] = scaled_traffic.drop(columns = 'Date').mean(axis = 1)
scaled_traffic

Unnamed: 0,Date,0,1,2,3,4,5,6,7,8,...,853,854,855,856,857,858,859,860,861,Scaled Avg
0,2015-01-01 00:00:00,0.010112,0.032444,0.068793,0.028704,0.015729,0.045481,0.044751,0.049137,0.093068,...,0.058219,0.021637,0.014221,0.017653,0.012830,0.238901,0.014439,0.024044,0.055658,0.054815
1,2015-01-01 01:00:00,0.015167,0.032889,0.083313,0.035173,0.020644,0.047050,0.055525,0.068629,0.118450,...,0.041096,0.031287,0.010441,0.012461,0.009057,0.245243,0.010193,0.020918,0.062557,0.065740
2,2015-01-01 02:00:00,0.008426,0.022444,0.063556,0.025066,0.012042,0.033327,0.035083,0.051777,0.090611,...,0.034247,0.012573,0.009001,0.010384,0.007547,0.230444,0.008494,0.014667,0.049218,0.049981
3,2015-01-01 03:00:00,0.008216,0.013333,0.051892,0.018193,0.007127,0.023133,0.024309,0.033096,0.057587,...,0.037671,0.005556,0.009361,0.011423,0.008302,0.205779,0.009343,0.009618,0.032659,0.035139
4,2015-01-01 04:00:00,0.008848,0.012222,0.045465,0.016576,0.005898,0.018624,0.017680,0.017665,0.039301,...,0.055936,0.003216,0.012781,0.016961,0.012327,0.186047,0.013873,0.009618,0.017939,0.029147
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
17539,2016-12-31 19:00:00,0.072046,0.073778,0.328017,0.112998,0.040551,0.075671,0.126243,0.139492,0.195415,...,0.204338,0.081287,0.050405,0.078574,0.094340,0.100070,0.056908,0.081991,0.197332,0.127081
17540,2016-12-31 20:00:00,0.060038,0.056222,0.129017,0.098039,0.032195,0.064889,0.102762,0.115533,0.182587,...,0.176941,0.058772,0.048065,0.071651,0.080503,0.085976,0.046149,0.073575,0.166973,0.103579
17541,2016-12-31 21:00:00,0.058142,0.057111,0.131159,0.095816,0.029737,0.063909,0.105249,0.114924,0.184225,...,0.162100,0.060819,0.045365,0.065421,0.072201,0.086681,0.043035,0.070209,0.149494,0.101878
17542,2016-12-31 22:00:00,0.048873,0.050889,0.118067,0.087932,0.029246,0.058224,0.099724,0.104975,0.163210,...,0.141553,0.064035,0.038704,0.058498,0.064906,0.067653,0.036806,0.060591,0.128335,0.096013
