In [4]:
import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset
import os
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl

class RETRO_pl(pl.LightningModule):
    def __init__(self, retro):
        super().__init__()
        self.model = retro

    def training_step(self, batch):
        seq, retrieved = batch
        loss = retro(
            seq,
            retrieved,
            return_loss = True
        )
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-5)
        return optimizer


In [3]:
import torch
from retro_pytorch import RETRO, TrainingWrapper
# instantiate RETRO, fit it into the TrainingWrapper with correct settings

retro = RETRO(
    max_seq_len = 512,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
)

wrapper = TrainingWrapper(
    retro = retro,                                 # path to retro instance
    knn = 2,                                       # knn (2 in paper was sufficient)
    chunk_size = 64,                               # chunk size (64 in paper)
    documents_path = '../text_files',              # path to folder of text
    glob = '../**/*.txt',                             # text glob
    chunks_memmap_path = '../all_d/train.chunks.dat',     # path to chunks
    seqs_memmap_path = '../all_d/train.seq.dat',          # path to sequence data
    doc_ids_memmap_path = '../all_d/train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
    max_chunks = 1_000_000,                        # maximum cap to chunks
    max_seqs = 300_000,                            # maximum seqs
    knn_extra_neighbors = 100,                     # num extra neighbors to fetch
    max_index_memory_usage = '10G',
    current_memory_available = '100G',
    reprocess= True
)

# get the dataloader and optimizer (AdamW with all the correct settings)

train_dl = wrapper.get_dataloader(batch_size = 2, shuffle = True)
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)


found to be previously processed at processed-stats.json
preprocessed knn found at ../all_d/train.chunks.knn.dat, faiss index reconstituted from .tmp/.index/knn.index


In [5]:

model = RETRO_pl(wrapper.retro)

In [6]:
train_dl = wrapper.get_dataloader(batch_size = 16, shuffle = True)
# ds_torch = DataLoader(train_dl, batch_size = 2)

In [7]:
train_dl

<torch.utils.data.dataloader.DataLoader at 0x7efd477bbb50>

In [8]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(project="retro-finetuning")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msaisam1[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
trainer = pl.Trainer(max_epochs=10,
                     accelerator="gpu",
                     logger=wandb_logger,
                     precision=16,
                     default_root_dir="../run_no_mod",
                     accumulate_grad_batches=7)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
model

RETRO_pl(
  (model): RETRO(
    (token_emb): Embedding(28996, 896)
    (pos_emb): Embedding(512, 896)
    (to_decoder_model_dim): Linear(in_features=896, out_features=768, bias=True)
    (encoder): Encoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): PreNorm(
            (fn): Attention(
              (dropout): Dropout(p=0.0, inplace=False)
              (to_q): Linear(in_features=896, out_features=512, bias=False)
              (to_k): Linear(in_features=896, out_features=512, bias=False)
              (to_v): Linear(in_features=896, out_features=512, bias=False)
              (to_out): Linear(in_features=512, out_features=896, bias=True)
            )
            (norm): RMSNorm()
          )
          (1): PreNorm(
            (fn): Attention(
              (dropout): Dropout(p=0.0, inplace=False)
              (to_q): Linear(in_features=896, out_features=512, bias=False)
              (to_k): Linear(in_features=768, out_features=512, bias=False)
           

In [11]:
CKPT_PATH = "/workspace/RETRO/retro-finetuning/pwiqzvvr/checkpoints/epoch=4-step=4135.ckpt"

In [12]:

trainer.fit(model=model, train_dataloaders=train_dl,ckpt_path=CKPT_PATH)

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

Restoring states from the checkpoint path at /workspace/RETRO/retro-finetuning/pwiqzvvr/checkpoints/epoch=4-step=4135.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type  | Params
--------------------------------
0 | model | RETRO | 161 M 
--------------------------------
161 M     Trainable params
0         Non-trainable params
161 M     Total params
323.801   Total estimated model params size (MB)
Restored all states from the checkpoint file at /workspace/RETRO/retro-finetuning/pwiqzvvr/checkpoints/epoch=4-step=4135.ckpt
  rank_zero_

Epoch 5:   0%|          | 6/5783 [00:03<1:00:07,  1.60it/s, loss=5.74, v_num=p8gc]



Epoch 9: 100%|██████████| 5783/5783 [49:30<00:00,  1.95it/s, loss=5.35, v_num=p8gc]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 5783/5783 [49:32<00:00,  1.95it/s, loss=5.35, v_num=p8gc]


In [10]:

# encode prompt
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

prompt_str = "In Rainbows,the latest album from the British rock band "

prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1]

prompt = torch.tensor([prompt_ids])

sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0)

# decode sample
decoded = tokenizer.decode(sampled.tolist()[0])

print(decoded)

retrieved at 64 / 512
retrieved at 128 / 512
retrieved at 192 / 512
In Rainbows, the latest album from the British rock band and American - Ts. Reporting until million on it. For personalves's'politicians he had always gczher's's healthding of She showing as he was moving for holiday had a claim in Iran with a good and dressed. The last reports who are named Brad contracts in outside five years, 6, Vicurita - year - old who really their players of £10Vulse ; a statement. Thereh is the coach Court used while directed first's oil Aification " The field " for the players loan, " and trulyist said. Theyly make it Mike a two health hat. It will go down the Senior by ] State treatments and St announcing a sustainable as a The DaCmp. Melania a halt ji Updated : 44z Dan a title, the main Department could be in the couple's trying to diaper in the Glasgow's NiL to take after being held down club. J. 75 million War. Any software in the U. St.. June 2017 [SEP]


In [None]:
from retro_pytorch import RETRO, TrainingWrapper

In [12]:
wrapper = TrainingWrapper(
    retro = wrapper.retro,                                 # path to retro instance
    knn = 2,                                       # knn (2 in paper was sufficient)
    chunk_size = 64,                               # chunk size (64 in paper)
    documents_path = '../text_files',              # path to folder of text
    glob = '../**/*.txt',                             # text glob
    chunks_memmap_path = '../all_d/train.chunks.dat',     # path to chunks
    seqs_memmap_path = '../all_d/train.seq.dat',          # path to sequence data
    doc_ids_memmap_path = '../all_d/train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
    max_chunks = 1_000_000,                        # maximum cap to chunks
    max_seqs = 300_000,                            # maximum seqs
    knn_extra_neighbors = 100,                     # num extra neighbors to fetch
    max_index_memory_usage = '10G',
    current_memory_available = '100G',
    reprocess= True
)

found to be previously processed at processed-stats.json
preprocessed knn found at ../all_d/train.chunks.knn.dat, faiss index reconstituted from .tmp/.index/knn.index


In [13]:

# encode prompt
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

prompt_str = "In Rainbows,the latest album from the British rock band "

prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1]

prompt = torch.tensor([prompt_ids])

sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0, seq_len_max=30)

# decode sample
decoded = tokenizer.decode(sampled.tolist()[0])

print(decoded)

TypeError: generate() got an unexpected keyword argument 'seq_len_max'