In [1]:
import torch

from retnet.configuration_retnet import RetNetConfig
from retnet.modeling_retnet import RetNetModel, RetNetModelWithLMHead

%load_ext autoreload
%autoreload 2

In [5]:
torch.manual_seed(0)
config = RetNetConfig(decoder_layers=2,
                      decoder_embed_dim=8,
                      decoder_retention_heads=2,
                      decoder_ffn_embed_dim=16)

model = RetNetModel(config)
model.eval()

device = 'cpu'  # cuda, cpu, mps for m1 mac
model = model.to(device)

In [6]:
input_ids = torch.LongTensor([[1,2,1,2]]).to(device)

parallel_outputs = model(input_ids, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values

past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, :i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True)
    rnn_state.append(rnn_out.last_hidden_state)
    past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values


chunk_outputs = model(input_ids, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=2)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

print(torch.allclose(parallel_state, rnn_state, atol=1e-5))
# print(torch.allclose(parallel_state, chunk_state, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_cache, rnn_cache, chunk_cache)):
    print(f"layer: {i + 1}")
    for key in p.keys():
        print(torch.allclose(p[key], r[key], atol=1e-5))
        # print(torch.allclose(p[key], c[key], atol=1e-5))

True
layer: 1
True
True
layer: 2
True
True


In [8]:
input_ids = torch.LongTensor([[1,2,3,4,1,2,5,5],
                              [5,5,1,2,3,4,1,2]]).to(device)
retention_mask = torch.LongTensor([[1,1,1,1,1,1,0,0],
                                   [0,0,1,1,1,1,1,1]]).to(device)

parallel_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values

past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
    rnn_out = model(input_ids[:, :i+1], retention_mask=retention_mask[:, i:i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True)
    rnn_state.append(rnn_out.last_hidden_state)
    past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values


chunk_outputs = model(input_ids, retention_mask=retention_mask, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

mask = retention_mask.unsqueeze(-1).float()
print(torch.allclose(parallel_state * mask, rnn_state * mask, atol=1e-5))
# print(torch.allclose(parallel_state * mask, chunk_state * mask, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_cache, rnn_cache, chunk_cache)):
    print(f"layer: {i + 1}")
    for key in p.keys():
        print(torch.allclose(p[key], r[key], atol=1e-5))
        # print(torch.allclose(p[key], c[key], atol=1e-5))


True
layer: 1
True
True
layer: 2
True
True


In [9]:
torch.manual_seed(0)
model = RetNetModelWithLMHead(config).to(device)
model.eval()

p_generated = model.generate(input_ids, parallel_compute_prompt=True, max_new_tokens=20, do_sample=False, early_stopping=False)
r_generated = model.generate(input_ids, parallel_compute_prompt=False, max_new_tokens=20, do_sample=False, early_stopping=False)

p_generated, r_generated


(tensor([[    5, 20137, 50121, 14818, 14818, 14818, 14818, 14818, 14818, 14818,
          14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818,
          14818],
         [    2,  2622, 14777, 47757, 47757, 47757, 47757, 47757, 47757, 47757,
          47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757,
          47757]]),
 tensor([[    5, 20137, 50121, 14818, 14818, 14818, 14818, 14818, 14818, 14818,
          14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818, 14818,
          14818],
         [    2,  2622, 14777, 47757, 47757, 47757, 47757, 47757, 47757, 47757,
          47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757, 47757,
          47757]]))