# Deep Learning for Business Applications course

## TOPIC 8: More Tasks for Deep Learning. Time-series prediction with TimesFM

### 1. Libraries and parameters

In [None]:
!pip install statsforecast
!git clone https://github.com/google-research/timesfm.git
!cd timesfm && pip install -e .[torch]

In [None]:
import os
import torch
import timesfm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import seasonal_decompose
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA, AutoETS

DEVICE = 'cpu'
torch.set_float32_matmul_precision('high')

### 2. Dataset

In [None]:
# dataset is related to world gold prices

DATA_PATH = '~/__DATA/DLBA_F25/topic_08'
file_name = 'gold_prices_1979-2021.csv'
file_path = f'{DATA_PATH}/{file_name}'

In [None]:
df = pd.read_csv(file_path)
df['Date'] = pd.to_datetime(df['Date'])
df = df.set_index('Date').resample('MS').mean()
display(df.head())

In [None]:
df.columns

### 3. EDA

In [None]:
TARGET = 'United Arab Emirates(AED)'

In [None]:
result = seasonal_decompose(df[TARGET])

fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(10, 12))
ax1.set_title(f'Monthly gold prices in {TARGET}')
result.observed.plot(ax=ax1, color='green')
ax1.set_ylabel('Observed')
result.trend.plot(ax=ax2, color='green')
ax2.set_ylabel('Trend')
result.seasonal.plot(ax=ax3, color='green')
ax3.set_ylabel('Seasonal')
result.resid.plot(ax=ax4, color='green')
ax4.set_ylabel('Residual')
plt.tight_layout()
plt.show()

### 4. Train-test split

In [None]:
df = pd.DataFrame({
    'unique_id': [1] * len(df),
    'ds': df.index,
    'price': df[TARGET]
})
display(df.head())

In [None]:
cut_date = '2019-07-01'
df_train = df[df['ds'] <= cut_date]
df_test = df[df['ds'] > cut_date]

In [None]:
df_train.shape

In [None]:
df_train.tail()

In [None]:
df_test.shape

In [None]:
df_test.head()

### 5. Training TimesFM

In [None]:
# initialize the TimesFM model
# and load the pretrained model checkpoint

model_name = 'google/timesfm-2.5-200m-pytorch'
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(model_name)

In [None]:
model.compile(
    timesfm.ForecastConfig(
        max_context=512,  # length of the context window for the model
        max_horizon=256,  # forecasting horizon length
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
    )
)

In [None]:
point_forecast, quantile_forecast = model.forecast(
    horizon=24,
    inputs=[df_train.price]
)

### 6. Training statistical models

In [None]:
# let's use two classical models to compare with TimeFS
# AutoARIMA model with annual seasonality for monthly data
autoarima = AutoARIMA(season_length=12)

# AutoETS model
# with annual seasonality for monthly data
autoets = AutoETS(season_length=12)

In [None]:
# create StatsForecast object with models
statforecast = StatsForecast(
    models=[autoarima, autoets],
    freq='MS',
    n_jobs=-1
)

# Fit the models
statforecast.fit(df=df_train, target_col='price')

# generate forecasts
# with horizon forecast is for 24 periods
forecasts = statforecast.predict(h=24)

In [None]:
forecasts.head()

### 7. Compare results

In [None]:
forecasts['timesfm'] = point_forecast[0, :]
forecasts.head()

In [None]:
forecasts = pd.merge(forecasts, df_test, on='ds')

In [None]:
forecasts.head()

In [None]:
def error_metrics(y, y_pred):
    """
    Calculates MAE, RMSE, MAPE metrics.

    """
    y = np.array(y)
    y_pred = np.array(y_pred)
    metrics = {
        'MAE': np.mean(np.abs(y - y_pred)),
        'RMSE': np.sqrt(np.mean((y - y_pred) ** 2)),
        'MAPE': np.mean(np.abs((y - y_pred) / y)) * 100
    }
    return metrics

In [None]:
models = ['timesfm', 'AutoARIMA', 'AutoETS']
err_metrics = []
for col in models:
    y_pred = forecasts[col]
    tmp_dict = error_metrics(
        y=forecasts['price'],
        y_pred=forecasts[col]
    )
    tmp_dict['model'] = col
    err_metrics.append(tmp_dict)

In [None]:
pd.DataFrame(err_metrics)