In [1]:
from __future__ import annotations
import deepscratch
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
from deepscratch.initialisers import Gaussian, Zeros
from deepscratch.typing import PyTree
from deepscratch.activations import ReLU, Softmax, Activation
from deepscratch.dataset.vision import MNISTDataset
from deepscratch.dataset.base import DataLoader
from deepscratch.optimisers import SGD, Adam
from deepscratch.losses import CrossEntropy, Accuracy
from deepscratch.transformations import Reshape

from functools import partial, wraps
from typing import Tuple
import jax
import jax.numpy as jnp

deepscratch.initialisers
deepscratch.activations
deepscratch.dataset.vision
deepscratch.dataset.base
deepscratch.optimisers
deepscratch.losses
deepscratch.transformations


## Fully Connected Layers

In [3]:
class Block(ABC):

    def __init__(self):
        self._w = None
        self.initialise()

        self.forward = jax.jit(self.forward)
    
    @abstractmethod
    def initialise(self) -> PyTree[jnp.array]:
        pass

    @staticmethod
    @abstractmethod
    def forward(self, x, w):
        pass

In [4]:
class LinearBlock(Block):
    def __init__(self, input_len: int, output_len: int, weight_init_method=Gaussian(0,1e-3), bias_init_method=Zeros()):
        self.input_len = input_len
        self.output_len = output_len
        self.weight_init_method = weight_init_method
        self.bias_init_method = bias_init_method
        super().__init__()

    def initialise(self) -> PyTree[jnp.array]:
        w = {}
        w["w_yx"] = self.weight_init_method((self.input_len, self.output_len))
        w["b_y"] = self.bias_init_method((self.output_len,))
        return w
    
    @staticmethod
    @jax.jit
    def forward(x, w):
        return x @ w["w_yx"] + w["b_y"]

I0000 00:00:1739971269.600857 8747963 service.cc:145] XLA service 0x127059360 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1739971269.600865 8747963 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1739971269.601808 8747963 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1739971269.601814 8747963 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



## Sequential Neural Nets

In [None]:
class Sequential: 
    def __init__(self, layers: list[Block, Activation]):
        self._layers = layers
        self._pure_forward = jax.jit(partial(self._pure_forward, layers=self._layers))
        self.trace_shape = partial(self.trace_shape, layers=self._layers)
        self.initialise()

    @property
    def layers(self):
        return self._layers
    
    def n_params(self):
        leaves = jax.tree_util.tree_leaves(self.w)
        total_elements = sum(leaf.size for leaf in leaves if hasattr(leaf, "size"))
        return total_elements
    
    def forward(self, x: jnp.array) -> jnp.array:
        return self._pure_forward(x, self.w)
    
    def trace_shape(self, shape: Tuple[int], layers) -> str:
        print(f"Input: shape: {shape}")
        z = jnp.ones(shape)
        i=0
        for layer in layers:
            if isinstance(layer, Block):
                z = layer.forward(z, self.w[i])
                i += 1
            else:
                z = layer.forward(z)
            print(f"Layer: {layer}\t shape: {z.shape}")
        
        return z
    
    @staticmethod
    def _pure_forward(x: jnp.array, w: PyTree[jnp.array], layers) -> jnp.array:
        z = x
        i=0
        for layer in layers:
            if isinstance(layer, Block):
                z = layer.forward(z, w[i])
                i += 1
            else:
                z = layer.forward(z)
        
        return z
    
    def initialise(self):
        self.w = [
            layer.initialise() for layer in self.layers 
            if isinstance(layer, Block)
        ]

    def train(self, data, loss_func, optimiser_cls, lr=1e-3, epochs=4, device="cpu"):

        self.initialise()
        w0 = self.w
        
        @jax.jit
        def objective(w, data):
            X, Y = data
            return loss_func.forward(self._pure_forward(X, w), Y)
        
        self.stepper = optimiser_cls(objective, lr, device=device)
        optimiser = SGD(self.stepper, data, epochs=epochs)
        self.w = optimiser.compute(w0)

    def predict(self, X):
        return self(X, *self.w)

    def accuracy(self, X, y):
        y_est = self.predict(X)
        return (y_est.argmax(axis=1) == y.argmax(axis=1)).sum() / len(y)

## Residual Connections

In [6]:
from typing import Type

In [7]:
def resconnect(block_cls: Type):

    ResClass = type("Res"+block_cls.__name__, (block_cls,), {})
    ResClass.__module__ = block_cls.__module__

    @staticmethod
    @wraps(block_cls.forward)
    def new_forward(x: jnp.array, w: PyTree[jnp.array]):
        return x + block_cls.forward(x, w)
    
    ResClass.forward = new_forward
    return ResClass

ResLinearBlock = resconnect(LinearBlock)

In [8]:
linear = LinearBlock(10,10)
reslinear = ResLinearBlock(10, 10)

x = jnp.array([[1,1]])
w = {"w_yx": jnp.array([[1,0],[0,1]], dtype="float32"), "b_y": jnp.array([0,0], dtype="float32")}

linear.forward(x,w), reslinear.forward(x,w)

(Array([[1., 1.]], dtype=float32), Array([[2., 2.]], dtype=float32))