In [10]:
from typing import Tuple
from abc import ABC, abstractmethod
import jax.numpy as jnp
import jax

In [9]:
class Transformation(ABC):
    @abstractmethod
    def __call__(self, x: jnp.array) -> jnp.array:
        pass

In [None]:
class Reshape(Transformation):

    def __init__(self, shape: int | Tuple[int]):
        self.shape = shape

    def __call__(self, arr: jnp.array) -> jnp.array:
        return jnp.reshape(arr, self.shape)


In [None]:
class RandomJitter(Transformation):
    """
    Applies random jitter (small random noise) to the image.
    Useful for data augmentation in image tasks.
    
    Formula:
        y = x + noise, where noise ~ Uniform(-amount, amount)
    """
    def __init__(self, key: jax.random.PRNGKey, amount: float = 0.05):
        self.key = key
        self.amount = amount
    
    def __call__(self, arr: jnp.array) -> jnp.array:
        noise = jax.random.uniform(self.key, arr.shape, minval=-self.amount, maxval=self.amount)
        return arr + noise

In [None]:
class RandomRotation(Transformation):
    """
    Rotates an image by a random angle.
    Useful for data augmentation to improve model generalization.
    
    Formula:
        y = rotate(x, angle), where angle ~ Uniform(-max_angle, max_angle)
    """
    def __init__(self, key: jax.random.PRNGKey, max_angle: float = 30.0):
        self.key = key
        self.max_angle = max_angle
    
    def __call__(self, arr: jnp.array) -> jnp.array:
        angle = jax.random.uniform(self.key, (), minval=-self.max_angle, maxval=self.max_angle)
        return jnp.rot90(arr, k=int(angle // 90))  # Approximate rotation using 90-degree steps

In [None]:
class RandomFlip(Transformation):
    """
    Randomly flips an image horizontally.
    Useful for data augmentation in image classification tasks.
    """
    def __init__(self, key: jax.random.PRNGKey):
        self.key = key
    
    def __call__(self, arr: jnp.array) -> jnp.array:
        should_flip = jax.random.bernoulli(self.key)
        return jnp.flip(arr, axis=1) if should_flip else arr

In [None]:
class RandomCrop(Transformation):
    """
    Randomly crops a patch from an image.
    Useful for data augmentation to make models robust to translations.
    """
    def __init__(self, key: jax.random.PRNGKey, crop_size: Tuple[int, int]):
        self.key = key
        self.crop_size = crop_size
    
    def __call__(self, arr: jnp.array) -> jnp.array:
        h, w = arr.shape[:2]
        crop_h, crop_w = self.crop_size
        start_h = jax.random.randint(self.key, (), 0, h - crop_h + 1)
        start_w = jax.random.randint(self.key, (), 0, w - crop_w + 1)
        return arr[start_h:start_h + crop_h, start_w:start_w + crop_w]

In [None]:
class AvgPool(Transformation):
    def __init__(self, kernel_size: tuple, stride: int = 1, padding: int = 0):
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
    
    def __call__(self, x):
        x = jnp.pad(x, ((0, 0), (self.padding, self.padding), (self.padding, self.padding), (0, 0)), mode='constant')
        return jax.lax.reduce_window(
            x,
            0.0,        
            jax.lax.add,
            window_dimensions=(1, self.kernel_size[0], self.kernel_size[1], 1),
            window_strides=(1, self.stride, self.stride, 1),
            padding='VALID'
        ) / (self.kernel_size[0] * self.kernel_size[1])