In [1]:
from __future__ import annotations
import deepscratch
from deepscratch.typing import PyTree

from deepscratch.dataset.vision import ImageNet
from deepscratch.dataset.base import DataLoader
from deepscratch.models.base import LinearBlock, Sequential
from deepscratch.models.vision.cnn import ConvBlock
from deepscratch.initialisers import Gaussian, Zeros
from deepscratch.activations import ReLU, Softmax, Activation
from deepscratch.optimisers import SGD, Adam
from deepscratch.losses import CrossEntropy, Accuracy
from deepscratch.transformations import Reshape, AvgPool

import jax.numpy as jnp

import matplotlib.pyplot as plt

deepscratch.dataset.vision




deepscratch.dataset.base
deepscratch.models.base
deepscratch.initialisers
deepscratch.activations
deepscratch.optimisers
deepscratch.losses
deepscratch.transformations
Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

deepscratch.models.vision.cnn


I0000 00:00:1740319505.167365 10202701 service.cc:145] XLA service 0x16a452700 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1740319505.167375 10202701 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1740319505.168410 10202701 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1740319505.168442 10202701 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


## TimeMachine

Below we implement a RNN architecture to predict the next 5 words in the sentence in the book *The Time Machine* by HG Wells.

In [10]:
import jax.numpy as jnp
import random

from deepscratch.models.sequence.embeddings import LSA
from deepscratch.dataset.base import DataLoader
from deepscratch.dataset.sequence import SequenceDataset
from deepscratch.models.sequence.tokeniser import WordTokeniser
from deepscratch.models.base import Sequential
from deepscratch.models.sequence.rnn import RNNEncoder, RNNDecoder
from deepscratch.activations import LinearActivation
from deepscratch.losses import RMSE
from deepscratch.optimisers import Adam

In [3]:
with open("../deepscratch/__deepscratchcache__/thetimemachine.pkl", "rb") as f:
    embedder = LSA.from_cache(f)

In [4]:
tokeniser = WordTokeniser()

with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    ds = SequenceDataset(f, 10, 10, tokeniser, embedder)

BATCH_SIZE = 256
dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=16, drop_last=True, iobound=False)

In [5]:
x, y = next(iter(dl))
x.shape, y.shape

((256, 10, 50), (256, 10, 50))

In [6]:
encoder_hidden_size = decoder_hidden_size = 50

ann = Sequential([
    RNNEncoder(x.shape[-1], encoder_hidden_size),
    RNNDecoder(y.shape[-2], encoder_hidden_size, y.shape[-1], decoder_hidden_size, output_activation=LinearActivation())
])

In [7]:
len(ds), ann.n_params()

(32791, 15150)

In [12]:
ann.initialise()

In [13]:
ann.train(
    dl,
    RMSE(),
    Adam,
    lr=1e-4,
    epochs=1,
    device="METAL"
)

Iter: 128	Step: NAN	Loss: NAN


In [15]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    text = f.read()

corpus = jnp.unique(tokeniser.tokenise(text))
corpus_embeddings = embedder.embed(corpus)

def to_words(arr: jnp.array) -> list[str]:
    """
    Map predicted embeddings at each step to NL words
    """
    words = []
    for t in range(arr.shape[-2]):
        euc_dist = ((embedder.embeddings - arr[...,t,:]) ** 2).sum(axis=-1)
        i = euc_dist.argmin()
        words.append(embedder.idx_to_token[i.item()])

    return words

In [24]:
# Take one sample
i = random.randint(0, len(dl))
for _, (x,y) in zip(range(i), dl):
    pass

x0, y0 = x[jnp.array([0])], y[jnp.array([0])] 
y_est = ann.forward(x)
y0_est = y_est[jnp.array([0])]

print(f"Input: {" ".join(to_words(x0))}\nPredicted Passage: {" ".join(to_words(y0_est))}")

Input: that they were very badly broken and weather worn several
Predicted Passage: unsatisfying probably dawn played calm specimens return specimens return specimens


the model picks out some semantic structure but is it by-and-large non-sensical -- reflecting the overparamterisation of the model. The novel is too short compared with the complexity of the english language to provide sufficient training data for our model.