In [1]:
from deepscratch.dataset.base import DataLoader
from deepscratch.dataset.sequence import SequenceDataset
from deepscratch.models.base import Block, Sequential, LinearBlock
from deepscratch.models.sequence.tokeniser import HFTokeniser
from deepscratch.models.sequence.embeddings import LSA, OHE, HFEmbedder
from deepscratch.initialisers import Orthonormal, Zeros, Xavier
from deepscratch.activations import Activation, Sigmoid, Tanh, Softmax, ReLU, LinearActivation
from deepscratch.losses import RMSE, CrossEntropy
from deepscratch.optimisers import Adam
from deepscratch.normalisers import BatchNorm

import random   

from functools import partial
import jax.numpy as jnp
from jax import lax

deepscratch.dataset.base
deepscratch.dataset.sequence
deepscratch.models.sequence.tokeniser


  from .autonotebook import tqdm as notebook_tqdm


deepscratch.models.sequence.embeddings
deepscratch.models.base
deepscratch.initialisers
deepscratch.activations
deepscratch.dataset.vision
deepscratch.optimisers




deepscratch.losses
deepscratch.transformations
Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

deepscratch.normalisers


I0000 00:00:1740234695.025963 9763662 service.cc:145] XLA service 0x15a0e9f30 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1740234695.025971 9763662 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1740234695.026901 9763662 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1740234695.026908 9763662 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


In [2]:
tokeniser = HFTokeniser('bert-base-uncased')

In [3]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    embedder = OHE(f, tokeniser)

In [4]:
with open("../../__deepscratchcache__/thetimemachine.pkl", "rb") as f:
    embedder = LSA.from_cache(f)

In [5]:
embedder = HFEmbedder('bert-base-uncased')

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    ds = SequenceDataset(f, 10, 10, tokeniser, embedder)

BATCH_SIZE = 256
dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=16, drop_last=True, iobound=False)

# RNN

In [7]:
class RNNEncoder(Block):
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            activation: Activation = Tanh(),
            input_weight_init_method=Xavier(),
            recur_weight_init_method=Orthonormal(),
            bias_init_method=Zeros()
        ):
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.activation = activation
        
        self.input_weight_init_method = input_weight_init_method
        self.recur_weight_init_method = recur_weight_init_method
        self.bias_init_method = bias_init_method

        self.forward = partial(
            self.forward,
            activation=self.activation.forward,
            hidden_size = self.hidden_size
        )
        
        super().__init__()
        
    def initialise(self):
        w = {}
        w["w_hx"] = self.input_weight_init_method((self.input_size, self.hidden_size))
        w["w_hh"] = self.recur_weight_init_method((self.hidden_size, self.hidden_size))
        w["b_h"] = self.bias_init_method((self.hidden_size,))

        return w
    
    @staticmethod
    def forward(x, w, activation, hidden_size):
        h = jnp.zeros((x.shape[0], hidden_size))

        step = lambda t, h : activation(x[...,t,:] @ w["w_hx"] + h @ w["w_hh"] + w["b_h"])
        h = lax.fori_loop(0, x.shape[-2], step, h) # 

        return h

In [8]:
class RNNDecoder(Block):
    def __init__(
            self,
            n_steps: int,
            input_size: int,
            output_size: int,
            hidden_size: int,

            hidden_activation: Activation = Tanh(),
            output_activation: Activation = LinearActivation(),
    
            input_weight_init_method=Xavier(),
            recur_weight_init_method=Orthonormal(),
            bias_init_method=Zeros()
        ):
        self.n_steps = n_steps
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size

        self.hidden_activation = hidden_activation
        self.output_activation = output_activation
        
        self.input_weight_init_method = input_weight_init_method
        self.recur_weight_init_method = recur_weight_init_method
        self.bias_init_method = bias_init_method

        self.forward = partial(
            self.forward,
            hidden_activation=self.hidden_activation.forward,
            output_activation=self.output_activation.forward,
            n_steps=self.n_steps,
            hidden_size = self.hidden_size,
            output_size = self.output_size
        )
        
        super().__init__()
        
    def initialise(self):
        w = {}
        # Hidden state weights
        w["w_hx"] = self.input_weight_init_method((self.input_size, self.hidden_size))
        w["w_hh"] = self.recur_weight_init_method((self.hidden_size, self.hidden_size))
        w["w_hy"] = self.input_weight_init_method((self.output_size, self.hidden_size))
        w["b_h"] = self.bias_init_method((self.hidden_size,))

        # Output weights
        w["w_yh"] = self.recur_weight_init_method((self.hidden_size, self.output_size))
        w["b_y"] = self.bias_init_method((self.output_size,))

        return w
    
    @staticmethod
    def forward(x, w, hidden_activation, output_activation, n_steps, hidden_size, output_size):
        h = jnp.zeros((x.shape[0], hidden_size))
        yt = jnp.zeros((x.shape[0], output_size))
        y = []
        for t in range(n_steps):
            h = hidden_activation(x @ w["w_hx"] + h @ w["w_hh"] + yt @ w["w_hy"] + w["b_h"])
            yt = output_activation(h @ w["w_yh"] + w["b_y"])
            y.append(yt)
        
        y = jnp.stack(y, axis=-2)
        return y

In [9]:
x, y = next(iter(dl))
x.shape, y.shape

((256, 10, 768), (256, 10, 768))

In [10]:
encoder_hidden_size = decoder_hidden_size = 50

ann = Sequential([
    RNNEncoder(x.shape[-1], encoder_hidden_size),
    RNNDecoder(y.shape[-2], encoder_hidden_size, y.shape[-1], decoder_hidden_size, output_activation=LinearActivation())
])

In [11]:
len(ds), ann.n_params()

(40480, 123568)

In [12]:
ann.initialise()

In [None]:
ann.train(
    dl,
    RMSE(),
    Adam,
    lr=1e-4,
    epochs=1,
    device="METAL"
)

Iter: 632	Step: 1.7E-03	Loss: 4.127E-01


### Predictions

In [14]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    text = f.read()

corpus = jnp.unique(tokeniser.tokenise(text))
corpus_embeddings = embedder.embed(corpus)

def to_words(arr: jnp.array) -> list[str]:
    words = []
    for t in range(arr.shape[-2]):
        if isinstance(embedder, LSA):
            euc_dist = ((embedder.embeddings - arr[...,t,:]) ** 2).sum(axis=-1)
            i = euc_dist.argmin()
            words.append(embedder.idx_to_token[i.item()])

        elif isinstance(embedder, OHE):
            i = arr[...,t,:].argmax()
            words.append(embedder.idx_to_token[i.item()])

        elif isinstance(embedder, HFEmbedder):
            distances = ((corpus_embeddings - arr[...,t,:]) ** 2).sum(axis=1)
            min_idx = distances.argmin()
            min_token = corpus[min_idx]
            words.append(tokeniser.idx_to_token[min_token.item()])

    return words

In [15]:
# Take one sample
i = random.randint(0, len(dl))
for _, (x,y) in zip(range(i), dl):
    pass

x0, y0 = x[jnp.array([0])], y[jnp.array([0])] 
y_est = ann.forward(x)
y0_est = y_est[jnp.array([0])]

print(f"Input: {" ".join(to_words(x0))}\nPredicted Passage: {" ".join(to_words(y0_est))}")

Input: them with the lever ##s , and began to scramble
Predicted Passage: bird with with with with with with with with with


In [102]:
i = random.randint(0, len(ds))
x, y = ds[i]
x, y = jnp.expand_dims(x, axis=0), jnp.expand_dims(y, axis=0)   # add batch dim
y_est = ann.forward(x)
print(f"Start: {" ".join(to_words(x))}\nActual: {" ".join(to_words(y))}\nPredicted: {" ".join(to_words(y_est))}")

Start: first impenetrably dark to me
Actual: i entered it groping for
Predicted: the the the i i


### Forecasting

In [103]:
# Take one sample
i = random.randint(0, len(ds))
x, y = ds[i]
x, y = jnp.expand_dims(x, axis=0), jnp.expand_dims(y, axis=0)   # add batch dim

# encode x to a context vector
w_emb, w_enc, w_dec = ann.w
embed_fc, encoder, _ = ann.layers
x_embed = embed_fc.forward(x, w_emb)
context = encoder.forward(x_embed, w_enc)

# New decoder that forecasts 50 steps ahead
decoder = RNNDecoder(50, encoder_hidden_size, y.shape[-1], decoder_hidden_size, output_activation=Softmax())
y_est = decoder.forward(context, w_dec)

In [104]:
print(f"Predicted Passage:\n{" ".join(to_words(y_est))}")

Predicted Passage:
the the the the i i the the i the the the the the i i i i i i i i i i i i i i i i i i i i i i i suffering i i i i i i i i i i i i


## LSTM

In [12]:
class LSTMBlock(Block):
    def __init__(
            self,
            input_size: int,
            hidden_size: int,
            activation=Tanh(),
            gate_activation=Sigmoid(),
            input_weight_init_method=Xavier(),
            recur_weight_init_method=Orthonormal(),
            bias_init_method=Zeros()
        ):
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.activation = activation
        self.gate_activation = gate_activation
        
        self.input_weight_init_method = input_weight_init_method
        self.recur_weight_init_method = recur_weight_init_method
        self.bias_init_method = bias_init_method
        
        super().__init__()
        
    def initialise(self):
        self.w = {}
        self.w["w_hx"] = self.input_weight_init_method((self.input_size, self.hidden_size * 4))
        self.w["w_hh"] = self.recur_weight_init_method((self.hidden_size, self.hidden_size * 4))
        self.w["b"] = self.bias_init_method((self.hidden_size * 4,))
    
    def __call__(self, x, h, c):
        gates = x @ self.w["w_x"] + h @ self.w["w_h"] + self.w["b"]
        i, f, g, o = jnp.split(gates, 4, axis=-1)
        i = self.gate_activation(i)
        f = self.gate_activation(f)
        g = self.activation(g)
        o = self.gate_activation(o)
        c_next = f * c + i * g
        h_next = o * self.activation(c_next)
        return h_next, c_next

## Self-Attention

In [13]:
class SelfAttentionBlock(Block):
    def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            activation,
            weight_init_method = Xavier(),
            bias_init_method = Zeros()
        ):
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.activation = activation
        
        self.weight_init_method = weight_init_method
        self.bias_init_method = bias_init_method
        
        super().__init__()
        
    def initialise(self):
        self.w = {}
        self.w["w_q"] = self.weight_init_method((self.embed_dim, self.embed_dim))
        self.w["w_k"] = self.weight_init_method((self.embed_dim, self.embed_dim))
        self.w["w_v"] = self.weight_init_method((self.embed_dim, self.embed_dim))
        self.w["w_o"] = self.weight_init_method((self.embed_dim, self.embed_dim))
        self.w["b_o"] = self.bias_init_method((self.embed_dim,))
    
    def __call__(self, x):
        q = x @ self.w["w_q"]
        k = x @ self.w["w_k"]
        v = x @ self.w["w_v"]
        
        q = q.reshape(q.shape[0], -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(k.shape[0], -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(v.shape[0], -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        
        attn_scores = jnp.einsum('bhid,bhjd->bhij', q, k) / jnp.sqrt(self.head_dim)
        attn_weights = Softmax()(attn_scores, axis=-1)
        attn_output = jnp.einsum('bhij,bhjd->bhid', attn_weights, v)
        
        attn_output = attn_output.transpose(0, 2, 1, 3).reshape(x.shape[0], -1, self.embed_dim)
        return attn_output @ self.w["w_o"] + self.w["b_o"]

## Transformer

In [14]:
class TransformerBlock(Block):
    def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            feedforward_dim: int,
            activation, 
            weight_init_method = Xavier(),
            bias_init_method = Zeros()
        ):
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.feedforward_dim = feedforward_dim
        self.activation = activation
        
        self.weight_init_method = weight_init_method
        self.bias_init_method = bias_init_method
        
        super().__init__()
        
    def initialise(self):
        self.attention = SelfAttentionBlock(
            self.embed_dim, self.num_heads, self.activation, self.weight_init_method, self.bias_init_method, self.key
        )
        
        self.w = {}
        self.w["w_ff1"] = self.weight_init_method((self.embed_dim, self.feedforward_dim))
        self.w["b_ff1"] = self.bias_init_method((self.feedforward_dim,))
        self.w["w_ff2"] = self.weight_init_method((self.feedforward_dim, self.embed_dim))
        self.w["b_ff2"] = self.bias_init_method((self.embed_dim,))
    
    def __call__(self, x):
        attn_out = self.attention(x) + x
        ff_out = self.activation(attn_out @ self.w["w_ff1"] + self.w["b_ff1"])
        ff_out = ff_out @ self.w["w_ff2"] + self.w["b_ff2"]
        return ff_out + attn_out