In [5]:
import torch

from model import RetNetModel, RetNetModelWithLMHead, RetNetConfig

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
torch.manual_seed(0)
config = RetNetConfig(num_layers=8, vocab_size=100, hidden_size=512, num_heads=4, use_default_gamma=False, chunk_size=4)
model = RetNetModel(config)
model.eval()

RetNetModel(
  (embedding): Embedding(100, 512)
  (blocks): ModuleList(
    (0-7): 8 x RetNetBlock(
      (msr): MultiScaleRetention(
        (qkv): Linear(in_features=512, out_features=2048, bias=False)
        (silu): SiLU()
        (gated): Linear(in_features=512, out_features=1024, bias=False)
        (proj): Linear(in_features=1024, out_features=512, bias=True)
        (gn): GroupNorm(4, 1024, eps=1e-05, affine=False)
        (xpos): XPOS()
      )
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=512, bias=True)
      )
      (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [7]:
input_ids = torch.LongTensor([[1,2,3,4, 1,2,3,4]])

out, parallel_past_kv = model(input_ids, forward_impl='parallel', return_kv=True)

past_kv = None
rnn_outs = []
for i in range(input_ids.shape[1]):
    rnn_out, past_kv = model(input_ids[:, i:i+1], forward_impl='recurrent', past_kv=past_kv, return_kv=True, sequence_offset=i)
    rnn_outs.append(rnn_out)
rnn_outs = torch.cat(rnn_outs, dim=1)

chunk_out, chunk_past_kv = model(input_ids, forward_impl='chunkwise', return_kv=True)

print(torch.allclose(out, rnn_outs, atol=1e-5))
print(torch.allclose(out, chunk_out, atol=1e-5))
print(torch.allclose(rnn_outs, chunk_out, atol=1e-5))

for i, (p, r, c) in enumerate(zip(parallel_past_kv, past_kv, chunk_past_kv)):
    print(f"layer: {i + 1}")
    print(torch.allclose(p, r, atol=1e-5))
    print(torch.allclose(p, c, atol=1e-5))
    print(torch.allclose(r, c, atol=1e-5))


True
True
True
layer: 1
True
True
True
layer: 2
True
True
True
layer: 3
True
True
True
layer: 4
True
True
True
layer: 5
True
True
True
layer: 6
True
True
True
layer: 7
True
True
True
layer: 8
True
True
True


In [8]:

torch.manual_seed(0)
config = RetNetConfig(num_layers=8, vocab_size=100, hidden_size=512, num_heads=4, use_default_gamma=False, chunk_size=4)
model = RetNetModelWithLMHead(config)
model.eval()

p_generated = model.generate(input_ids, parallel_compute_prompt=True, max_new_tokens=20, do_sample=False, early_stopping=False)
r_generated = model.generate(input_ids, parallel_compute_prompt=False, max_new_tokens=20, do_sample=False, early_stopping=False)

p_generated, r_generated


(tensor([[27, 17, 36, 96, 79, 51, 87, 97, 70, 16, 11,  2, 93, 68,  4, 62,  4, 62,
           4, 62]]),
 tensor([[27, 17, 36, 96, 79, 51, 87, 97, 70, 16, 11,  2, 93, 68,  4, 62,  4, 62,
           4, 62]]))