In [1]:
import torch

from model import RetNetModel, RetNetModelWithLMHead, RetNetConfig

%load_ext autoreload
%autoreload 2

In [2]:
torch.manual_seed(0)
config = RetNetConfig(num_layers=8, vocab_size=100, hidden_size=512, num_heads=4, use_default_gamma=False, chunk_size=4)
model = RetNetModel(config)
model.eval()

device = 'mps'  # cuda, cpu, mps for m1 mac
model = model.to(device)

In [3]:
model.blocks[0].msr.decay

Parameter containing:
tensor([0.9688, 0.9876, 0.9951, 0.9980], device='mps:0')

In [4]:
input_ids = torch.LongTensor([[1,2,3,4, 1,2,3,4]]).to(device)

out, parallel_past_kv = model(input_ids, forward_impl='parallel', return_kv=True)

past_kv = None
rnn_outs = []
for i in range(input_ids.shape[1]):
    rnn_out, past_kv = model(input_ids[:, i:i+1], forward_impl='recurrent', past_kv=past_kv, return_kv=True, sequence_offset=i)
    rnn_outs.append(rnn_out)
rnn_outs = torch.cat(rnn_outs, dim=1)

chunk_out, chunk_past_kv = model(input_ids, forward_impl='chunkwise', return_kv=True)

print(torch.allclose(out, rnn_outs, atol=1e-5))
print(torch.allclose(out, chunk_out, atol=1e-5))
print(torch.allclose(rnn_outs, chunk_out, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_past_kv, past_kv, chunk_past_kv)):
    print(f"layer: {i + 1}")
    print(torch.allclose(p, r, atol=1e-5))
    print(torch.allclose(p, c, atol=1e-5))
    print(torch.allclose(r, c, atol=1e-5))


True
True
True
layer: 1
True
True
True
layer: 2
True
True
True
layer: 3
True
True
True
layer: 4
True
True
True
layer: 5
True
True
True
layer: 6
True
True
True
layer: 7
True
True
True
layer: 8
True
True
True


In [5]:

torch.manual_seed(0)
config = RetNetConfig(num_layers=8, vocab_size=100, hidden_size=512, num_heads=4, use_default_gamma=False, chunk_size=4)
model = RetNetModelWithLMHead(config).to(device)
model.eval()

p_generated = model.generate(input_ids, parallel_compute_prompt=True, max_new_tokens=20, do_sample=False, early_stopping=False)
r_generated = model.generate(input_ids, parallel_compute_prompt=False, max_new_tokens=20, do_sample=False, early_stopping=False)

p_generated, r_generated


(tensor([[27, 17, 36, 96, 79, 51, 87, 97, 70, 16, 11,  2, 93, 68,  4, 62,  4, 62,
           4, 62]], device='mps:0'),
 tensor([[27, 17, 36, 96, 79, 51, 87, 97, 70, 16, 11,  2, 93, 68,  4, 62,  4, 62,
           4, 62]], device='mps:0'))