In [4]:
import os

os.chdir("..")


In [5]:
from utils import load_dataset
from run_exp import get_instance
import optuna
from pytorch_lightning.utilities.model_summary import ModelSummary
import pandas as pd

def build_m3_and_m4_dictionary(dataset, group, freq):
    """
    Build a dictionary for M3 and M4 datasets.
    """
    dataset_dict = {
        "name": dataset,
        "directory": f"C:\\Users\\ricar\\mixture_of_experts_time_series\\data\\{dataset}\\",
        "group": group,
        "freq": freq,
    }
    return dataset_dict

def calculate_avg_active_params(nr_of_parameters, nr_of_experts, active_experts, model_name):
    """
    Calculate the average number of active parameters.
    """
    if not active_experts and not nr_of_experts:
        return nr_of_parameters
    
    if model_name == "nbeatsmoeshared":
        return round((nr_of_parameters * (active_experts + 1)) / (nr_of_experts + 1), 2)

    return round((nr_of_parameters  * active_experts) / nr_of_experts, 2)

list_models = ["nbeats", "nbeatsmoe", "nbeatsmoeshared", "nbeatsstackmoe"]


map_horizon_freq = {
    "M": 18,
    "Q": 8,
    "Y": 6,
}
list_datasets = [
    {"name": "gluonts_m1_quarterly"},
    {"name": "gluonts_m1_monthly"},
    {"name": "gluonts_m1_yearly"},
    {"name": "gluonts_tourism_monthly"},
    {"name": "gluonts_tourism_quarterly"},
    {"name": "gluonts_tourism_yearly"},
    build_m3_and_m4_dictionary("m3", "Monthly", "M"),
    build_m3_and_m4_dictionary("m3", "Quarterly", "Q"),
    build_m3_and_m4_dictionary("m3", "Yearly", "Y"),
    build_m3_and_m4_dictionary("m4", "Monthly", "M"),
    build_m3_and_m4_dictionary("m4", "Quarterly", "Q"),
    build_m3_and_m4_dictionary("m4", "Yearly", "Y"),
]


  from .autonotebook import tqdm as notebook_tqdm
2025-04-30 15:04:34,298	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-04-30 15:04:35,016	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [7]:
STORAGE = "sqlite:///c:/Users/ricar/mixture_of_experts_time_series/db/study_nbeats_blcs.db"

In [None]:
summaries_pd = pd.DataFrame(
    {
        "model_name": [],
        "dataset": [],
        "freq": [],
        "nr_of_parameters": [],
        "nr_experts": [],
        "top_k": [],
        "average_active_parameters": [],
    }
)

for dataset in list_datasets:
    Y_ALL = load_dataset(
        dataset["name"],
        dataset
    )
    horizon = None
    for model_name in list_models:
        if type(Y_ALL) == tuple:
            Y_ALL, horizon, n_lags, dataset["group"], _ = Y_ALL
            dataset["freq"] = dataset["group"]
            dataset["name"] = dataset["name"].replace(" ", "_")

        if horizon == None:
            horizon = map_horizon_freq[dataset["freq"]]

        study_name = f"{model_name}_{dataset['name']}_{dataset['group']}_{horizon}"
        print(study_name)

        study = optuna.load_study(
                        study_name=study_name,
                        storage=STORAGE,
        )

        # initialize the model with the best trial parameters and print the number of parameters
        best_params = study.best_params

        # some models where run with a different number of units in the mlp
        # manely to ensure that we could run the exp in a reasonable time
        if model_name == "nbeatsmoe" or model_name == "nbeatsmoeshared" and \
           (dataset["name"] == "gluonts_tourism_yearly" \
            or dataset["name"] == "gluonts_m1_quarterly" \
            or dataset["name"] == "gluonts_m1_monthly" \
            or dataset["name"] == "gluonts_tourism_monthly"):
            if best_params["mlp_units"][0][0] == 512 or best_params["mlp_units"][0][0] == 256:
                best_params["mlp_units"] = [[128, 128], [128, 128], [128, 128]]

        model = get_instance(model_name, best_params, horizon=horizon)

        nr_of_parameters = sum(p.numel() for p in model.parameters())
        nr_experts = best_params.get("nr_experts", None)
        top_k = best_params.get("top_k", None)

        summaries_pd = pd.concat(
            [
            summaries_pd,
            pd.DataFrame(
                {
                "model_name": [model_name],
                "dataset": [dataset["name"]],
                "freq": [dataset["freq"]],
                "nr_of_parameters": [nr_of_parameters],
                "nr_experts": [nr_experts],
                "top_k": [top_k],
                "average_active_parameters": [
                    calculate_avg_active_params(
                        nr_of_parameters, nr_experts, top_k, model_name
                    )    
                ],
                }
            ),
            ],
            ignore_index=True,
        )

        # print the model summary
        summary = ModelSummary(model, max_depth=3)
        print(summary)

        del model, study, best_params, summary

# store the dataframe in a csv file
summaries_pd.to_csv(
    "c:/Users/ricar/mixture_of_experts_time_series/results/model_summaries.csv",
    index=False,
)
        


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_quarterly.


Loading m1_quarterly dataset...
nbeats_gluonts_m1_quarterly_Q_2


INFO:lightning_fabric.utilities.seed:Seed set to 1
INFO:lightning_fabric.utilities.seed:Seed set to 3
INFO:lightning_fabric.utilities.seed:Seed set to 20


   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 152 K  | train
4  | blocks.0         | NBEATSBlock      | 51.0 K | train
5  | blocks.0.layers  | Sequential       | 51.0 K | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 51.0 K | train
8  | blocks.9.layers  | Sequential       | 51.0 K | train
9  | blocks.9.basis   | TrendBasis       | 18     | train
10 | blocks.18        | NBEATSBlock      | 50.7 K | train
11 | blocks.18.layers | Sequential       | 50.7 K | train
12 | blocks.18.basis  | SeasonalityBasis | 12     | train
---------------------------------------------------------------
152 K     Trainable params
30        Non-trainable params
15

INFO:lightning_fabric.utilities.seed:Seed set to 18
INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_monthly.


   | Name            | Type             | Params | Mode 
--------------------------------------------------------------
0  | loss            | MAE              | 0      | train
1  | padder_train    | ConstantPad1d    | 0      | train
2  | scaler          | TemporalNorm     | 0      | train
3  | blocks          | ModuleList       | 39.5 K | train
4  | blocks.0        | NBEATSBlock      | 13.2 K | train
5  | blocks.0.layers | Sequential       | 13.2 K | train
6  | blocks.0.basis  | IdentityBasis    | 0      | train
7  | blocks.1        | NBEATSBlock      | 13.2 K | train
8  | blocks.1.layers | Sequential       | 13.2 K | train
9  | blocks.1.basis  | TrendBasis       | 18     | train
10 | blocks.2        | NBEATSBlock      | 13.1 K | train
11 | blocks.2.layers | Sequential       | 13.1 K | train
12 | blocks.2.basis  | SeasonalityBasis | 12     | train
13 | gate            | Sequential       | 23     | train
14 | gate.0          | LayerNorm        | 8      | train
15 | gate.1          | Li

INFO:lightning_fabric.utilities.seed:Seed set to 9
INFO:lightning_fabric.utilities.seed:Seed set to 18


nbeats_gluonts_m1_monthly_M_8
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 4.3 K  | train
4  | blocks.0         | NBEATSBlock      | 1.4 K  | train
5  | blocks.0.layers  | Sequential       | 1.4 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 1.2 K  | train
8  | blocks.9.layers  | Sequential       | 1.1 K  | train
9  | blocks.9.basis   | TrendBasis       | 60     | train
10 | blocks.18        | NBEATSBlock      | 1.8 K  | train
11 | blocks.18.layers | Sequential       | 1.5 K  | train
12 | blocks.18.basis  | SeasonalityBasis | 280    | train
---------------------------------------------------------------
4.0 K     Trainable params
340

INFO:lightning_fabric.utilities.seed:Seed set to 12
INFO:lightning_fabric.utilities.seed:Seed set to 4
INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_yearly.


   | Name              | Type                | Params | Mode 
-------------------------------------------------------------------
0  | loss              | MAE                 | 0      | train
1  | padder_train      | ConstantPad1d       | 0      | train
2  | scaler            | TemporalNorm        | 0      | train
3  | blocks            | ModuleList          | 6.3 K  | train
4  | blocks.0          | NBEATSMoEBlock      | 2.1 K  | train
5  | blocks.0.gate     | Sequential          | 141    | train
6  | blocks.0.softmax  | Softmax             | 0      | train
7  | blocks.0.pooling  | SharedExpertPooling | 2.1 K  | train
8  | blocks.0.basis    | IdentityBasis       | 0      | train
9  | blocks.6          | NBEATSMoEBlock      | 1.5 K  | train
10 | blocks.6.gate     | Sequential          | 141    | train
11 | blocks.6.softmax  | Softmax             | 0      | train
12 | blocks.6.pooling  | SharedExpertPooling | 1.4 K  | train
13 | blocks.6.basis    | TrendBasis          | 60     | train
14

INFO:lightning_fabric.utilities.seed:Seed set to 5
INFO:lightning_fabric.utilities.seed:Seed set to 20


nbeats_gluonts_m1_yearly_Y_2
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 2.4 M  | train
4  | blocks.0         | NBEATSBlock      | 792 K  | train
5  | blocks.0.layers  | Sequential       | 792 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 793 K  | train
8  | blocks.9.layers  | Sequential       | 793 K  | train
9  | blocks.9.basis   | TrendBasis       | 15     | train
10 | blocks.18        | NBEATSBlock      | 792 K  | train
11 | blocks.18.layers | Sequential       | 792 K  | train
12 | blocks.18.basis  | SeasonalityBasis | 10     | train
---------------------------------------------------------------
2.4 M     Trainable params
25  

INFO:lightning_fabric.utilities.seed:Seed set to 11
INFO:lightning_fabric.utilities.seed:Seed set to 18
INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_monthly.


   | Name              | Type                | Params | Mode 
-------------------------------------------------------------------
0  | loss              | MAE                 | 0      | train
1  | padder_train      | ConstantPad1d       | 0      | train
2  | scaler            | TemporalNorm        | 0      | train
3  | blocks            | ModuleList          | 2.7 K  | train
4  | blocks.0          | NBEATSMoEBlock      | 905    | train
5  | blocks.0.gate     | Sequential          | 18     | train
6  | blocks.0.softmax  | Softmax             | 0      | train
7  | blocks.0.pooling  | SharedExpertPooling | 905    | train
8  | blocks.0.basis    | IdentityBasis       | 0      | train
9  | blocks.6          | NBEATSMoEBlock      | 947    | train
10 | blocks.6.gate     | Sequential          | 18     | train
11 | blocks.6.softmax  | Softmax             | 0      | train
12 | blocks.6.pooling  | SharedExpertPooling | 932    | train
13 | blocks.6.basis    | TrendBasis          | 15     | train
14

INFO:lightning_fabric.utilities.seed:Seed set to 19
INFO:lightning_fabric.utilities.seed:Seed set to 7
INFO:lightning_fabric.utilities.seed:Seed set to 20


nbeats_gluonts_tourism_monthly_M_18
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 642 K  | train
4  | blocks.0         | NBEATSBlock      | 214 K  | train
5  | blocks.0.layers  | Sequential       | 214 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 205 K  | train
8  | blocks.6.layers  | Sequential       | 205 K  | train
9  | blocks.6.basis   | TrendBasis       | 126    | train
10 | blocks.12        | NBEATSBlock      | 222 K  | train
11 | blocks.12.layers | Sequential       | 221 K  | train
12 | blocks.12.basis  | SeasonalityBasis | 1.4 K  | train
---------------------------------------------------------------
641 K     Trainable para

INFO:lightning_fabric.utilities.seed:Seed set to 10
INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_quarterly.


   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 9.6 M  | train
4  | blocks.0         | NBEATSBlock      | 3.2 M  | train
5  | blocks.0.layers  | Sequential       | 3.2 M  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 3.2 M  | train
8  | blocks.6.layers  | Sequential       | 3.2 M  | train
9  | blocks.6.basis   | TrendBasis       | 126    | train
10 | blocks.12        | NBEATSBlock      | 3.2 M  | train
11 | blocks.12.layers | Sequential       | 3.2 M  | train
12 | blocks.12.basis  | SeasonalityBasis | 1.4 K  | train
13 | gate             | Sequential       | 498    | train
14 | gate.0           | LayerNorm        | 48     | train
15 | gat

INFO:lightning_fabric.utilities.seed:Seed set to 2
INFO:lightning_fabric.utilities.seed:Seed set to 3


nbeats_gluonts_tourism_quarterly_Q_8
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 1.4 K  | train
4  | blocks.0         | NBEATSBlock      | 364    | train
5  | blocks.0.layers  | Sequential       | 364    | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 346    | train
8  | blocks.9.layers  | Sequential       | 310    | train
9  | blocks.9.basis   | TrendBasis       | 36     | train
10 | blocks.18        | NBEATSBlock      | 676    | train
11 | blocks.18.layers | Sequential       | 508    | train
12 | blocks.18.basis  | SeasonalityBasis | 168    | train
---------------------------------------------------------------
1.2 K     Trainable par

INFO:lightning_fabric.utilities.seed:Seed set to 11
INFO:lightning_fabric.utilities.seed:Seed set to 7
INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_yearly.


   | Name              | Type                | Params | Mode 
-------------------------------------------------------------------
0  | loss              | MAE                 | 0      | train
1  | padder_train      | ConstantPad1d       | 0      | train
2  | scaler            | TemporalNorm        | 0      | train
3  | blocks            | ModuleList          | 11.0 K | train
4  | blocks.0          | NBEATSMoEBlock      | 3.3 K  | train
5  | blocks.0.gate     | Sequential          | 53     | train
6  | blocks.0.softmax  | Softmax             | 0      | train
7  | blocks.0.pooling  | SharedExpertPooling | 3.3 K  | train
8  | blocks.0.basis    | IdentityBasis       | 0      | train
9  | blocks.9          | NBEATSMoEBlock      | 2.9 K  | train
10 | blocks.9.gate     | Sequential          | 53     | train
11 | blocks.9.softmax  | Softmax             | 0      | train
12 | blocks.9.pooling  | SharedExpertPooling | 2.9 K  | train
13 | blocks.9.basis    | TrendBasis          | 36     | train
14

INFO:lightning_fabric.utilities.seed:Seed set to 11
INFO:lightning_fabric.utilities.seed:Seed set to 5


nbeats_gluonts_tourism_yearly_Y_4
   | Name            | Type             | Params | Mode 
--------------------------------------------------------------
0  | loss            | MAE              | 0      | train
1  | padder_train    | ConstantPad1d    | 0      | train
2  | scaler          | TemporalNorm     | 0      | train
3  | blocks          | ModuleList       | 2.4 M  | train
4  | blocks.0        | NBEATSBlock      | 793 K  | train
5  | blocks.0.layers | Sequential       | 793 K  | train
6  | blocks.0.basis  | IdentityBasis    | 0      | train
7  | blocks.1        | NBEATSBlock      | 793 K  | train
8  | blocks.1.layers | Sequential       | 793 K  | train
9  | blocks.1.basis  | TrendBasis       | 21     | train
10 | blocks.2        | NBEATSBlock      | 796 K  | train
11 | blocks.2.layers | Sequential       | 796 K  | train
12 | blocks.2.basis  | SeasonalityBasis | 42     | train
--------------------------------------------------------------
2.4 M     Trainable params
63        Non-t

INFO:lightning_fabric.utilities.seed:Seed set to 14
INFO:lightning_fabric.utilities.seed:Seed set to 8


   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 1.2 M  | train
4  | blocks.0         | NBEATSMoEBlock   | 407 K  | train
5  | blocks.0.gate    | Sequential       | 38     | train
6  | blocks.0.softmax | Softmax          | 0      | train
7  | blocks.0.pooling | SparsePooling    | 407 K  | train
8  | blocks.0.basis   | IdentityBasis    | 0      | train
9  | blocks.3         | NBEATSMoEBlock   | 406 K  | train
10 | blocks.3.gate    | Sequential       | 38     | train
11 | blocks.3.softmax | Softmax          | 0      | train
12 | blocks.3.pooling | SparsePooling    | 406 K  | train
13 | blocks.3.basis   | TrendBasis       | 21     | train
14 | blocks.6         | NBEATSMoEBlock   | 412 K  | train
15 | blo

INFO:lightning_fabric.utilities.seed:Seed set to 8
INFO:lightning_fabric.utilities.seed:Seed set to 20
INFO:lightning_fabric.utilities.seed:Seed set to 12


nbeats_m3_Monthly_18
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 181 K  | train
4  | blocks.0         | NBEATSBlock      | 61.2 K | train
5  | blocks.0.layers  | Sequential       | 61.2 K | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 55.2 K | train
8  | blocks.9.layers  | Sequential       | 55.0 K | train
9  | blocks.9.basis   | TrendBasis       | 162    | train
10 | blocks.18        | NBEATSBlock      | 64.9 K | train
11 | blocks.18.layers | Sequential       | 63.0 K | train
12 | blocks.18.basis  | SeasonalityBasis | 1.8 K  | train
---------------------------------------------------------------
179 K     Trainable params
2.0 K     No

INFO:lightning_fabric.utilities.seed:Seed set to 20


nbeatsstackmoe_m3_Monthly_18
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 674 K  | train
4  | blocks.0         | NBEATSBlock      | 229 K  | train
5  | blocks.0.layers  | Sequential       | 229 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.9         | NBEATSBlock      | 213 K  | train
8  | blocks.9.layers  | Sequential       | 212 K  | train
9  | blocks.9.basis   | TrendBasis       | 216    | train
10 | blocks.18        | NBEATSBlock      | 231 K  | train
11 | blocks.18.layers | Sequential       | 228 K  | train
12 | blocks.18.basis  | SeasonalityBasis | 2.4 K  | train
13 | gate             | Sequential       | 1.6 K  | train
14 | gate.0           | LayerNorm    

INFO:lightning_fabric.utilities.seed:Seed set to 4
INFO:lightning_fabric.utilities.seed:Seed set to 6


nbeats_m3_Quarterly_8
   | Name            | Type             | Params | Mode 
--------------------------------------------------------------
0  | loss            | MAE              | 0      | train
1  | padder_train    | ConstantPad1d    | 0      | train
2  | scaler          | TemporalNorm     | 0      | train
3  | blocks          | ModuleList       | 612 K  | train
4  | blocks.0        | NBEATSBlock      | 203 K  | train
5  | blocks.0.layers | Sequential       | 203 K  | train
6  | blocks.0.basis  | IdentityBasis    | 0      | train
7  | blocks.3        | NBEATSBlock      | 201 K  | train
8  | blocks.3.layers | Sequential       | 201 K  | train
9  | blocks.3.basis  | TrendBasis       | 48     | train
10 | blocks.6        | NBEATSBlock      | 207 K  | train
11 | blocks.6.layers | Sequential       | 206 K  | train
12 | blocks.6.basis  | SeasonalityBasis | 224    | train
--------------------------------------------------------------
611 K     Trainable params
272       Non-trainable par

INFO:lightning_fabric.utilities.seed:Seed set to 7
INFO:lightning_fabric.utilities.seed:Seed set to 12


   | Name             | Type                | Params | Mode 
------------------------------------------------------------------
0  | loss             | MAE                 | 0      | train
1  | padder_train     | ConstantPad1d       | 0      | train
2  | scaler           | TemporalNorm        | 0      | train
3  | blocks           | ModuleList          | 39.9 K | train
4  | blocks.0         | NBEATSMoEBlock      | 13.6 K | train
5  | blocks.0.gate    | Sequential          | 83     | train
6  | blocks.0.softmax | Softmax             | 0      | train
7  | blocks.0.pooling | SharedExpertPooling | 13.6 K | train
8  | blocks.0.basis   | IdentityBasis       | 0      | train
9  | blocks.3         | NBEATSMoEBlock      | 11.9 K | train
10 | blocks.3.gate    | Sequential          | 83     | train
11 | blocks.3.softmax | Softmax             | 0      | train
12 | blocks.3.pooling | SharedExpertPooling | 11.8 K | train
13 | blocks.3.basis   | TrendBasis          | 72     | train
14 | blocks.6     

INFO:lightning_fabric.utilities.seed:Seed set to 7
INFO:lightning_fabric.utilities.seed:Seed set to 9
INFO:lightning_fabric.utilities.seed:Seed set to 16


nbeats_m3_Yearly_6
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 1.3 K  | train
4  | blocks.0         | NBEATSBlock      | 380    | train
5  | blocks.0.layers  | Sequential       | 380    | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 362    | train
8  | blocks.6.layers  | Sequential       | 326    | train
9  | blocks.6.basis   | TrendBasis       | 36     | train
10 | blocks.12        | NBEATSBlock      | 572    | train
11 | blocks.12.layers | Sequential       | 452    | train
12 | blocks.12.basis  | SeasonalityBasis | 120    | train
---------------------------------------------------------------
1.2 K     Trainable params
156       Non-

INFO:lightning_fabric.utilities.seed:Seed set to 10


   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 3.6 K  | train
4  | blocks.0         | NBEATSBlock      | 1.1 K  | train
5  | blocks.0.layers  | Sequential       | 1.1 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 1.1 K  | train
8  | blocks.6.layers  | Sequential       | 1.0 K  | train
9  | blocks.6.basis   | TrendBasis       | 36     | train
10 | blocks.12        | NBEATSBlock      | 1.4 K  | train
11 | blocks.12.layers | Sequential       | 1.3 K  | train
12 | blocks.12.basis  | SeasonalityBasis | 120    | train
13 | gate             | Sequential       | 138    | train
14 | gate.0           | LayerNorm        | 12     | train
15 | gat

INFO:lightning_fabric.utilities.seed:Seed set to 13
INFO:lightning_fabric.utilities.seed:Seed set to 20
INFO:lightning_fabric.utilities.seed:Seed set to 20


nbeats_m4_Monthly_18
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 181 K  | train
4  | blocks.0         | NBEATSBlock      | 61.2 K | train
5  | blocks.0.layers  | Sequential       | 61.2 K | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 55.2 K | train
8  | blocks.6.layers  | Sequential       | 55.0 K | train
9  | blocks.6.basis   | TrendBasis       | 162    | train
10 | blocks.12        | NBEATSBlock      | 64.9 K | train
11 | blocks.12.layers | Sequential       | 63.0 K | train
12 | blocks.12.basis  | SeasonalityBasis | 1.8 K  | train
---------------------------------------------------------------
179 K     Trainable params
2.0 K     No

INFO:lightning_fabric.utilities.seed:Seed set to 13


   | Name             | Type                | Params | Mode 
------------------------------------------------------------------
0  | loss             | MAE                 | 0      | train
1  | padder_train     | ConstantPad1d       | 0      | train
2  | scaler           | TemporalNorm        | 0      | train
3  | blocks           | ModuleList          | 479 K  | train
4  | blocks.0         | NBEATSMoEBlock      | 165 K  | train
5  | blocks.0.gate    | Sequential          | 405    | train
6  | blocks.0.softmax | Softmax             | 0      | train
7  | blocks.0.pooling | SharedExpertPooling | 165 K  | train
8  | blocks.0.basis   | IdentityBasis       | 0      | train
9  | blocks.1         | NBEATSMoEBlock      | 137 K  | train
10 | blocks.1.gate    | Sequential          | 405    | train
11 | blocks.1.softmax | Softmax             | 0      | train
12 | blocks.1.pooling | SharedExpertPooling | 137 K  | train
13 | blocks.1.basis   | TrendBasis          | 162    | train
14 | blocks.2     

INFO:lightning_fabric.utilities.seed:Seed set to 20
INFO:lightning_fabric.utilities.seed:Seed set to 14


nbeats_m4_Quarterly_8
   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 175 K  | train
4  | blocks.0         | NBEATSBlock      | 61.0 K | train
5  | blocks.0.layers  | Sequential       | 61.0 K | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 55.7 K | train
8  | blocks.6.layers  | Sequential       | 55.6 K | train
9  | blocks.6.basis   | TrendBasis       | 144    | train
10 | blocks.12        | NBEATSBlock      | 59.1 K | train
11 | blocks.12.layers | Sequential       | 58.4 K | train
12 | blocks.12.basis  | SeasonalityBasis | 672    | train
---------------------------------------------------------------
174 K     Trainable params
816       N

INFO:lightning_fabric.utilities.seed:Seed set to 14
INFO:lightning_fabric.utilities.seed:Seed set to 13


   | Name             | Type                | Params | Mode 
------------------------------------------------------------------
0  | loss             | MAE                 | 0      | train
1  | padder_train     | ConstantPad1d       | 0      | train
2  | scaler           | TemporalNorm        | 0      | train
3  | blocks           | ModuleList          | 22.1 M | train
4  | blocks.0         | NBEATSMoEBlock      | 7.4 M  | train
5  | blocks.0.gate    | Sequential          | 361    | train
6  | blocks.0.softmax | Softmax             | 0      | train
7  | blocks.0.pooling | SharedExpertPooling | 7.4 M  | train
8  | blocks.0.basis   | IdentityBasis       | 0      | train
9  | blocks.3         | NBEATSMoEBlock      | 7.3 M  | train
10 | blocks.3.gate    | Sequential          | 361    | train
11 | blocks.3.softmax | Softmax             | 0      | train
12 | blocks.3.pooling | SharedExpertPooling | 7.3 M  | train
13 | blocks.3.basis   | TrendBasis          | 120    | train
14 | blocks.6     

INFO:lightning_fabric.utilities.seed:Seed set to 10
INFO:lightning_fabric.utilities.seed:Seed set to 9


   | Name             | Type             | Params | Mode 
---------------------------------------------------------------
0  | loss             | MAE              | 0      | train
1  | padder_train     | ConstantPad1d    | 0      | train
2  | scaler           | TemporalNorm     | 0      | train
3  | blocks           | ModuleList       | 2.4 M  | train
4  | blocks.0         | NBEATSBlock      | 810 K  | train
5  | blocks.0.layers  | Sequential       | 810 K  | train
6  | blocks.0.basis   | IdentityBasis    | 0      | train
7  | blocks.6         | NBEATSBlock      | 800 K  | train
8  | blocks.6.layers  | Sequential       | 800 K  | train
9  | blocks.6.basis   | TrendBasis       | 72     | train
10 | blocks.12        | NBEATSBlock      | 808 K  | train
11 | blocks.12.layers | Sequential       | 807 K  | train
12 | blocks.12.basis  | SeasonalityBasis | 240    | train
---------------------------------------------------------------
2.4 M     Trainable params
312       Non-trainable params
2.

INFO:lightning_fabric.utilities.seed:Seed set to 17
INFO:lightning_fabric.utilities.seed:Seed set to 15


   | Name             | Type                | Params | Mode 
------------------------------------------------------------------
0  | loss             | MAE                 | 0      | train
1  | padder_train     | ConstantPad1d       | 0      | train
2  | scaler           | TemporalNorm        | 0      | train
3  | blocks           | ModuleList          | 5.7 M  | train
4  | blocks.0         | NBEATSMoEBlock      | 1.9 M  | train
5  | blocks.0.gate    | Sequential          | 339    | train
6  | blocks.0.softmax | Softmax             | 0      | train
7  | blocks.0.pooling | SharedExpertPooling | 1.9 M  | train
8  | blocks.0.basis   | IdentityBasis       | 0      | train
9  | blocks.3         | NBEATSMoEBlock      | 1.9 M  | train
10 | blocks.3.gate    | Sequential          | 339    | train
11 | blocks.3.softmax | Softmax             | 0      | train
12 | blocks.3.pooling | SharedExpertPooling | 1.9 M  | train
13 | blocks.3.basis   | TrendBasis          | 108    | train
14 | blocks.6     

In [17]:
import json

summaries_pd = pd.DataFrame(
    {
        "model_name": [],
        "dataset": [],
        "freq": [],
        "nr_of_parameters": [],
        "nr_experts": [],
        "top_k": [],
        "average_active_parameters": [],
    }
)

all_results = {}  # Dictionary to store all results

for dataset in list_datasets:
    Y_ALL = load_dataset(dataset["name"], dataset)
    horizon = None

    for model_name in list_models:
        if isinstance(Y_ALL, tuple):
            Y_ALL, horizon, n_lags, dataset["group"], _ = Y_ALL
            dataset["freq"] = dataset["group"]
            dataset["name"] = dataset["name"].replace(" ", "_")

        if horizon is None:
            horizon = map_horizon_freq[dataset["freq"]]

        study_name = f"{model_name}_{dataset['name']}_{dataset['group']}_{horizon}"
        print(study_name)

        study = optuna.load_study(
            study_name=study_name,
            storage=STORAGE,
        )

        best_params = study.best_params

        # Special handling for specific models and datasets
        if model_name in ["nbeatsmoe", "nbeatsmoeshared"] and dataset["name"] in [
            "gluonts_tourism_yearly",
            "gluonts_m1_quarterly",
            "gluonts_m1_monthly",
            "gluonts_tourism_monthly"
        ]:
            if best_params.get("mlp_units", [[None]])[0][0] in [256, 512]:
                best_params["mlp_units"] = [[128, 128], [128, 128], [128, 128]]

        # Store in dictionary instead of file
        all_results[study_name] = best_params

        output_path = "c:/Users/ricar/mixture_of_experts_time_series/results/hyperparameters_config.json"

        with open(output_path, "w") as json_file:
            json.dump(all_results, json_file, indent=4)  # indent=4 makes it pretty-printed



INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_quarterly.


Loading m1_quarterly dataset...
nbeats_gluonts_m1_quarterly_Q_2
nbeatsmoe_gluonts_m1_quarterly_Q_2
nbeatsmoeshared_gluonts_m1_quarterly_Q_2
nbeatsstackmoe_gluonts_m1_quarterly_Q_2


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_monthly.


Loading m1_monthly dataset...


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\m1_yearly.


nbeats_gluonts_m1_monthly_M_8
nbeatsmoe_gluonts_m1_monthly_M_8
nbeatsmoeshared_gluonts_m1_monthly_M_8
nbeatsstackmoe_gluonts_m1_monthly_M_8
Loading m1_yearly dataset...
nbeats_gluonts_m1_yearly_Y_2
nbeatsmoe_gluonts_m1_yearly_Y_2
nbeatsmoeshared_gluonts_m1_yearly_Y_2
nbeatsstackmoe_gluonts_m1_yearly_Y_2


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_monthly.


Loading tourism_monthly dataset...
nbeats_gluonts_tourism_monthly_M_18
nbeatsmoe_gluonts_tourism_monthly_M_18
nbeatsmoeshared_gluonts_tourism_monthly_M_18
nbeatsstackmoe_gluonts_tourism_monthly_M_18


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_quarterly.


Loading tourism_quarterly dataset...
nbeats_gluonts_tourism_quarterly_Q_8
nbeatsmoe_gluonts_tourism_quarterly_Q_8
nbeatsmoeshared_gluonts_tourism_quarterly_Q_8
nbeatsstackmoe_gluonts_tourism_quarterly_Q_8


INFO:root:using dataset already processed in path C:\Users\ricar\.gluonts\datasets\tourism_yearly.


Loading tourism_yearly dataset...
nbeats_gluonts_tourism_yearly_Y_4
nbeatsmoe_gluonts_tourism_yearly_Y_4
nbeatsmoeshared_gluonts_tourism_yearly_Y_4
nbeatsstackmoe_gluonts_tourism_yearly_Y_4
Loading m3_monthly dataset...
nbeats_m3_Monthly_18
nbeatsmoe_m3_Monthly_18
nbeatsmoeshared_m3_Monthly_18
nbeatsstackmoe_m3_Monthly_18
Loading m3_monthly dataset...
nbeats_m3_Quarterly_8
nbeatsmoe_m3_Quarterly_8
nbeatsmoeshared_m3_Quarterly_8
nbeatsstackmoe_m3_Quarterly_8
Loading m3_monthly dataset...
nbeats_m3_Yearly_6
nbeatsmoe_m3_Yearly_6
nbeatsmoeshared_m3_Yearly_6
nbeatsstackmoe_m3_Yearly_6
Loading m4_monthly dataset...
nbeats_m4_Monthly_18
nbeatsmoe_m4_Monthly_18
nbeatsmoeshared_m4_Monthly_18
nbeatsstackmoe_m4_Monthly_18
Loading m4_monthly dataset...
nbeats_m4_Quarterly_8
nbeatsmoe_m4_Quarterly_8
nbeatsmoeshared_m4_Quarterly_8
nbeatsstackmoe_m4_Quarterly_8
Loading m4_monthly dataset...
nbeats_m4_Yearly_6
nbeatsmoe_m4_Yearly_6
nbeatsmoeshared_m4_Yearly_6
nbeatsstackmoe_m4_Yearly_6
