Skip to content

Commit

Permalink
Fix uninitialized parameter in conformer relative attention. (hugging…
Browse files Browse the repository at this point in the history
…face#18368)

`torch.Tensor` creates an unitialized tensor (as via `torch.empty`), this leads to undeterministic behavior, poor initialization, and nans if you have unlucky init. The paper does not specify the initialization for bias terms, so I guess zero seems like a good choice - no bias initially. `torch.Tensor` is usually populated with zeros, so this fix will be close to the intended behavior:

```
>>> torch.Tensor(100, 100).sum()
tensor(0.)
>>> torch.Tensor(100, 100).sum()
tensor(nan)
>>> torch.Tensor(100, 100).sum()
tensor(0.)
```
  • Loading branch information
Piotr Dabkowski committed Aug 2, 2022
1 parent df5e423 commit 68a894a
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ def __init__(self, config):
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))

def forward(
self,
Expand Down

0 comments on commit 68a894a

Please sign in to comment.