In [10]:
import deepscratch
from deepscratch.models.base import Block
from deepscratch.initialisers import He, Zeros

from functools import partial

import jax
from jax import lax
import jax.numpy as jnp

In [11]:
class ConvBlock(Block):
    def __init__(
        self, 
        input_channels: int, 
        output_channels: int, 
        kernel_size: int | tuple, 
        stride: int = 1, 
        padding: int = 0, 
        weight_init_method=He(), 
        bias_init_method=Zeros()
    ):
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)

        self.stride = stride
        self.padding = padding
        self.forward = jax.jit(partial(self.forward, padding=padding, stride=stride))
        
        self.weight_init_method = weight_init_method
        self.bias_init_method = bias_init_method
        
        super().__init__()
        
    def initialise(self):
        w = {}
        w["w_yxc"] = self.weight_init_method((self.kernel_size[0], self.kernel_size[1], self.input_channels, self.output_channels))
        w["b_y"] = self.bias_init_method((self.output_channels,))
        return w
    
    @staticmethod
    def forward(x, w, padding, stride):
        x = jnp.pad(x, [(0, 0), (padding, padding), (padding, padding), (0, 0)], mode='constant')
        return lax.conv_general_dilated(
            x,
            w["w_yxc"],
            window_strides=(stride, stride),
            padding='valid',
            dimension_numbers=('NHWC', 'HWIO', 'NHWC')
        ) + w["b_y"]