In [1]:
import jax
import jax.numpy as jnp
from functools import partial 

from deepscratch.dataset.base import Dataset
from deepscratch.models.sequence.tokeniser import WordTokeniser, Tokeniser
from deepscratch.models.sequence.embeddings import Embedder, OHE, LSA

deepscratch.dataset.base
deepscratch.models.sequence.tokeniser
deepscratch.models.sequence.embeddings


# Sequential datasets

Sequential datasets are iterables over Corpus objects.
 - Seq2Vec yields the next thing in the corpus along with the next y value

In [3]:
with open("../__deepscratchcache__/thetimemachine.pkl", "rb") as f:
    embedder = LSA.from_cache(f)



Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



I0000 00:00:1739968881.207415 8712519 service.cc:145] XLA service 0x104255280 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739968881.207427 8712519 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1739968881.208594 8712519 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1739968881.208604 8712519 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


In [None]:
class SequenceDataset(Dataset):
    def __init__(
        self,
        f,
        input_len,
        output_len,
        tokeniser: Tokeniser,
        embedder: Embedder,
    ):
        self.input_len = input_len
        self.output_len = output_len
        self.tokeniser = tokeniser
        self.embedder = embedder

        # Read the full text and tokenise
        f.seek(0)
        self.tokens = jnp.array(tokeniser.tokenise(f.read()))
        f.close()
        
        # Bind params
        self._jitgetitem = partial(
            self._jitgetitem,
            tokens=self.tokens,
            input_len=self.input_len,
            output_len=self.output_len,
            embedder_func=self.embedder.embed
        )

    def __len__(self) -> int:
        """
        Returns the number of possible sequences (windows) available.
        For index i, __getitem__(i) will return tokens[i : i + window_len].
        """
        return len(self.tokens) - self.input_len - self.output_len + 1
    
    def __getitem__(self, index: int) -> dict[str, jnp.array]:
        return self._jitgetitem(index)

    @staticmethod
    def _jitgetitem(
        index: int,
        tokens,
        input_len,
        output_len,
        embedder_func
    ) -> jnp.ndarray:
        # Extract the token window
        window = jax.lax.dynamic_slice(
            tokens, (index,), (input_len + output_len,)
        )
        
        # Embed the tokens on demand (this may be memory heavy, but only for a small window)
        embedded_window = embedder_func(window)
        x, y = embedded_window[...,:input_len,:], embedded_window[...,input_len:,:]
        return x, y

In [5]:
tokeniser = WordTokeniser()
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    ds = SequenceDataset(f, 3, 2, tokeniser, embedder, window_len=5)

ds[0]

(Array([[-7.0361847e+02,  3.3582416e+02,  1.6287428e+02],
        [-1.1196708e-02, -9.6939774e-03, -2.0985298e-02],
        [-1.3034204e+03, -3.2633771e+02,  1.4264659e+02],
        [-1.0618215e+02, -3.3759235e+01,  6.7791662e+00],
        [-3.2326031e+01, -2.8942438e+01,  8.6696997e+00]], dtype=float32),
 Array([[-2.2718998e+01,  3.7305080e+01, -4.2888706e+01,  1.0025843e+02,
         -1.6993481e+01, -1.6425310e+01,  4.7479172e+01],
        [ 1.0182316e-02,  6.3046985e-03,  1.1931198e-02, -5.4324535e-03,
          3.7624363e-03, -1.0620670e-02,  8.6728754e-03],
        [ 4.0849900e+00,  5.5485703e+01, -2.1984842e+01, -1.8106064e+01,
         -1.2444858e+01,  1.5683472e+00, -4.9172010e+00],
        [ 4.8439548e+01, -7.0512054e+01, -6.3593425e-02,  6.6587166e+01,
          1.9610695e+00, -2.0335042e+00, -1.3289231e+01],
        [ 4.0634182e+01, -3.2517994e+01, -3.6052942e+00,  2.7541046e+01,
         -5.4291263e+00, -8.5053606e+00,  1.7380468e+00]], dtype=float32))

In [6]:
from deepscratch.dataset.base import DataLoader

BATCH_SIZE = 256
dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=16, drop_last=True, iobound=False)

In [7]:
for batch in dl:
    pass

In [None]:
class Seq2VecDataset:
    def __init__(self, corpus, y, embed, batch_size):
        self.corpus = corpus
        self.y = y
        self.i = 0
        self.embed = embed
        self.batch_size = batch_size

    def __len__(self):
        i = 0
        for _ in self.corpus:
            i+=1
        self.corpus.reset()
        return i

    def __next__(self):
        try:
            batch_idx = next(self.corpus)
        except StopIteration:
            self.corpus.reset()
            self.i = 0
            batch_idx = next(self.corpus)

        batch = self.embed(batch_idx)
        y = self.y[self.i:self.i + len(batch)]
        self.i += len(batch)

        return {"x":batch, "y":y}

Seq2Seq datasets iterate over the corpus and split that into the input token and output token

In [None]:
class Seq2SeqDataset:
    def __init__(self, iterator, input_len, output_len, embed, batch_size, target_memory_usage=100):
        self.input_len = input_len
        self.output_len = output_len
        self.batch_size = batch_size
        self.iterator = iterator
        self.embed = embed

    def __len__(self):
        i = 0
        for _ in self.corpus:
            i+=1
        self.iterator.reset()
        return i

    @staticmethod
    @partial(jax.jit, static_argnums=1)
    def split_batch(batch, cutoff):
        return batch[:,:cutoff], batch[:, cutoff:]

    def __next__(self):
        try:
            next_batch_idx = next(self.iterator)
        except StopIteration:
            self.iterator.reset()
            next_batch_idx = next(self.iterator)

        next_batch = self.embed(next_batch_idx)            
        x, y = self.split_batch(next_batch, self.input_len)

        return {"x":x, "y":y}