Skip to content

Commit

Permalink
Merge pull request #23 from preinaj/transformer-fixed
Browse files Browse the repository at this point in the history
transformer bug fixed
  • Loading branch information
pedrolarben committed Mar 27, 2022
2 parents fda2093 + 6833b75 commit 5f163f7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ADLStream/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def predict_step(self, data):
if i != self.target_shape[0] - 1:
output = tf.gather(output, [output.shape[1] - 1], axis=1)
tar_inp = tf.concat([tar_inp, output], axis=1)
output = tf.squeeze(output, axis=[2])
return output


Expand Down Expand Up @@ -525,8 +526,8 @@ def Transformer(
output_size (int): Number of neurons of the last layer.
loss (tf.keras.Loss): Loss to be use for training.
optimizer (tf.keras.Optimizer): Optimizer that implements the training algorithm.
Use "custom" in order to use a customize optimizer for the transformer model.
output_shape (tuple): Shape of the output data.
Use "custom" in order to use a customize optimizer for the transformer model.
output_shape (tuple): Shape of the output data. Must be [forecasting_horizon,1].
attribute (list): Ordered list of the indexes of the attributes that we want to predict, if the number of
attributes of the input is different from the ones of the output.
Defaults to None.
Expand Down

0 comments on commit 5f163f7

Please sign in to comment.