Skip to content

Commit

Permalink
fixed typo
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 18, 2020
1 parent 8e9f31e commit 2649874
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion pts/model/deepvar/deepvar_network.py
Expand Up @@ -61,6 +61,7 @@ def __init__(
batch_first=True,
)

self.target_shape = distr_output.event_shape
self.proj_dist_args = distr_output.get_args_proj(num_cells)

self.embed = FeatureEmbedder(
Expand Down Expand Up @@ -147,7 +148,14 @@ def unroll(
embedded_cat = self.embed(feat_static_cat)
# assert_shape(index_embeddings, (-1, self.target_dim, self.embed_dim))

static_feat = torch.cat((embedded_cat, feat_static_real, scale.log()), dim=1)
static_feat = torch.cat(
(
embedded_cat,
feat_static_real,
scale.log() if len(self.target_shape) == 0 else scale.squeeze(1).log(),
),
dim=1,
)

# (batch_size, seq_len, embed_dim)
repeated_static_feat = static_feat.unsqueeze(1).expand(-1, unroll_length, -1)
Expand Down

0 comments on commit 2649874

Please sign in to comment.