Skip to content

Commit

Permalink
Only set require_grad for gradient checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
ceshine committed Apr 24, 2021
1 parent acaeee6 commit 658bdd0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False):
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)

def prune_heads(self, heads):
if len(heads) == 0:
Expand Down Expand Up @@ -486,7 +487,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.training:
if self.training and self.gradient_checkpointing:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length)
Expand Down

0 comments on commit 658bdd0

Please sign in to comment.