Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rnn: hx is not contiguous[BUG] #730

Closed
korosig opened this issue Jan 11, 2022 · 2 comments
Closed

rnn: hx is not contiguous[BUG] #730

korosig opened this issue Jan 11, 2022 · 2 comments
Labels
bug Something isn't working duplicate This issue or pull request already exists

Comments

@korosig
Copy link

korosig commented Jan 11, 2022

Hello,

Nice work, I like your package.

After a few days of waiting, I finally managed to install all the necessary book transfers on my computer.
When my Torch was running on CPU only, the templates ran fine, but now I wanted to test the program on GPU and got the following Error:

rnn: hx is not contiguous

Template:
https://github.com/h3ik0th/TFT_darts/blob/main/TFT_2g5.ipynb

System :
absl-py
anyio==3.4.0
argcomplete
argon2-cffi
argon2-cffi-bindings
arviz
astor
async-generator==1.10
attrs
Babel==2.9.1
backcall
backports.functools-lru-cache
bleach
bokeh
cached-property
cachetools==4.2.4
certifi==2021.10.8
cffi
cftime
charset-normalizer==2.0.10
click
cloudpickle
cmdstanpy==0.9.68
colorama
convertdate
cycler
Cython
cytoolz==0.11.2
darts==0.15.0
dask
dataclasses
debugpy
decorator
defusedxml
distributed
docopt==0.6.2
entrypoints
ephem
filterpy==1.4.5
flit_core
fonttools
fsspec
gast
google-auth==2.3.3
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio
h5py
HeapDict==1.0.1
hijri-converter
holidays
idna==3.3
importlib-metadata
importlib-resources
ipykernel
ipython
ipython-genutils==0.2.0
ipywidgets
jedi
Jinja2
joblib==1.1.0
json5==0.9.6
jsonschema
jupyter-client
jupyter-core
jupyter-server==1.13.1
jupyterlab==3.2.6
jupyterlab-pygments
jupyterlab-server==2.10.3
jupyterlab-widgets
Keras-Applications==1.0.8
Keras-Preprocessing
kiwisolver
korean-lunar-calendar
lightgbm==3.3.2
locket==0.2.0
LunarCalendar==0.0.9
Markdown
MarkupSafe
matplotlib==3.5.1
matplotlib-inline
mistune
msgpack
munkres==1.1.4
nbclassic==0.3.4
nbclient
nbconvert
nbformat
nest-asyncio
netCDF4
notebook
numpy==1.21.5
oauthlib==3.1.1
olefile
packaging
pandas
pandocfilters
parso
partd
patsy==0.5.2
pickleshare
Pillow==9.0.0
pipreqs==0.4.11
pmdarima==1.8.4
prometheus-client rk
prompt-toolkit
prophet
protobuf==3.19.2
psutil
pyarrow==6.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser
Pygments
PyMeeus==0.5.11
pyparsing
PyQt5==5.12.3
PyQt5_sip==4.19.18
PyQtChart==5.12
PyQtWebEngine==5.12.1
pyrsistent
pystan==2.19.1.1
python-dateutil
pytz
pywin32==303
pywinpty
PyYAML
pyzmq
requests==2.27.1
requests-oauthlib==1.3.0
rsa==4.8
scikit-learn==1.0.2
scipy
Send2Trash
setuptools-git==1.2
six
sniffio==1.2.0
sortedcontainers
statsmodels==0.13.1
tblib
tensorboard==1.15.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow 1.14l
tensorflow-estimator==1.14.0
termcolor==1.1.0
terminado
testpath
threadpoolctl==3.0.0
toolz
torch==1.10.1+cu113
torchaudio==0.10.1+cu113
torchvision==0.11.2+cu113
tornado
tqdm
traitlets
typing_extensions==4.0.1
ujson==5.1.0
unicodedata2
urllib3==1.26.8
wcwidth
webencodings==0.5.1
websocket-client==1.2.3
Werkzeug==2.0.2
widgetsnbextension k
wrapt
xarray
yarg==0.1.9
zict==2.0.0
zipp

Error :

RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_29804/2078052997.py in
2 model.fit( ts_ttrain,
3 future_covariates=tcov,
----> 4 verbose=True)

~\anaconda3\envs\TFT-test\lib\site-packages\darts\utils\torch.py in decorator(self, *args, **kwargs)
63 with fork_rng():
64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65 return decorated(self, *args, **kwargs)
66 return decorator

~\anaconda3\envs\TFT-test\lib\site-packages\darts\models\forecasting\torch_forecasting_model.py in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, verbose, epochs, max_samples_per_ts, num_loader_workers)
477 logger.info('Train dataset contains {} samples.'.format(len(train_dataset)))
478
--> 479 self.fit_from_dataset(train_dataset, val_dataset, verbose, epochs, num_loader_workers)
480
481 @Property

~\anaconda3\envs\TFT-test\lib\site-packages\darts\utils\torch.py in decorator(self, *args, **kwargs)
63 with fork_rng():
64 manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 65 return decorated(self, *args, **kwargs)
66 return decorator

~\anaconda3\envs\TFT-test\lib\site-packages\darts\models\forecasting\torch_forecasting_model.py in fit_from_dataset(self, train_dataset, val_dataset, verbose, epochs, num_loader_workers)
591
592 # Train model
--> 593 self._train(train_loader, val_loader, tb_writer, verbose, train_num_epochs)
594
595 # Close tensorboard writer

~\anaconda3\envs\TFT-test\lib\site-packages\darts\models\forecasting\torch_forecasting_model.py in _train(self, train_loader, val_loader, tb_writer, verbose, epochs)
886 self.model.train()
887 train_batch = self._batch_to_device(train_batch)
--> 888 output = self._produce_train_output(train_batch[:-1])
889 target = train_batch[-1] # By convention target is always the last element returned by datasets
890 loss = self._compute_loss(output, target)

~\anaconda3\envs\TFT-test\lib\site-packages\darts\models\forecasting\tft_model.py in _produce_train_output(self, input_batch)
800
801 def _produce_train_output(self, input_batch: Tuple):
--> 802 return self.model(input_batch)
803
804 def predict(self, n, *args, **kwargs):

~\anaconda3\envs\TFT-test\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

~\anaconda3\envs\TFT-test\lib\site-packages\darts\models\forecasting\tft_model.py in forward(self, x)
447
448 # run local lstm encoder
--> 449 encoder_out, (hidden, cell) = self.lstm_encoder(input=embeddings_varying_encoder, hx=(input_hidden, input_cell))
450
451 # run local lstm decoder

~\anaconda3\envs\TFT-test\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

~\anaconda3\envs\TFT-test\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
690 if batch_sizes is None:
691 result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
--> 692 self.dropout, self.training, self.bidirectional, self.batch_first)
693 else:
694 result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,

RuntimeError: rnn: hx is not contiguous

@korosig korosig added bug Something isn't working triage Issue waiting for triaging labels Jan 11, 2022
@korosig
Copy link
Author

korosig commented Jan 11, 2022

actually I don't know if this is a good solution, but in TFT_model.py line 449 I overwrote the

run local lstm encoder
encoder_out, (hidden, cell) = self.lstm_encoder(input=embeddings_varying_encoder, hx=(input_hidden, input_cell))

to

run local lstm encoder
encoder_out, (hidden, cell) = self.lstm_encoder(input=embeddings_varying_encoder, hx=(input_hidden.contiguous(), input_cell.contiguous()))

@dennisbader dennisbader added duplicate This issue or pull request already exists bug Something isn't working and removed bug Something isn't working triage Issue waiting for triaging labels Jan 11, 2022
@dennisbader
Copy link
Collaborator

Hi @korosig and thanks for writing.

We are working on it.
I close this as it is a duplicate of #722.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants