In [34]:
ls ../all_d

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
train.chunks.dat  train.chunks.knn.dat  train.doc_ids.dat  train.seq.dat


In [2]:
import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset

# mock data constants

import numpy as np

NUM_CHUNKS = 1000
CHUNK_SIZE = 64
NUM_SEQS = 100
NUM_NEIGHBORS = 2

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_ds = RETRODataset(
    num_sequences = NUM_SEQS,
    num_chunks = NUM_CHUNKS,
    num_neighbors = NUM_NEIGHBORS,
    chunk_size = CHUNK_SIZE,
    seq_len = 512,
    chunk_memmap_path = '/workspace/all_d/train.chunks.dat',
    chunk_nn_memmap_path = '/workspace/all_d/train.chunks.knn.dat',
    seq_memmap_path = '/workspace/all_d/train.seq.dat'
)

In [4]:
train_ds.__len__()

100

In [5]:
train_dl = iter(DataLoader(train_ds, batch_size = 2))

In [6]:
retro = RETRO(
    max_seq_len = 512,                      # max sequence length
    enc_dim = 768,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = ( 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
)

seq, retrieved = map(lambda t: t.cuda(), next(train_dl))

In [10]:
seq.shape, retrieved.shape

(torch.Size([2, 513]), torch.Size([2, 8, 2, 128]))

In [11]:

loss = retro(
    seq,
    retrieved,
    return_loss = True
)

loss.backward()

In [12]:
loss

tensor(10.4309, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [62]:
import torch
from retro_pytorch import RETRO, TrainingWrapper

# instantiate RETRO, fit it into the TrainingWrapper with correct settings

retro = RETRO(
    max_seq_len = 512,                      # max sequence length
    enc_dim = 896,                           # encoder model dimension
    enc_depth = 3,                           # encoder depth
    dec_dim = 768,                           # decoder model dimensions
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
).cuda()

wrapper = TrainingWrapper(
    retro = retro,                                 # path to retro instance
    knn = 2,                                       # knn (2 in paper was sufficient)
    chunk_size = 64,                               # chunk size (64 in paper)
    documents_path = '/workspace/text_files',              # path to folder of text
    glob = '/workspace/**/*.txt',                             # text glob
    chunks_memmap_path = '/workspace/all_d/train.chunks.dat',     # path to chunks
    seqs_memmap_path = '/workspace/train.seq.dat',          # path to sequence data
    doc_ids_memmap_path = '/workspace/train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
    max_chunks = 1_000_000,                        # maximum cap to chunks
    max_seqs = 100_000,                            # maximum seqs
    knn_extra_neighbors = 100,                     # num extra neighbors to fetch
    max_index_memory_usage = '100m',
    current_memory_available = '1G',
    reprocess= True
)

# get the dataloader and optimizer (AdamW with all the correct settings)

found to be previously processed at processed-stats.json
preprocessed knn found at /workspace/all_d/train.chunks.knn.dat, faiss index reconstituted from .tmp/.index/knn.index


In [63]:

train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)


In [66]:
wrapper.retro

RETRO(
  (token_emb): Embedding(28996, 896)
  (pos_emb): Embedding(512, 896)
  (to_decoder_model_dim): Linear(in_features=896, out_features=768, bias=True)
  (encoder): Encoder(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (fn): Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=896, out_features=512, bias=False)
            (to_k): Linear(in_features=896, out_features=512, bias=False)
            (to_v): Linear(in_features=896, out_features=512, bias=False)
            (to_out): Linear(in_features=512, out_features=896, bias=True)
          )
          (norm): RMSNorm()
        )
        (1): PreNorm(
          (fn): Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=896, out_features=512, bias=False)
            (to_k): Linear(in_features=768, out_features=512, bias=False)
            (to_v): Linear(in_features=768, out_features=512, bias=False)


In [67]:
import os
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl

import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset

class RETRO_pl(pl.LightningModule):
    def __init__(self, retro):
        super().__init__()
        self.model = retro

    def training_step(self, batch):
        seq, retrieved = batch
        loss = retro(
            seq,
            retrieved,
            return_loss = True
        )
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-5)
        return optimizer


# init the autoencoder
model = RETRO_pl(retro)

In [5]:
import torch
from torch.utils.data import DataLoader
from retro_pytorch import RETRO, RETRODataset

# mock data constants

import numpy as np

NUM_CHUNKS = 1000
CHUNK_SIZE = 64
NUM_SEQS = 100
NUM_NEIGHBORS = 2

train_ds = RETRODataset(
    num_sequences = NUM_SEQS,
    num_chunks = NUM_CHUNKS,
    num_neighbors = NUM_NEIGHBORS,
    chunk_size = CHUNK_SIZE,
    seq_len = 512,
    chunk_memmap_path = '/workspace/all_d/train.chunks.dat',
    chunk_nn_memmap_path = '/workspace/all_d/train.chunks.knn.dat',
    seq_memmap_path = '/workspace/all_d/train.seq.dat'
)

In [3]:
ds_torch = DataLoader(train_ds, batch_size = 2)

In [4]:
trainer = pl.Trainer(max_epochs=20,accelerator="gpu")
trainer.fit(model=model, train_dataloaders=ds_torch)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type  | Params
--------------------------------
0 | model | RETRO | 147 M 
--------------------------------
147 M     Trainable params
0         Non-trainable params
147 M     Total params
591.561   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/25 [00:00<?, ?it/s] 



Epoch 19: 100%|██████████| 25/25 [00:05<00:00,  4.35it/s, loss=6.39, v_num=2]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 25/25 [00:08<00:00,  3.12it/s, loss=6.39, v_num=2]


In [7]:
# encode prompt
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

prompt_str = "With a 102-96 victory Friday night against Charlotte at the Izod Center, they put the finishing touch on their first four-game winning streak and climbed back to"


In [8]:
prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1]

prompt = torch.tensor([prompt_ids])


In [9]:

def top_k(logits, thres = 0.9):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

In [10]:
start = prompt
retrieved = None
filter_fn = top_k
filter_thres = 0.9
temperature = 1.0

In [43]:
assert filter_fn in {top_k}, 'filter function must be either top-k or nucleus'

device = "cuda"

def exists(val):
    return val is not None

if not exists(start):
    start = torch.full((1, 1), SOS_ID, device = device).long()
    
b, start_seq_len = start.shape

In [55]:
start = start.to(device)

if start_seq_len >= CHUNK_SIZE:
            seq_index = (start_seq_len // CHUNK_SIZE) * CHUNK_SIZE
            past_seq_chunks = rearrange(start[:, :seq_index], 'b (n c) -> (b n) c', c = CHUNK_SIZE)

            retrieved = fetch_knn_chunks_fn(past_seq_chunks)
            retrieved = rearrange(retrieved, '(b n) k c -> b n k c', b = b)

In [45]:
out = start

In [46]:
model.model.cuda()
# retrieved = retrieved.to(device)
out = out.to(device)

In [47]:
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))
from einops import rearrange
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

In [48]:
import numpy as np
from functools import partial
import json
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from retro_pytorch import RETRO, RETRODataset
from retro_pytorch.data import knn_to_retrieved_chunks
from retro_pytorch.optimizer import get_optimizer
from retro_pytorch.retrieval import text_folder_to_chunks_, chunks_to_precalculated_knn_, bert_embed, SOS_ID, EOS_ID
from retro_pytorch.utils import memmap, is_true_env_flag

from einops import rearrange


In [54]:
max_seq_len =512
for i in range(start_seq_len - 1, max_seq_len):

            logits =model.model( out, retrieved)
            logits = logits[:, i]

            logits = filter_fn(logits, thres = filter_thres)
            sampled = gumbel_sample(logits, temperature = temperature, dim = -1)
            sampled = rearrange(sampled, 'b -> b 1')

            out = torch.cat((out, sampled), dim = 1)

            # early terminate if all EOS

            is_eos_tokens = (out == EOS_ID)

            if is_eos_tokens.any(dim = -1).all():

                # mask out everything after the eos tokens

                shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                out = out.masked_fill(mask, model.model.pad_id)
                break

            # when the sequence length is a multiple of the chunk size
            # retrieve the next set of knns

            curr_seq_len = out.shape[-1]
            print(curr_seq_len)
            if (curr_seq_len % CHUNK_SIZE) == 0:
                print(curr_seq_len)
                last_chunk = rearrange(out, 'b (c n) -> b c n', n = CHUNK_SIZE)[:, -1]

                knn_chunks = fetch_knn_chunks_fn(last_chunk)

                # concat retrieved knn chunks to all retrieved
                # to be sent to Retro for chunked cross attention at the next iteration

                knn_chunks = rearrange(knn_chunks, 'b k r -> b 1 k r')
                retrieved = safe_cat(retrieved, knn_chunks, dim = 1)

                print(f'retrieved at {curr_seq_len} / {max_seq_len}')

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
128


Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_main
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


AttributeError: 'str' object has no attribute 'search'

In [59]:
knn_memmap_path, faiss_index = chunks_to_precalculated_knn_(
            num_chunks = 1000,
            chunk_size = 64,
        chunk_memmap_path = '/workspace/all_d/train.chunks.dat',
            chunks_memmap_path = '/workspace/all_d/train.chunks.dat',    # path to chunks
            seqs_memmap_path = '/workspace/train.seq.dat',          # path to sequence data
            doc_ids_memmap_path = '/workspace/train.doc_ids.dat', 
            num_nearest_neighbors = 2,
            num_extra_neighbors = 100,
            index_file = "/workspace/all_d/train.chunks.knn.dat",
            force_reprocess = False
        )

preprocessed knn found at /workspace/all_d/train.chunks.knn.dat, faiss index reconstituted from /workspace/all_d/train.chunks.knn.dat


RuntimeError: Error in faiss::Index* faiss::read_index(faiss::IOReader*, int) at /project/faiss/faiss/impl/index_read.cpp:796: Index type 0x0000000b ("\x0b\x00\x00\x00") not recognized

In [52]:
from retro_pytorch.data import knn_to_retrieved_chunks
from functools import partial
from retro_pytorch.training import knn_chunks_from_seq_chunks

In [53]:
fetch_knn_chunks_fn = partial(
            knn_chunks_from_seq_chunks,
            knn = 2,
            chunk_size = 64,
            num_chunks = NUM_CHUNKS,
            chunks_memmap_path = '/workspace/all_d/train.chunks.dat',
            faiss_index = "/workspace/all_d/train.chunks.knn.dat"
        )

In [None]:
b, start_seq_len = start.shape

        # move onto same device as RETRO

        start = start.to(device)

        # prepare retrieval related variables

        if start_seq_len >= CHUNK_SIZE:
            seq_index = (start_seq_len // self.chunk_size) * self.chunk_size
            past_seq_chunks = rearrange(start[:, :seq_index], 'b (n c) -> (b n) c', c = self.chunk_size)

            retrieved = self.fetch_knn_chunks_fn(past_seq_chunks)
            retrieved = rearrange(retrieved, '(b n) k c -> b n k c', b = b)

        # get starting sequence index

        out = start

        # sampling loop

        for i in range(start_seq_len - 1, self.max_seq_len):

            logits = self.retro(out, retrieved = retrieved)
            logits = logits[:, i]

            logits = filter_fn(logits, thres = filter_thres)
            sampled = gumbel_sample(logits, temperature = temperature, dim = -1)
            sampled = rearrange(sampled, 'b -> b 1')

            out = torch.cat((out, sampled), dim = 1)

            # early terminate if all EOS

            is_eos_tokens = (out == EOS_ID)

            if is_eos_tokens.any(dim = -1).all():

                # mask out everything after the eos tokens

                shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                out = out.masked_fill(mask, self.retro.pad_id)
                break

            # when the sequence length is a multiple of the chunk size
            # retrieve the next set of knns

            curr_seq_len = out.shape[-1]

            if (curr_seq_len % self.chunk_size) == 0:
                last_chunk = rearrange(out, 'b (c n) -> b c n', n = self.chunk_size)[:, -1]

                knn_chunks = self.fetch_knn_chunks_fn(last_chunk)

                # concat retrieved knn chunks to all retrieved
                # to be sent to Retro for chunked cross attention at the next iteration

                knn_chunks = rearrange(knn_chunks, 'b k r -> b 1 k r')
                retrieved = safe_cat(retrieved, knn_chunks, dim = 1)

                print(f'retrieved at {curr_seq_len} / {self.max_seq_len}')

        return out