Skip to content

Commit

Permalink
Make RotaryPositionalEmbedding jit-compatible (#5237)
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorLakomkin committed Jul 7, 2023
1 parent 31fba01 commit 100cd91
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
15 changes: 7 additions & 8 deletions fairseq/modules/rotary_positional_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,26 @@ def __init__(self, dim, base=10000, precision=torch.half):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.seq_len_cached = 0
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
self.precision = precision

def forward(self, x, seq_len=None):
def forward(self, x, seq_len: int = 0):
"""
Args:
x: Input x with T X B X C
seq_len: Sequence length of input x
"""
if seq_len != self.seq_len_cached:
if seq_len > self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
return self.cos_cached, self.sin_cached


# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
Expand Down
21 changes: 21 additions & 0 deletions tests/test_rotary_positional_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def test_apply_rotary_pos_emb(self):
)
)

def test_jit_compile_rope_module(self):
module_scripted = torch.jit.script(self.rope_pos_emd)
apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
# Test several different lengths
for T in [3, 5, 10]:
sample = torch.randn(T, self.B, self.C)
# Run forward pass with the original module
cos_original, sin_original = self.rope_pos_emd(sample, T)
query = sample.view(T, self.B, 1, self.C)
new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)

# Run forward pass with the scripted module
cos_scripted, sin_scripted = module_scripted(sample, T)
new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)

# Ensure the outputs are the same
self.assertTrue(torch.allclose(cos_original, cos_scripted))
self.assertTrue(torch.allclose(sin_original, sin_scripted))
self.assertTrue(torch.allclose(new_query, new_query_scripted))
self.assertTrue(torch.allclose(new_key, new_key_scripted))


if __name__ == "__main__":
unittest.main()

0 comments on commit 100cd91

Please sign in to comment.