In [1]:
import torch

from retnet.configuration_retnet import RetNetConfig
from retnet.modeling_retnet import RetNetModel, RetNetModelWithLMHead

%load_ext autoreload
%autoreload 2

In [6]:
torch.manual_seed(0)
config = RetNetConfig(decoder_layers=8,
                      decoder_embed_dim=512,
                      decoder_retention_heads=4,
                      decoder_ffn_embed_dim=1024)

model = RetNetModel(config)
model.eval()

device = 'cuda'  # cuda, cpu, mps for m1 mac
model = model.to(device)

In [7]:
input_ids = torch.LongTensor([[1,2,3,4,1,2,3,4]]).to(device)

parallel_outputs = model(input_ids, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values

past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, i:i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True, sequence_offset=i)
    rnn_state.append(rnn_out.last_hidden_state)
    past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values


chunk_outputs = model(input_ids, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

print(torch.allclose(parallel_state, rnn_state, atol=1e-5))
print(torch.allclose(parallel_state, chunk_state, atol=1e-5))
print(torch.allclose(rnn_state, chunk_state, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_cache, rnn_cache, chunk_cache)):
    print(f"layer: {i + 1}")
    print(torch.allclose(p, r, atol=1e-5))
    print(torch.allclose(p, c, atol=1e-5))
    print(torch.allclose(r, c, atol=1e-5))


True
True
True
layer: 1
True
True
True
layer: 2
True
True
True
layer: 3
True
True
True
layer: 4
True
True
True
layer: 5
True
True
True
layer: 6
True
True
True
layer: 7
True
True
True
layer: 8
True
True
True


In [8]:
input_ids = torch.LongTensor([[1,2,3,4,1,2,3,4]]).to(device)
retention_mask = torch.LongTensor([[1,1,1,1,1,0,0,0]]).to(device)

parallel_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values

past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, i:i+1], retention_mask=retention_mask[:, i:i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True, sequence_offset=i)
    rnn_state.append(rnn_out.last_hidden_state)
    past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values


chunk_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

print(torch.allclose(parallel_state, rnn_state, atol=1e-5))
print(torch.allclose(parallel_state, chunk_state, atol=1e-5))
print(torch.allclose(rnn_state, chunk_state, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_cache, rnn_cache, chunk_cache)):
    print(f"layer: {i + 1}")
    print(torch.allclose(p, r, atol=1e-5))
    print(torch.allclose(p, c, atol=1e-5))
    print(torch.allclose(r, c, atol=1e-5))


True
True
True
layer: 1
True
True
True
layer: 2
True
True
True
layer: 3
True
True
True
layer: 4
True
True
True
layer: 5
True
True
True
layer: 6
True
True
True
layer: 7
True
True
True
layer: 8
True
True
True


In [9]:
torch.manual_seed(0)
model = RetNetModelWithLMHead(config).to(device)
model.eval()

p_generated = model.generate(input_ids, parallel_compute_prompt=True, max_new_tokens=20, do_sample=False, early_stopping=False)
r_generated = model.generate(input_ids, parallel_compute_prompt=False, max_new_tokens=20, do_sample=False, early_stopping=False)

p_generated, r_generated


(tensor([[29772, 22884,   842, 49295, 42077,  1681, 43525, 38407, 50177, 24993,
          27512, 35252, 28429,  6718, 36836, 24775, 42771,    46, 13646,  2228]],
        device='cuda:0'),
 tensor([[29772, 22884,   842, 49295, 42077,  1681, 43525, 38407, 50177, 24993,
          27512, 35252, 28429,  6718, 36836, 24775, 42771,    46, 13646,  2228]],
        device='cuda:0'))