Skip to content

tripplyons/retentive-network

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Retentive Network (RetNet)

A minimal PyTorch implementation of Retentive Network: A Successor to Transformer for Large Language Models

Notes

  • This repository exists mostly for educational purposes, for both me and anyone else who wants to learn about RetNet.
  • It is basically a direct translation of the math in the paper, complex numbers and all. I haven't looked into it, but there are other implementations that claim to do it without needing complex numbers.
  • It makes heavy use of torch.einsum, so make sure you understand it before trying to understand this code.
  • I haven't implemented the chunkwise recurrent mode yet, this repo only has the parallel and the recurrent modes.

Usage

For more examples see test.py

import torch
from retnet import RetNet

model = RetNet(256, 64, 4, 4)

x = torch.randint(0, 256, (1, 64), dtype=torch.long)

print(model.loss(x))

About

A minimal PyTorch implementation of Retentive Network

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages