In [5]:
import jax.numpy as jnp
from abc import ABC, abstractmethod
from functools import partial

# Activation functions
Activation functions are simple, non-linear layers in a neural network that induce a great flexibility in the range of functions that a neural network can approximate, whilst being simple and easily-differentiable, and so enabling the use of backpropogation to train the neural network.

In [6]:
class Activation(ABC):
    @staticmethod
    @abstractmethod
    def forward(x: jnp.array) -> jnp.array:
        pass

## Linear Activation

In [None]:
class LinearActivation(ABC):
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        return x

## Sigmoid Activation Function
The **Sigmoid** activation function maps any real-valued input to a range between \(0\) and \(1\), making it useful for probability-based outputs in binary classification problems. The function is defined as:  
\begin{equation}
\sigma(x) = \frac{1}{1 + e^{-x}}
\end{equation}
where $e^{-x}$ ensures smooth §and continuous output. The sigmoid function is differentiable, but it suffers from the **vanishing gradient problem**, making it less suitable for deep networks.

In [7]:
class Sigmoid(Activation):
    def __init__(self):
        pass
    
    @staticmethod
    def forward(x):
        return 1 / (1 + jnp.exp(-x))

In [8]:
class Sigmoid(Activation):
    def __init__(self):
        pass
    
    @staticmethod
    def forward(x):
        return jnp.exp(-jnp.logaddexp(0, -x))

## ReLU (Rectified Linear Unit) Activation Function
The **ReLU (Rectified Linear Unit)** activation function is widely used in deep learning due to its simplicity and effectiveness in mitigating vanishing gradient issues. It is defined as:  
$$
\text{ReLU}(x) = \max(0, x)
$$  
or, in the generalized form with a threshold \(t\):  
$$
\text{ReLU}(x) = \begin{cases} 
x, & x > t \\ 
0, & x \leq t 
\end{cases}
$$
ReLU is computationally efficient and promotes sparsity in activations, but it can suffer from the **dying ReLU problem**, where neurons become inactive for negative inputs.

In [9]:
class ReLU(Activation):
    def __init__(self, threshold=0) -> None:
        self.threshold = threshold
        self.forward = partial(self.forward, threshold=self.threshold)

    @staticmethod
    def forward(x, threshold) -> jnp.array:
        return jnp.where(x>threshold, x, 0)

## Tanh (Hyperbolic Tangent) Activation Function 
The **Tanh** activation function is a scaled version of the sigmoid function, mapping inputs to a range between \(-1\) and \(1\). The function is given by:  
$$
\tanh(x) = \frac{e^{2x} - 1}{e^{2x} + 1}
$$  
or equivalently, using the sigmoid function:  
$$
\tanh(x) = 2\sigma(2x) - 1
$$  
Tanh is zero-centered, making it preferable over sigmoid for training deep networks, though it still suffers from vanishing gradients for large or small values.

In [10]:
class Tanh(Activation):
    def __init__(self):
        pass

    @staticmethod
    def forward(x) -> jnp.array:
        return 2/(1 + jnp.exp(-2*x)) - 1 

## Softmax Activation Function 
The **Softmax** function is used primarily in classification tasks where outputs need to be interpreted as probabilities. It normalizes an input vector into a probability distribution by computing:  
$$
\text{Softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}
$$  
where the subtraction of $\max(x)$ (log-sum-exp trick) helps prevent numerical instability. Softmax ensures that all output values sum to 1, making it ideal for multi-class classification problems.

In [None]:
class Softmax(Activation):
    def __init__(self, axis=-1):
        self.axis = axis
        self.forward = partial(self.forward, axis=self.axis)
    
    @staticmethod
    def forward(x, axis):
        """
        'axis' = axis to compute softmax over
        """
        # Use log-sum-exp trickt to circumvent under/over flow
        x_max = jnp.max(x, axis=axis, keepdims=True)
        exp_x = jnp.exp(x-x_max)
        return  exp_x / exp_x.sum(axis=axis, keepdims=True)