Skip to content

Commit

Permalink
[ENH] Added test parameters for the LSTM FCNN network (#6281)
Browse files Browse the repository at this point in the history
I have added 2 test parameters for the LSTM-FCNN network returned via
the `get_test_params` function. The test parameters aim to explore 2
quite different variations of the model in the hope to capture a broad
spectrum of the model's parameters.
  • Loading branch information
shlok191 committed Apr 14, 2024
1 parent 4018337 commit 0185d5c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Expand Up @@ -2756,6 +2756,15 @@
"doc"
]
},
{
"login": "ssabarwal",
"name": "Shlok Sabarwal",
"avatar_url": "https://gravatar.com/avatar/cbdbaac712ae282d730cd3028e862d45?s=400&d=robohash&r=x",
"prifle": "https://www.github.com/shlok191/",
"contributions": [
"code"
]
},
{
"login": "mobley-trent",
"name": "Eddy Oyieko",
Expand Down
40 changes: 40 additions & 0 deletions sktime/networks/lstmfcn.py
Expand Up @@ -123,3 +123,43 @@ def build_network(self, input_shape, **kwargs):
output_layer = keras.layers.concatenate([x, y])

return input_layer, output_layer

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params = [
# Advanced model version
{
"kernel_sizes": (8, 5, 3), # Keep standard kernel sizes
"filter_sizes": (128, 256, 128), # Keep standard kernel counts
"lstm_size": 8,
"dropout": 0.25, # Maintain lower dropout rate for attention model
"attention": True,
},
# Simpler model version
{
"kernel_sizes": (4, 2, 1), # Reduce kernel sizes
"filter_sizes": (32, 64, 32), # Reduc filter sizes for cheaper model
"lstm_size": 8, # Keeping LSTM output size fixed
"dropout": 0.75, # Maintain higher dropout rate for non attention model
"attention": False,
},
{},
]

return params

0 comments on commit 0185d5c

Please sign in to comment.